Author: Amr Hesham Date: 2025-05-01T08:51:55+02:00 New Revision: 93ff19c27a0ad21068f431c0380f5a6dd5c4abc3
URL: https://github.com/llvm/llvm-project/commit/93ff19c27a0ad21068f431c0380f5a6dd5c4abc3 DIFF: https://github.com/llvm/llvm-project/commit/93ff19c27a0ad21068f431c0380f5a6dd5c4abc3.diff LOG: [CIR] Upstream global initialization for VectorType (#137511) This change adds global initialization for VectorType Issue https://github.com/llvm/llvm-project/issues/136487 Added: Modified: clang/include/clang/CIR/Dialect/IR/CIRAttrs.td clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp clang/lib/CIR/Dialect/IR/CIRAttrs.cpp clang/lib/CIR/Dialect/IR/CIRDialect.cpp clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp clang/test/CIR/CodeGen/vector-ext.cpp clang/test/CIR/CodeGen/vector.cpp clang/test/CIR/IR/vector.cir Removed: ################################################################################ diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index fb3f7b1632436..8152535930095 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]> }]> ]; - // Printing and parsing available in CIRDialect.cpp + // Printing and parsing available in CIRAttrs.cpp let hasCustomAssemblyFormat = 1; // Enable verifier. @@ -215,6 +215,38 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]> }]; } +//===----------------------------------------------------------------------===// +// ConstVectorAttr +//===----------------------------------------------------------------------===// + +def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", + [TypedAttrInterface]> { + let summary = "A constant vector from ArrayAttr"; + let description = [{ + A CIR vector attribute is an array of literals of the specified attribute + types. + }]; + + let parameters = (ins AttributeSelfTypeParameter<"">:$type, + "mlir::ArrayAttr":$elts); + + // Define a custom builder for the type; that removes the need to pass in an + // MLIRContext instance, as it can be inferred from the `type`. + let builders = [ + AttrBuilderWithInferredContext<(ins "cir::VectorType":$type, + "mlir::ArrayAttr":$elts), [{ + return $_get(type.getContext(), type, elts); + }]> + ]; + + let assemblyFormat = [{ + `<` $elts `>` + }]; + + // Enable verifier. + let genVerifyDecl = 1; +} + //===----------------------------------------------------------------------===// // ConstPtrAttr //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index b9a74e90a5960..6e5c7b8fb51f8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value, elements, typedFiller); } case APValue::Vector: { - cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector"); - return {}; + const QualType elementType = + destType->castAs<VectorType>()->getElementType(); + const unsigned numElements = value.getVectorLength(); + + SmallVector<mlir::Attribute, 16> elements; + elements.reserve(numElements); + + for (unsigned i = 0; i < numElements; ++i) { + const mlir::Attribute element = + tryEmitPrivateForMemory(value.getVectorElt(i), elementType); + if (!element) + return {}; + elements.push_back(element); + } + + const auto desiredVecTy = + mlir::cast<cir::VectorType>(cgm.convertType(destType)); + + return cir::ConstVectorAttr::get( + desiredVecTy, + mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements)); } case APValue::MemberPointer: { cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer"); diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index a940651f1e9eb..6f41cd4388ac2 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -299,6 +299,47 @@ void ConstArrayAttr::print(AsmPrinter &printer) const { printer << ">"; } +//===----------------------------------------------------------------------===// +// CIR ConstVectorAttr +//===----------------------------------------------------------------------===// + +LogicalResult cir::ConstVectorAttr::verify( + function_ref<::mlir::InFlightDiagnostic()> emitError, Type type, + ArrayAttr elts) { + + if (!mlir::isa<cir::VectorType>(type)) { + return emitError() << "type of cir::ConstVectorAttr is not a " + "cir::VectorType: " + << type; + } + + const auto vecType = mlir::cast<cir::VectorType>(type); + + if (vecType.getSize() != elts.size()) { + return emitError() + << "number of constant elements should match vector size"; + } + + // Check if the types of the elements match + LogicalResult elementTypeCheck = success(); + elts.walkImmediateSubElements( + [&](Attribute element) { + if (elementTypeCheck.failed()) { + // An earlier element didn't match + return; + } + auto typedElement = mlir::dyn_cast<TypedAttr>(element); + if (!typedElement || + typedElement.getType() != vecType.getElementType()) { + elementTypeCheck = failure(); + emitError() << "constant type should match vector element type"; + } + }, + [&](Type) {}); + + return elementTypeCheck; +} + //===----------------------------------------------------------------------===// // CIR Dialect //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 21b77b5327ca7..7ca012fbd73c6 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -244,7 +244,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, return success(); } - if (mlir::isa<cir::ConstArrayAttr>(attrType)) + if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType)) return success(); assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?"); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 1afcee6857de1..ea8125026a8c6 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -188,8 +188,9 @@ class CIRAttrToValue { mlir::Value visit(mlir::Attribute attr) { return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) - .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr, - cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); }) + .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, + cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) .Default([&](auto attrT) { return mlir::Value(); }); } @@ -197,6 +198,7 @@ class CIRAttrToValue { mlir::Value visitCirAttr(cir::FPAttr fltAttr); mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); mlir::Value visitCirAttr(cir::ConstArrayAttr attr); + mlir::Value visitCirAttr(cir::ConstVectorAttr attr); mlir::Value visitCirAttr(cir::ZeroAttr attr); private: @@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { return result; } +/// ConstVectorAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { + const mlir::Type llvmTy = converter->convertType(attr.getType()); + const mlir::Location loc = parentOp->getLoc(); + + SmallVector<mlir::Attribute> mlirValues; + for (const mlir::Attribute elementAttr : attr.getElts()) { + mlir::Attribute mlirAttr; + if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) { + mlirAttr = rewriter.getIntegerAttr( + converter->convertType(intAttr.getType()), intAttr.getValue()); + } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) { + mlirAttr = rewriter.getFloatAttr( + converter->convertType(floatAttr.getType()), floatAttr.getValue()); + } else { + llvm_unreachable( + "vector constant with an element that is neither an int nor a float"); + } + mlirValues.push_back(mlirAttr); + } + + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, llvmTy, + mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), + mlirValues)); +} + /// ZeroAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) { mlir::Location loc = parentOp->getLoc(); @@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( cir::GlobalOp op, mlir::Attribute init, mlir::ConversionPatternRewriter &rewriter) const { // TODO: Generalize this handling when more types are needed here. - assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init))); + assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, + cir::ZeroAttr>(init))); // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals @@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); } - } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>( - init.value())) { + } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, + cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) { // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals // to the appropriate value. diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp index 13726edf3d259..7759a32fc1378 100644 --- a/clang/test/CIR/CodeGen/vector-ext.cpp +++ b/clang/test/CIR/CodeGen/vector-ext.cpp @@ -31,7 +31,7 @@ vi2 vec_c; // OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer -vd2 d; +vd2 vec_d; // CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double> @@ -39,6 +39,15 @@ vd2 d; // OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer +vi4 vec_e = { 1, 2, 3, 4 }; + +// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + +// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + +// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + void foo() { vi4 a; vi3 b; diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp index 8f9e98fb6b3c0..4c1850141a21c 100644 --- a/clang/test/CIR/CodeGen/vector.cpp +++ b/clang/test/CIR/CodeGen/vector.cpp @@ -30,6 +30,15 @@ vll2 c; // OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer +vi4 d = { 1, 2, 3, 4 }; + +// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + +// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + +// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4> + void vec_int_test() { vi4 a; vd2 b; diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir index 74ddf7691e7d4..bc70a8b55fa5c 100644 --- a/clang/test/CIR/IR/vector.cir +++ b/clang/test/CIR/IR/vector.cir @@ -13,6 +13,12 @@ cir.global external @vec_b = #cir.zero : !cir.vector<3 x !s32i> cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i> // CHECK: cir.global external @vec_c = #cir.zero : !cir.vector<2 x !s32i> +cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> +: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + +// CIR: cir.global external @vec_d = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : +// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + cir.func @vec_int_test() { %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"] %1 = cir.alloca !cir.vector<3 x !s32i>, !cir.ptr<!cir.vector<3 x !s32i>>, ["b"] _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits