https://github.com/4vtomat created 
https://github.com/llvm/llvm-project/pull/147173

Currently, struct of fixed-vector array is flattened and lowered to
scalable vector. However only struct of 1-element-fixed-vector array
should be lowered that way, struct of fixed-vector array of length >1
should be lowered to vector tuple type.
https://github.com/riscv-non-isa/riscv-elf-psabi-doc/pull/418/files#diff-3a934f00cffdb3e509722753126a2cf6082a7648ab3b9ca8cbb0e84f8a6a12edR555-R558


>From 1934543b7ec215312eebefd152f7c9151c2d0e54 Mon Sep 17 00:00:00 2001
From: Brandon Wu <songwu0...@gmail.com>
Date: Sat, 5 Jul 2025 21:32:28 -0700
Subject: [PATCH] [RISCV] Correct type lowering of struct of fixed-vector array
 in VLS

Currently, struct of fixed-vector array is flattened and lowered to
scalable vector. However only struct of 1-element-fixed-vector array
should be lowered that way, struct of fixed-vector array of length >1
should be lowered to vector tuple type.
https://github.com/riscv-non-isa/riscv-elf-psabi-doc/pull/418/files#diff-3a934f00cffdb3e509722753126a2cf6082a7648ab3b9ca8cbb0e84f8a6a12edR555-R558
---
 clang/lib/CodeGen/Targets/RISCV.cpp           | 144 ++++++++----------
 .../RISCV/riscv-vector-callingconv-llvm-ir.c  |   8 +-
 .../riscv-vector-callingconv-llvm-ir.cpp      |   8 +-
 3 files changed, 68 insertions(+), 92 deletions(-)

diff --git a/clang/lib/CodeGen/Targets/RISCV.cpp 
b/clang/lib/CodeGen/Targets/RISCV.cpp
index cc3d487da83b5..e1603d3095a04 100644
--- a/clang/lib/CodeGen/Targets/RISCV.cpp
+++ b/clang/lib/CodeGen/Targets/RISCV.cpp
@@ -441,98 +441,74 @@ bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, 
unsigned ABIVLen,
   //     __attribute__((vector_size(64))) int d;
   //   }
   //
-  // Struct of 1 fixed-length vector is passed as a scalable vector.
-  // Struct of >1 fixed-length vectors are passed as vector tuple.
-  // Struct of 1 array of fixed-length vectors is passed as a scalable vector.
-  // Otherwise, pass the struct indirectly.
-
-  if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty))) 
{
-    unsigned NumElts = STy->getStructNumElements();
-    if (NumElts > 8)
-      return false;
+  // 1. Struct of 1 fixed-length vector is passed as a scalable vector.
+  // 2. Struct of >1 fixed-length vectors are passed as vector tuple.
+  // 3. Struct of an array with 1 element of fixed-length vectors is passed as 
a
+  //    scalable vector.
+  // 4. Struct of an array with >1 elements of fixed-length vectors is passed 
as
+  //    vector tuple.
+  // 5. Otherwise, pass the struct indirectly.
+
+  llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty));
+  if (!STy)
+    return false;
 
-    auto *FirstEltTy = STy->getElementType(0);
-    if (!STy->containsHomogeneousTypes())
-      return false;
+  unsigned NumElts = STy->getStructNumElements();
+  if (NumElts > 8)
+    return false;
 
-    // Check structure of fixed-length vectors and turn them into vector tuple
-    // type if legal.
-    if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
-      if (NumElts == 1) {
-        // Handle single fixed-length vector.
-        VLSType = llvm::ScalableVectorType::get(
-            FixedVecTy->getElementType(),
-            llvm::divideCeil(FixedVecTy->getNumElements() *
-                                 llvm::RISCV::RVVBitsPerBlock,
-                             ABIVLen));
-        // Check registers needed <= 8.
-        return llvm::divideCeil(
-                   FixedVecTy->getNumElements() *
-                       FixedVecTy->getElementType()->getScalarSizeInBits(),
-                   ABIVLen) <= 8;
-      }
-      // LMUL
-      // = fixed-length vector size / ABIVLen
-      // = 8 * I8EltCount / RVVBitsPerBlock
-      // =>
-      // I8EltCount
-      // = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
-      unsigned I8EltCount = llvm::divideCeil(
-          FixedVecTy->getNumElements() *
-              FixedVecTy->getElementType()->getScalarSizeInBits() *
-              llvm::RISCV::RVVBitsPerBlock,
-          ABIVLen * 8);
-      VLSType = llvm::TargetExtType::get(
-          getVMContext(), "riscv.vector.tuple",
-          llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
-                                        I8EltCount),
-          NumElts);
-      // Check registers needed <= 8.
-      return NumElts *
-                 llvm::divideCeil(
-                     FixedVecTy->getNumElements() *
-                         FixedVecTy->getElementType()->getScalarSizeInBits(),
-                     ABIVLen) <=
-             8;
-    }
+  auto *FirstEltTy = STy->getElementType(0);
+  if (!STy->containsHomogeneousTypes())
+    return false;
 
-    // If elements are not fixed-length vectors, it should be an array.
+  if (auto *ArrayTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
+    // Only struct of single array is accepted
     if (NumElts != 1)
       return false;
+    FirstEltTy = ArrayTy->getArrayElementType();
+    NumElts = ArrayTy->getNumElements();
+  }
 
-    // Check array of fixed-length vector and turn it into scalable vector type
-    // if legal.
-    if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
-      unsigned NumArrElt = ArrTy->getNumElements();
-      if (NumArrElt > 8)
-        return false;
-
-      auto *ArrEltTy = 
dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType());
-      if (!ArrEltTy)
-        return false;
+  auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy);
+  if (!FixedVecTy)
+    return false;
 
-      // LMUL
-      // = NumArrElt * fixed-length vector size / ABIVLen
-      // = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
-      // =>
-      // ScalVecNumElts
-      // = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
-      //   (ABIVLen * fixed-length vector elt size)
-      // = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
-      //   ABIVLen
-      unsigned ScalVecNumElts = llvm::divideCeil(
-          NumArrElt * ArrEltTy->getNumElements() * 
llvm::RISCV::RVVBitsPerBlock,
-          ABIVLen);
-      VLSType = llvm::ScalableVectorType::get(ArrEltTy->getElementType(),
-                                              ScalVecNumElts);
-      // Check registers needed <= 8.
-      return llvm::divideCeil(
-                 ScalVecNumElts *
-                     ArrEltTy->getElementType()->getScalarSizeInBits(),
-                 llvm::RISCV::RVVBitsPerBlock) <= 8;
-    }
+  // Turn them into scalable vector type or vector tuple type if legal.
+  if (NumElts == 1) {
+    // Handle single fixed-length vector.
+    VLSType = llvm::ScalableVectorType::get(
+        FixedVecTy->getElementType(),
+        llvm::divideCeil(FixedVecTy->getNumElements() *
+                             llvm::RISCV::RVVBitsPerBlock,
+                         ABIVLen));
+    // Check registers needed <= 8.
+    return llvm::divideCeil(
+               FixedVecTy->getNumElements() *
+                   FixedVecTy->getElementType()->getScalarSizeInBits(),
+               ABIVLen) <= 8;
   }
-  return false;
+  // LMUL
+  // = fixed-length vector size / ABIVLen
+  // = 8 * I8EltCount / RVVBitsPerBlock
+  // =>
+  // I8EltCount
+  // = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
+  unsigned I8EltCount =
+      llvm::divideCeil(FixedVecTy->getNumElements() *
+                           FixedVecTy->getElementType()->getScalarSizeInBits() 
*
+                           llvm::RISCV::RVVBitsPerBlock,
+                       ABIVLen * 8);
+  VLSType = llvm::TargetExtType::get(
+      getVMContext(), "riscv.vector.tuple",
+      llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
+                                    I8EltCount),
+      NumElts);
+  // Check registers needed <= 8.
+  return NumElts * llvm::divideCeil(
+                       FixedVecTy->getNumElements() *
+                           FixedVecTy->getElementType()->getScalarSizeInBits(),
+                       ABIVLen) <=
+         8;
 }
 
 // Fixed-length RVV vectors are represented as scalable vectors in function
diff --git a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c 
b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
index 3044d91f1c31c..82e43fff0c3aa 100644
--- a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
+++ b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
@@ -153,14 +153,14 @@ void __attribute__((riscv_vls_cc)) 
test_st_i32x4_arr1(struct st_i32x4_arr1 arg)
 // CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4_arr1_256(<vscale x 1 x i32> %arg)
 void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr1_256(struct 
st_i32x4_arr1 arg) {}
 
-// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_i32x4_arr4(<vscale x 8 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) %arg)
 void __attribute__((riscv_vls_cc)) test_st_i32x4_arr4(struct st_i32x4_arr4 
arg) {}
-// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4_arr4_256(<vscale x 4 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4_arr4_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 4) %arg)
 void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr4_256(struct 
st_i32x4_arr4 arg) {}
 
-// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_i32x4_arr8(<vscale x 16 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
 void __attribute__((riscv_vls_cc)) test_st_i32x4_arr8(struct st_i32x4_arr8 
arg) {}
-// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4_arr8_256(<vscale x 8 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@test_st_i32x4_arr8_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
 void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr8_256(struct 
st_i32x4_arr8 arg) {}
 
 // CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@test_st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
diff --git a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp 
b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
index 594bfe159b28c..5f6539796c20d 100644
--- a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
+++ b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp
@@ -133,14 +133,14 @@ typedef int __attribute__((vector_size(256))) int32x64_t;
 // CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z22test_st_i32x4_arr1_25613st_i32x4_arr1(<vscale x 1 x i32> %arg)
 [[riscv::vls_cc(256)]] void test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}
 
-// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z18test_st_i32x4_arr413st_i32x4_arr4(<vscale x 8 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z18test_st_i32x4_arr413st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 
x i8>, 4) %arg)
 [[riscv::vls_cc]] void test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
-// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(<vscale x 4 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(target("riscv.vector.tuple", <vscale 
x 4 x i8>, 4) %arg)
 [[riscv::vls_cc(256)]] void test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}
 
-// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z18test_st_i32x4_arr813st_i32x4_arr8(<vscale x 16 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z18test_st_i32x4_arr813st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 
x i8>, 8) %arg)
 [[riscv::vls_cc]] void test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
-// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(<vscale x 8 x i32> %arg)
+// CHECK-LLVM: define dso_local riscv_vls_cc(256) void 
@_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(target("riscv.vector.tuple", <vscale 
x 4 x i8>, 8) %arg)
 [[riscv::vls_cc(256)]] void test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}
 
 // CHECK-LLVM: define dso_local riscv_vls_cc(128) void 
@_Z15test_st_i32x4x210st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x 
i8>, 2) %arg)

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to