llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) <details> <summary>Changes</summary> This change adds support for equal operation for ComplexType https://github.com/llvm/llvm-project/issues/141365 --- Full diff: https://github.com/llvm/llvm-project/pull/145769.diff 5 Files Affected: - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+25) - (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+11-3) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+39-1) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10) - (modified) clang/test/CIR/CodeGen/complex.cpp (+74-1) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 5f24ab7816cbc..6eef525f52f8e 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2429,6 +2429,31 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// ComplexEqualOp +//===----------------------------------------------------------------------===// + +def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> { + + let summary = "Computes whether two complex values are equal"; + let description = [{ + The `complex.equal` op takes two complex numbers and returns whether + they are equal. + + ```mlir + %r = cir.complex.eq %a, %b : !cir.complex<!cir.float> + ``` + }]; + + let results = (outs CIR_BoolType:$result); + let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs); + + let assemblyFormat = [{ + $lhs `,` $rhs + `:` qualified(type($lhs)) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Assume Operations //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 7f8dcd96a6bff..4bcbc6d7ce798 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -894,9 +894,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> { } } else { // Complex Comparison: can only be an equality comparison. - assert(!cir::MissingFeatures::complexType()); - cgf.cgm.errorNYI(loc, "complex comparison"); - result = builder.getBool(false, loc); + assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE); + + BinOpInfo boInfo = emitBinOps(e); + if (e->getOpcode() == BO_EQ) { + result = + builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs); + } else { + assert(!cir::MissingFeatures::complexType()); + cgf.cgm.errorNYI(loc, "complex not equal"); + result = builder.getBool(false, loc); + } } return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(), diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index d41afbdd0b69e..1d33b00d026f4 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1905,7 +1905,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecTernaryOpLowering, CIRToLLVMComplexCreateOpLowering, CIRToLLVMComplexRealOpLowering, - CIRToLLVMComplexImagOpLowering + CIRToLLVMComplexImagOpLowering, + CIRToLLVMComplexEqualOpLowering // clang-format on >(converter, patterns.getContext()); @@ -2227,6 +2228,43 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite( + cir::ComplexEqualOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); + + auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType()); + mlir::Type complexElemTy = + getTypeConverter()->convertType(complexType.getElementType()); + + mlir::Location loc = op.getLoc(); + auto lhsReal = + rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); + auto lhsImag = + rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); + auto rhsReal = + rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); + auto rhsImag = + rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); + + if (complexElemTy.isInteger()) { + auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>( + loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); + auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>( + loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp); + return mlir::success(); + } + + auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>( + loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); + auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>( + loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); + rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp); + return mlir::success(); +} + std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { return std::make_unique<ConvertCIRToLLVMPass>(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 8502cb1ae5d9f..25cf218cf8b6c 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -463,6 +463,16 @@ class CIRToLLVMComplexImagOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMComplexEqualOpLowering + : public mlir::OpConversionPattern<cir::ComplexEqualOp> { +public: + using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + } // namespace direct } // namespace cir diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp index ad3720097a795..1e9ce0e29fd46 100644 --- a/clang/test/CIR/CodeGen/complex.cpp +++ b/clang/test/CIR/CodeGen/complex.cpp @@ -368,4 +368,77 @@ int foo17(int _Complex a, int _Complex b) { // OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0 // OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4 // OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]] -// OGCG: ret i32 %[[ADD]] \ No newline at end of file +// OGCG: ret i32 %[[ADD]] + +bool foo18(int _Complex a, int _Complex b) { + return a == b; +} + +// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i> +// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i> +// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i> + +// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 +// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4 +// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0 +// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1 +// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0 +// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1 +// LLVM: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]] +// LLVM: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]] +// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4 +// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4 +// OGCG: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]] +// OGCG: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]] +// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +bool foo19(double _Complex a, double _Complex b) { + return a == b; +} + +// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double> +// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double> +// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double> + +// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8 +// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8 +// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0 +// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1 +// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0 +// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1 +// LLVM: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]] +// LLVM: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]] +// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]] + +// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8 +// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8 +// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0 +// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8 +// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1 +// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8 +// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0 +// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8 +// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1 +// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8 +// OGCG: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]] +// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]] +// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]] + `````````` </details> https://github.com/llvm/llvm-project/pull/145769 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits