https://github.com/AmrDeveloper created 
https://github.com/llvm/llvm-project/pull/137511

This change adds global initialization for VectorType


Issue https://github.com/llvm/llvm-project/issues/136487

>From 153f0c0daa33b1c71ced4a0f050d49656e72f505 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <am...@programmer.net>
Date: Sat, 26 Apr 2025 18:43:00 +0200
Subject: [PATCH] [CIR] Upstream global initialization for VectorType

---
 .../include/clang/CIR/Dialect/IR/CIRAttrs.td  | 33 ++++++-
 clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp  | 23 ++++-
 clang/lib/CIR/Dialect/IR/CIRAttrs.cpp         | 88 +++++++++++++++++++
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       |  2 +-
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 40 +++++++--
 clang/test/CIR/CodeGen/vector-ext.cpp         | 11 ++-
 clang/test/CIR/CodeGen/vector.cpp             |  9 ++
 7 files changed, 196 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td 
b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index fb3f7b1632436..624a82762ab18 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", 
[TypedAttrInterface]>
     }]>
   ];
 
-  // Printing and parsing available in CIRDialect.cpp
+  // Printing and parsing available in CIRAttrs.cpp
   let hasCustomAssemblyFormat = 1;
 
   // Enable verifier.
@@ -215,6 +215,37 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", 
[TypedAttrInterface]>
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
+                               [TypedAttrInterface]> {
+  let summary = "A constant vector from ArrayAttr";
+  let description = [{
+    A CIR vector attribute is an array of literals of the specified attribute
+    types.
+  }];
+
+  let parameters = (ins AttributeSelfTypeParameter<"">:$type,
+                       "mlir::ArrayAttr":$elts);
+
+  // Define a custom builder for the type; that removes the need to pass in an
+  // MLIRContext instance, as it can be inferred from the `type`.
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "cir::VectorType":$type,
+                                       "mlir::ArrayAttr":$elts), [{
+      return $_get(type.getContext(), type, elts);
+    }]>
+  ];
+
+  // Printing and parsing available in CIRAttrs.cpp
+  let hasCustomAssemblyFormat = 1;
+
+  // Enable verifier.
+  let genVerifyDecl = 1;
+}
+
 
//===----------------------------------------------------------------------===//
 // ConstPtrAttr
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp 
b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index b9a74e90a5960..6e5c7b8fb51f8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const 
APValue &value,
                              elements, typedFiller);
   }
   case APValue::Vector: {
-    cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector");
-    return {};
+    const QualType elementType =
+        destType->castAs<VectorType>()->getElementType();
+    const unsigned numElements = value.getVectorLength();
+
+    SmallVector<mlir::Attribute, 16> elements;
+    elements.reserve(numElements);
+
+    for (unsigned i = 0; i < numElements; ++i) {
+      const mlir::Attribute element =
+          tryEmitPrivateForMemory(value.getVectorElt(i), elementType);
+      if (!element)
+        return {};
+      elements.push_back(element);
+    }
+
+    const auto desiredVecTy =
+        mlir::cast<cir::VectorType>(cgm.convertType(destType));
+
+    return cir::ConstVectorAttr::get(
+        desiredVecTy,
+        mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements));
   }
   case APValue::MemberPointer: {
     cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer");
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp 
b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index a8d9f6a0e6e9b..b9b27f33207b8 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -299,6 +299,94 @@ void ConstArrayAttr::print(AsmPrinter &printer) const {
   printer << ">";
 }
 
+//===----------------------------------------------------------------------===//
+// CIR ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ConstVectorAttr::verify(
+    function_ref<::mlir::InFlightDiagnostic()> emitError, Type type,
+    ArrayAttr elts) {
+
+  if (!mlir::isa<cir::VectorType>(type)) {
+    return emitError() << "type of cir::ConstVectorAttr is not a "
+                          "cir::VectorType: "
+                       << type;
+  }
+
+  const auto vecType = mlir::cast<cir::VectorType>(type);
+
+  if (vecType.getSize() != elts.size()) {
+    return emitError()
+           << "number of constant elements should match vector size";
+  }
+
+  // Check if the types of the elements match
+  LogicalResult elementTypeCheck = success();
+  elts.walkImmediateSubElements(
+      [&](Attribute element) {
+        if (elementTypeCheck.failed()) {
+          // An earlier element didn't match
+          return;
+        }
+        auto typedElement = mlir::dyn_cast<TypedAttr>(element);
+        if (!typedElement ||
+            typedElement.getType() != vecType.getElementType()) {
+          elementTypeCheck = failure();
+          emitError() << "constant type should match vector element type";
+        }
+      },
+      [&](Type) {});
+
+  return elementTypeCheck;
+}
+
+Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) {
+  FailureOr<Type> resultType;
+  FailureOr<ArrayAttr> resultValue;
+
+  const SMLoc loc = parser.getCurrentLocation();
+
+  // Parse literal '<'
+  if (parser.parseLess()) {
+    return {};
+  }
+
+  // Parse variable 'value'
+  resultValue = FieldParser<ArrayAttr>::parse(parser);
+  if (failed(resultValue)) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "failed to parse ConstVectorAttr parameter 'value' as "
+                     "an attribute");
+    return {};
+  }
+
+  if (parser.parseOptionalColon().failed()) {
+    resultType = type;
+  } else {
+    resultType = ::mlir::FieldParser<Type>::parse(parser);
+    if (failed(resultType)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "failed to parse ConstVectorAttr parameter 'type' as "
+                       "an MLIR type");
+      return {};
+    }
+  }
+
+  // Parse literal '>'
+  if (parser.parseGreater()) {
+    return {};
+  }
+
+  return parser.getChecked<ConstVectorAttr>(
+      loc, parser.getContext(), resultType.value(), resultValue.value());
+}
+
+void cir::ConstVectorAttr::print(AsmPrinter &printer) const {
+  printer << "<";
+  printer.printStrippedAttrOrType(getElts());
+  printer << ">";
+}
+
 
//===----------------------------------------------------------------------===//
 // CIR Dialect
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 939802a3af680..07847d62feadd 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -242,7 +242,7 @@ static LogicalResult checkConstantTypes(mlir::Operation 
*op, mlir::Type opType,
     return success();
   }
 
-  if (mlir::isa<cir::ConstArrayAttr>(attrType))
+  if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
     return success();
 
   assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 102438c2ded02..db331691154e6 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -188,8 +188,9 @@ class CIRAttrToValue {
 
   mlir::Value visit(mlir::Attribute attr) {
     return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
-        .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, 
cir::ConstPtrAttr,
-              cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
+        .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
+              cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
+            [&](auto attrT) { return visitCirAttr(attrT); })
         .Default([&](auto attrT) { return mlir::Value(); });
   }
 
@@ -197,6 +198,7 @@ class CIRAttrToValue {
   mlir::Value visitCirAttr(cir::FPAttr fltAttr);
   mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
   mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
+  mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
   mlir::Value visitCirAttr(cir::ZeroAttr attr);
 
 private:
@@ -275,6 +277,33 @@ mlir::Value 
CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
   return result;
 }
 
+/// ConstVectorAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
+  const mlir::Type llvmTy = converter->convertType(attr.getType());
+  const mlir::Location loc = parentOp->getLoc();
+
+  SmallVector<mlir::Attribute> mlirValues;
+  for (const mlir::Attribute elementAttr : attr.getElts()) {
+    mlir::Attribute mlirAttr;
+    if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
+      mlirAttr = rewriter.getIntegerAttr(
+          converter->convertType(intAttr.getType()), intAttr.getValue());
+    } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
+      mlirAttr = rewriter.getFloatAttr(
+          converter->convertType(floatAttr.getType()), floatAttr.getValue());
+    } else {
+      llvm_unreachable(
+          "vector constant with an element that is neither an int nor a 
float");
+    }
+    mlirValues.push_back(mlirAttr);
+  }
+
+  return rewriter.create<mlir::LLVM::ConstantOp>(
+      loc, llvmTy,
+      mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
+                                   mlirValues));
+}
+
 /// ZeroAttr visitor.
 mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
   mlir::Location loc = parentOp->getLoc();
@@ -888,7 +917,8 @@ 
CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
     cir::GlobalOp op, mlir::Attribute init,
     mlir::ConversionPatternRewriter &rewriter) const {
   // TODO: Generalize this handling when more types are needed here.
-  assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
+  assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
+              cir::ZeroAttr>(init)));
 
   // TODO(cir): once LLVM's dialect has proper equivalent attributes this
   // should be updated. For now, we use a custom op to initialize globals
@@ -941,8 +971,8 @@ mlir::LogicalResult 
CIRToLLVMGlobalOpLowering::matchAndRewrite(
         op.emitError() << "unsupported initializer '" << init.value() << "'";
         return mlir::failure();
       }
-    } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, 
cir::ZeroAttr>(
-                   init.value())) {
+    } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
+                         cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
       // TODO(cir): once LLVM's dialect has proper equivalent attributes this
       // should be updated. For now, we use a custom op to initialize globals
       // to the appropriate value.
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp 
b/clang/test/CIR/CodeGen/vector-ext.cpp
index 13726edf3d259..7759a32fc1378 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -31,7 +31,7 @@ vi2 vec_c;
 
 // OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
 
-vd2 d;
+vd2 vec_d;
 
 // CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x 
!cir.double>
 
@@ -39,6 +39,15 @@ vd2 d;
 
 // OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
 
+vi4 vec_e = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : 
!s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 
x !s32i>
+
+// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 
4>
+
+// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
 void foo() {
   vi4 a;
   vi3 b;
diff --git a/clang/test/CIR/CodeGen/vector.cpp 
b/clang/test/CIR/CodeGen/vector.cpp
index 8f9e98fb6b3c0..4c1850141a21c 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -30,6 +30,15 @@ vll2 c;
 
 // OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer
 
+vi4 d = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : 
!s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 
x !s32i>
+
+// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 
4>
+
+// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
 void vec_int_test() {
   vi4 a;
   vd2 b;

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to