https://github.com/AmrDeveloper created https://github.com/llvm/llvm-project/pull/137511
This change adds global initialization for VectorType Issue https://github.com/llvm/llvm-project/issues/136487 >From 153f0c0daa33b1c71ced4a0f050d49656e72f505 Mon Sep 17 00:00:00 2001 From: AmrDeveloper <am...@programmer.net> Date: Sat, 26 Apr 2025 18:43:00 +0200 Subject: [PATCH] [CIR] Upstream global initialization for VectorType --- .../include/clang/CIR/Dialect/IR/CIRAttrs.td | 33 ++++++- clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp | 23 ++++- clang/lib/CIR/Dialect/IR/CIRAttrs.cpp | 88 +++++++++++++++++++ clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 2 +- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 40 +++++++-- clang/test/CIR/CodeGen/vector-ext.cpp | 11 ++- clang/test/CIR/CodeGen/vector.cpp | 9 ++ 7 files changed, 196 insertions(+), 10 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index fb3f7b1632436..624a82762ab18 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]> }]> ]; - // Printing and parsing available in CIRDialect.cpp + // Printing and parsing available in CIRAttrs.cpp let hasCustomAssemblyFormat = 1; // Enable verifier. @@ -215,6 +215,37 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]> }]; } +//===----------------------------------------------------------------------===// +// ConstVectorAttr +//===----------------------------------------------------------------------===// + +def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", + [TypedAttrInterface]> { + let summary = "A constant vector from ArrayAttr"; + let description = [{ + A CIR vector attribute is an array of literals of the specified attribute + types. + }]; + + let parameters = (ins AttributeSelfTypeParameter<"">:$type, + "mlir::ArrayAttr":$elts); + + // Define a custom builder for the type; that removes the need to pass in an + // MLIRContext instance, as it can be inferred from the `type`. + let builders = [ + AttrBuilderWithInferredContext<(ins "cir::VectorType":$type, + "mlir::ArrayAttr":$elts), [{ + return $_get(type.getContext(), type, elts); + }]> + ]; + + // Printing and parsing available in CIRAttrs.cpp + let hasCustomAssemblyFormat = 1; + + // Enable verifier. + let genVerifyDecl = 1; +} + //===----------------------------------------------------------------------===// // ConstPtrAttr //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index b9a74e90a5960..6e5c7b8fb51f8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value, elements, typedFiller); } case APValue::Vector: { - cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector"); - return {}; + const QualType elementType = + destType->castAs<VectorType>()->getElementType(); + const unsigned numElements = value.getVectorLength(); + + SmallVector<mlir::Attribute, 16> elements; + elements.reserve(numElements); + + for (unsigned i = 0; i < numElements; ++i) { + const mlir::Attribute element = + tryEmitPrivateForMemory(value.getVectorElt(i), elementType); + if (!element) + return {}; + elements.push_back(element); + } + + const auto desiredVecTy = + mlir::cast<cir::VectorType>(cgm.convertType(destType)); + + return cir::ConstVectorAttr::get( + desiredVecTy, + mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements)); } case APValue::MemberPointer: { cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer"); diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index a8d9f6a0e6e9b..b9b27f33207b8 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -299,6 +299,94 @@ void ConstArrayAttr::print(AsmPrinter &printer) const { printer << ">"; } +//===----------------------------------------------------------------------===// +// CIR ConstVectorAttr +//===----------------------------------------------------------------------===// + +LogicalResult cir::ConstVectorAttr::verify( + function_ref<::mlir::InFlightDiagnostic()> emitError, Type type, + ArrayAttr elts) { + + if (!mlir::isa<cir::VectorType>(type)) { + return emitError() << "type of cir::ConstVectorAttr is not a " + "cir::VectorType: " + << type; + } + + const auto vecType = mlir::cast<cir::VectorType>(type); + + if (vecType.getSize() != elts.size()) { + return emitError() + << "number of constant elements should match vector size"; + } + + // Check if the types of the elements match + LogicalResult elementTypeCheck = success(); + elts.walkImmediateSubElements( + [&](Attribute element) { + if (elementTypeCheck.failed()) { + // An earlier element didn't match + return; + } + auto typedElement = mlir::dyn_cast<TypedAttr>(element); + if (!typedElement || + typedElement.getType() != vecType.getElementType()) { + elementTypeCheck = failure(); + emitError() << "constant type should match vector element type"; + } + }, + [&](Type) {}); + + return elementTypeCheck; +} + +Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) { + FailureOr<Type> resultType; + FailureOr<ArrayAttr> resultValue; + + const SMLoc loc = parser.getCurrentLocation(); + + // Parse literal '<' + if (parser.parseLess()) { + return {}; + } + + // Parse variable 'value' + resultValue = FieldParser<ArrayAttr>::parse(parser); + if (failed(resultValue)) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse ConstVectorAttr parameter 'value' as " + "an attribute"); + return {}; + } + + if (parser.parseOptionalColon().failed()) { + resultType = type; + } else { + resultType = ::mlir::FieldParser<Type>::parse(parser); + if (failed(resultType)) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse ConstVectorAttr parameter 'type' as " + "an MLIR type"); + return {}; + } + } + + // Parse literal '>' + if (parser.parseGreater()) { + return {}; + } + + return parser.getChecked<ConstVectorAttr>( + loc, parser.getContext(), resultType.value(), resultValue.value()); +} + +void cir::ConstVectorAttr::print(AsmPrinter &printer) const { + printer << "<"; + printer.printStrippedAttrOrType(getElts()); + printer << ">"; +} + //===----------------------------------------------------------------------===// // CIR Dialect //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 939802a3af680..07847d62feadd 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -242,7 +242,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, return success(); } - if (mlir::isa<cir::ConstArrayAttr>(attrType)) + if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType)) return success(); assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?"); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 102438c2ded02..db331691154e6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -188,8 +188,9 @@ class CIRAttrToValue { mlir::Value visit(mlir::Attribute attr) { return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) - .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr, - cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); }) + .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, + cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) .Default([&](auto attrT) { return mlir::Value(); }); } @@ -197,6 +198,7 @@ class CIRAttrToValue { mlir::Value visitCirAttr(cir::FPAttr fltAttr); mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); mlir::Value visitCirAttr(cir::ConstArrayAttr attr); + mlir::Value visitCirAttr(cir::ConstVectorAttr attr); mlir::Value visitCirAttr(cir::ZeroAttr attr); private: @@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { return result; } +/// ConstVectorAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { + const mlir::Type llvmTy = converter->convertType(attr.getType()); + const mlir::Location loc = parentOp->getLoc(); + + SmallVector<mlir::Attribute> mlirValues; + for (const mlir::Attribute elementAttr : attr.getElts()) { + mlir::Attribute mlirAttr; + if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) { + mlirAttr = rewriter.getIntegerAttr( + converter->convertType(intAttr.getType()), intAttr.getValue()); + } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) { + mlirAttr = rewriter.getFloatAttr( + converter->convertType(floatAttr.getType()), floatAttr.getValue()); + } else { + llvm_unreachable( + "vector constant with an element that is neither an int nor a float"); + } + mlirValues.push_back(mlirAttr); + } + + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, llvmTy, + mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), + mlirValues)); +} + /// ZeroAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) { mlir::Location loc = parentOp->getLoc(); @@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( cir::GlobalOp op, mlir::Attribute init, mlir::ConversionPatternRewriter &rewriter) const { // TODO: Generalize this handling when more types are needed here. - assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init))); + assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, + cir::ZeroAttr>(init))); // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals @@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); } - } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>( - init.value())) { + } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, + cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) { // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals // to the appropriate value. diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp index 13726edf3d259..7759a32fc1378 100644 --- a/clang/test/CIR/CodeGen/vector-ext.cpp +++ b/clang/test/CIR/CodeGen/vector-ext.cpp @@ -31,7 +31,7 @@ vi2 vec_c; // OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer -vd2 d; +vd2 vec_d; // CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double> @@ -39,6 +39,15 @@ vd2 d; // OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer +vi4 vec_e = { 1, 2, 3, 4 }; + +// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + +// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + +// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + void foo() { vi4 a; vi3 b; diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp index 8f9e98fb6b3c0..4c1850141a21c 100644 --- a/clang/test/CIR/CodeGen/vector.cpp +++ b/clang/test/CIR/CodeGen/vector.cpp @@ -30,6 +30,15 @@ vll2 c; // OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer +vi4 d = { 1, 2, 3, 4 }; + +// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + +// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + +// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + void vec_int_test() { vi4 a; vd2 b; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits