https://github.com/AmrDeveloper created https://github.com/llvm/llvm-project/pull/141369
This change adds support for zero and global init for ComplexType #141365 >From 3a8bcd052d25d138b3a9a53bbcc69d48500b4b41 Mon Sep 17 00:00:00 2001 From: AmrDeveloper <am...@programmer.net> Date: Sat, 24 May 2025 14:18:06 +0200 Subject: [PATCH] [CIR] Upstream global initialization for ComplexType --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 2 + .../include/clang/CIR/Dialect/IR/CIRAttrs.td | 34 ++++++++++ .../include/clang/CIR/Dialect/IR/CIRTypes.td | 43 ++++++++++++ clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp | 27 ++++++-- clang/lib/CIR/CodeGen/CIRGenTypes.cpp | 7 ++ clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 20 ++++++ clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 6 +- clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 26 ++++++++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 65 +++++++++++++++---- clang/test/CIR/CodeGen/complex.cpp | 29 +++++++++ clang/test/CIR/IR/complex.cir | 16 +++++ 11 files changed, 257 insertions(+), 18 deletions(-) create mode 100644 clang/test/CIR/CodeGen/complex.cpp create mode 100644 clang/test/CIR/IR/complex.cir diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 9de3a66755778..878aba69c0e24 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -89,6 +89,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { return cir::IntAttr::get(ty, 0); if (cir::isAnyFloatingPointType(ty)) return cir::FPAttr::getZero(ty); + if (auto complexType = mlir::dyn_cast<cir::ComplexType>(ty)) + return cir::ZeroAttr::get(complexType); if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty)) return cir::ZeroAttr::get(arrTy); if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty)) diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index 8152535930095..4effae1cf2e29 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -276,4 +276,38 @@ def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> { }]; } +//===----------------------------------------------------------------------===// +// ConstComplexAttr +//===----------------------------------------------------------------------===// + +def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex", [TypedAttrInterface]> { + let summary = "An attribute that contains a constant complex value"; + let description = [{ + The `#cir.const_complex` attribute contains a constant value of complex number + type. The `real` parameter gives the real part of the complex number and the + `imag` parameter gives the imaginary part of the complex number. + + The `real` and `imag` parameter must be either an IntAttr or an FPAttr that + contains values of the same CIR type. + }]; + + let parameters = (ins + AttributeSelfTypeParameter<"", "cir::ComplexType">:$type, + "mlir::TypedAttr":$real, "mlir::TypedAttr":$imag); + + let builders = [ + AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type, + "mlir::TypedAttr":$real, + "mlir::TypedAttr":$imag), [{ + return $_get(type.getContext(), type, real, imag); + }]>, + ]; + + let genVerifyDecl = 1; + + let assemblyFormat = [{ + `<` qualified($real) `,` qualified($imag) `>` + }]; +} + #endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRS_TD diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 26f1122a4b261..ec994620893fe 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -161,6 +161,49 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> { }]; } +//===----------------------------------------------------------------------===// +// ComplexType +//===----------------------------------------------------------------------===// + +def CIR_ComplexType : CIR_Type<"Complex", "complex", + [DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> { + + let summary = "CIR complex type"; + let description = [{ + CIR type that represents a C complex number. `cir.complex` models the C type + `T _Complex`. + + The type models complex values, per C99 6.2.5p11. It supports the C99 + complex float types as well as the GCC integer complex extensions. + + The parameter `elementType` gives the type of the real and imaginary part of + the complex number. `elementType` must be either a CIR integer type or a CIR + floating-point type. + }]; + + let parameters = (ins CIR_AnyIntOrFloatType:$elementType); + + let builders = [ + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); + }]>, + ]; + + let assemblyFormat = [{ + `<` $elementType `>` + }]; + + let extraClassDeclaration = [{ + bool isFloatingPointComplex() const { + return isAnyFloatingPointType(getElementType()); + } + + bool isIntegerComplex() const { + return mlir::isa<cir::IntType>(getElementType()); + } + }]; +} + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index 9085ee2dfe506..973349b8c0443 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -577,12 +577,31 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value, case APValue::Union: cgm.errorNYI("ConstExprEmitter::tryEmitPrivate struct or union"); return {}; - case APValue::FixedPoint: case APValue::ComplexInt: - case APValue::ComplexFloat: + case APValue::ComplexFloat: { + mlir::Type desiredType = cgm.convertType(destType); + cir::ComplexType complexType = + mlir::dyn_cast<cir::ComplexType>(desiredType); + + mlir::Type compelxElemTy = complexType.getElementType(); + if (isa<cir::IntType>(compelxElemTy)) { + llvm::APSInt real = value.getComplexIntReal(); + llvm::APSInt imag = value.getComplexIntImag(); + return builder.getAttr<cir::ConstComplexAttr>( + complexType, builder.getAttr<cir::IntAttr>(compelxElemTy, real), + builder.getAttr<cir::IntAttr>(compelxElemTy, imag)); + } + + llvm::APFloat real = value.getComplexFloatReal(); + llvm::APFloat imag = value.getComplexFloatImag(); + return builder.getAttr<cir::ConstComplexAttr>( + complexType, builder.getAttr<cir::FPAttr>(compelxElemTy, real), + builder.getAttr<cir::FPAttr>(compelxElemTy, imag)); + } + case APValue::FixedPoint: case APValue::AddrLabelDiff: - cgm.errorNYI("ConstExprEmitter::tryEmitPrivate fixed point, complex int, " - "complex float, addr label diff"); + cgm.errorNYI( + "ConstExprEmitter::tryEmitPrivate fixed point, addr label diff"); return {}; } llvm_unreachable("Unknown APValue kind"); diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index 0665ea0506875..948be813ebe51 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -385,6 +385,13 @@ mlir::Type CIRGenTypes::convertType(QualType type) { break; } + case Type::Complex: { + const ComplexType *ct = cast<ComplexType>(ty); + mlir::Type elementTy = convertType(ct->getElementType()); + resultType = cir::ComplexType::get(elementTy); + break; + } + case Type::LValueReference: case Type::RValueReference: { const ReferenceType *refTy = cast<ReferenceType>(ty); diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index c4fb4942aec75..d9426ced5f5ab 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -184,6 +184,26 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } +//===----------------------------------------------------------------------===// +// ConstComplexAttr definitions +//===----------------------------------------------------------------------===// + +LogicalResult +ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError, + cir::ComplexType type, mlir::TypedAttr real, + mlir::TypedAttr imag) { + mlir::Type elemType = type.getElementType(); + if (real.getType() != elemType) + return emitError() + << "type of the real part does not match the complex type"; + + if (imag.getType() != elemType) + return emitError() + << "type of the imaginary part does not match the complex type"; + + return success(); +} + //===----------------------------------------------------------------------===// // CIR ConstArrayAttr //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 36dcbc6a4be4a..4a9386b1eed0f 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -231,7 +231,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, } if (isa<cir::ZeroAttr>(attrType)) { - if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType)) + if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>( + opType)) return success(); return op->emitOpError("zero expects struct or array type"); } @@ -253,7 +254,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, return success(); } - if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType)) + if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, + cir::ConstComplexAttr>(attrType)) return success(); assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?"); diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index b402177a5ec18..14050f36bbfdc 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -552,6 +552,32 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout, .getABIAlignment(dataLayout, params); } +//===----------------------------------------------------------------------===// +// ComplexType Definitions +//===----------------------------------------------------------------------===// + +llvm::TypeSize +cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, + mlir::DataLayoutEntryListRef params) const { + // C17 6.2.5p13: + // Each complex type has the same representation and alignment requirements + // as an array type containing exactly two elements of the corresponding + // real type. + + return dataLayout.getTypeSizeInBits(getElementType()) * 2; +} + +uint64_t +cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout, + mlir::DataLayoutEntryListRef params) const { + // C17 6.2.5p13: + // Each complex type has the same representation and alignment requirements + // as an array type containing exactly two elements of the corresponding + // real type. + + return dataLayout.getTypeABIAlignment(getElementType()); +} + //===----------------------------------------------------------------------===// // Floating-point and Float-point Vector type helpers //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 8e82af7e62bc0..d0ae1d64e9afd 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -188,14 +188,15 @@ class CIRAttrToValue { mlir::Value visit(mlir::Attribute attr) { return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) - .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, - cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>( - [&](auto attrT) { return visitCirAttr(attrT); }) + .Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr, + cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, + cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); }) .Default([&](auto attrT) { return mlir::Value(); }); } mlir::Value visitCirAttr(cir::IntAttr intAttr); mlir::Value visitCirAttr(cir::FPAttr fltAttr); + mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr); mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); mlir::Value visitCirAttr(cir::ConstArrayAttr attr); mlir::Value visitCirAttr(cir::ConstVectorAttr attr); @@ -226,6 +227,42 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { loc, converter->convertType(intAttr.getType()), intAttr.getValue()); } +/// FPAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); +} + +/// ConstComplexAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) { + auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType()); + auto complexElemTy = complexType.getElementType(); + auto complexElemLLVMTy = converter->convertType(complexElemTy); + + mlir::Attribute components[2]; + if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) { + components[0] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue()); + components[1] = rewriter.getIntegerAttr( + complexElemLLVMTy, + mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue()); + } else { + components[0] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue()); + components[1] = rewriter.getFloatAttr( + complexElemLLVMTy, + mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue()); + } + + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(complexAttr.getType()), + rewriter.getArrayAttr(components)); +} + /// ConstPtrAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { mlir::Location loc = parentOp->getLoc(); @@ -241,13 +278,6 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { loc, converter->convertType(ptrAttr.getType()), ptrVal); } -/// FPAttr visitor. -mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { - mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); -} - // ConstArrayAttr visitor mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { mlir::Type llvmTy = converter->convertType(attr.getType()); @@ -341,9 +371,11 @@ class GlobalInitAttrRewriter { mlir::Attribute visitCirAttr(cir::IntAttr attr) { return rewriter.getIntegerAttr(llvmType, attr.getValue()); } + mlir::Attribute visitCirAttr(cir::FPAttr attr) { return rewriter.getFloatAttr(llvmType, attr.getValue()); } + mlir::Attribute visitCirAttr(cir::BoolAttr attr) { return rewriter.getBoolAttr(attr.getValue()); } @@ -990,7 +1022,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( mlir::ConversionPatternRewriter &rewriter) const { // TODO: Generalize this handling when more types are needed here. assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, - cir::ZeroAttr>(init))); + cir::ConstComplexAttr, cir::ZeroAttr>(init))); // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals @@ -1043,7 +1075,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( return mlir::failure(); } } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, - cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) { + cir::ConstPtrAttr, cir::ConstComplexAttr, + cir::ZeroAttr>(init.value())) { // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals // to the appropriate value. @@ -1559,6 +1592,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, converter.addConversion([&](cir::BF16Type type) -> mlir::Type { return mlir::BFloat16Type::get(type.getContext()); }); + converter.addConversion([&](cir::ComplexType type) -> mlir::Type { + // A complex type is lowered to an LLVM struct that contains the real and + // imaginary part as data fields. + mlir::Type elementTy = converter.convertType(type.getElementType()); + mlir::Type structFields[2] = {elementTy, elementTy}; + return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(), + structFields); + }); converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> { auto result = converter.convertType(type.getReturnType()); llvm::SmallVector<mlir::Type> arguments; diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp new file mode 100644 index 0000000000000..1e0c9fcf08ef0 --- /dev/null +++ b/clang/test/CIR/CodeGen/complex.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG + +int _Complex ci; + +float _Complex cf; + +int _Complex ci2 = { 1, 2 }; + +float _Complex cf2 = { 1.0f, 2.0f }; + +// CIR: cir.global external @ci = #cir.zero : !cir.complex<!s32i> +// CIR: cir.global external @cf = #cir.zero : !cir.complex<!cir.float> +// CIR: cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i> +// CIR: cir.global external @cf2 = #cir.const_complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float> + +// LLVM: {{.*}} = dso_local global { i32, i32 } zeroinitializer, align 4 +// LLVM: {{.*}} = dso_local global { float, float } zeroinitializer, align 4 +// LLVM: {{.*}} = dso_local global { i32, i32 } { i32 1, i32 2 }, align 4 +// LLVM: {{.*}} = dso_local global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4 + +// OGCG: {{.*}} = global { i32, i32 } zeroinitializer, align 4 +// OGCG: {{.*}} = global { float, float } zeroinitializer, align 4 +// OGCG: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4 +// OGCG: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4 diff --git a/clang/test/CIR/IR/complex.cir b/clang/test/CIR/IR/complex.cir new file mode 100644 index 0000000000000..a73a8654ca274 --- /dev/null +++ b/clang/test/CIR/IR/complex.cir @@ -0,0 +1,16 @@ +// RUN: cir-opt %s | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + +cir.global external @ci = #cir.zero : !cir.complex<!s32i> +// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!s32i> + +cir.global external @cf = #cir.zero : !cir.complex<!cir.float> +// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!cir.float> + +cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i> +// CHECK: cir.global external {{.*}} = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i> + +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits