https://github.com/kmclaughlin-arm updated 
https://github.com/llvm/llvm-project/pull/154761

>From 625925797e8e7a76471aeaa01150dbee8cf69de5 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Wed, 20 Aug 2025 09:33:17 +0000
Subject: [PATCH 1/7] RDSVL tests

---
 .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll   | 49 +++++++++++++++++++
 1 file changed, 49 insertions(+)

diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll 
b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 5d10d7e13da14..b799f98981520 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -40,6 +40,55 @@ define i64 @sme_cntsd() {
   ret i64 %v
 }
 
+define i64 @sme_cntsb_mul() {
+; CHECK-LABEL: sme_cntsb_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    lsl x0, x8, #1
+; CHECK-NEXT:    ret
+  %v = call i64 @llvm.aarch64.sme.cntsb()
+  %res = mul i64 %v, 2
+  ret i64 %res
+}
+
+define i64 @sme_cntsh_mul() {
+; CHECK-LABEL: sme_cntsh_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    lsr x8, x8, #1
+; CHECK-NEXT:    add x0, x8, x8, lsl #2
+; CHECK-NEXT:    ret
+  %v = call i64 @llvm.aarch64.sme.cntsh()
+  %res = mul i64 %v, 5
+  ret i64 %res
+}
+
+define i64 @sme_cntsw_mul() {
+; CHECK-LABEL: sme_cntsw_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    lsr x8, x8, #2
+; CHECK-NEXT:    lsl x9, x8, #3
+; CHECK-NEXT:    sub x0, x9, x8
+; CHECK-NEXT:    ret
+  %v = call i64 @llvm.aarch64.sme.cntsw()
+  %res = mul i64 %v, 7
+  ret i64 %res
+}
+
+define i64 @sme_cntsd_mul() {
+; CHECK-LABEL: sme_cntsd_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    lsr x8, x8, #3
+; CHECK-NEXT:    add x8, x8, x8, lsl #1
+; CHECK-NEXT:    lsl x0, x8, #2
+; CHECK-NEXT:    ret
+  %v = call i64 @llvm.aarch64.sme.cntsd()
+  %res = mul i64 %v, 12
+  ret i64 %res
+}
+
 declare i64 @llvm.aarch64.sme.cntsb()
 declare i64 @llvm.aarch64.sme.cntsh()
 declare i64 @llvm.aarch64.sme.cntsw()

>From fd8ff8ab97a5d876811b11c43d1e6d6c19100399 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Wed, 13 Aug 2025 14:13:12 +0000
Subject: [PATCH 2/7] [AArch64][SME] Improve codegen for aarch64.sme.cnts* when
 not in streaming mode

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 36 ++++++++++---------
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 23 ++++++++++++
 .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll   | 20 +++++------
 3 files changed, 51 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 23328ed57fb36..6d65d6354b462 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6266,25 +6266,26 @@ SDValue 
AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_clz:
     return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
-  case Intrinsic::aarch64_sme_cntsb:
-    return DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
-                       DAG.getConstant(1, DL, MVT::i32));
+  case Intrinsic::aarch64_sme_cntsb: {
+    SDValue Cntd = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+                       DAG.getConstant(8, DL, MVT::i64));
+  }
   case Intrinsic::aarch64_sme_cntsh: {
-    SDValue One = DAG.getConstant(1, DL, MVT::i32);
-    SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), One);
-    return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, One);
+    SDValue Cntd = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+                       DAG.getConstant(4, DL, MVT::i64));
   }
   case Intrinsic::aarch64_sme_cntsw: {
-    SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
-                                DAG.getConstant(1, DL, MVT::i32));
-    return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
-                       DAG.getConstant(2, DL, MVT::i32));
-  }
-  case Intrinsic::aarch64_sme_cntsd: {
-    SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
-                                DAG.getConstant(1, DL, MVT::i32));
-    return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
-                       DAG.getConstant(3, DL, MVT::i32));
+    SDValue Cntd = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+                       DAG.getConstant(2, DL, MVT::i64));
   }
   case Intrinsic::aarch64_sve_cnt: {
     SDValue Data = Op.getOperand(3);
@@ -19200,6 +19201,9 @@ static SDValue performMulCombine(SDNode *N, 
SelectionDAG &DAG,
        if (ConstValue.sge(1) && ConstValue.sle(16))
          return SDValue();
 
+  if (getIntrinsicID(N0.getNode()) == Intrinsic::aarch64_sme_cntsd)
+    return SDValue();
+
   // Multiplication of a power of two plus/minus one can be done more
   // cheaply as shift+add/sub. For now, this is true unilaterally. If
   // future CPUs have a cheaper MADD instruction, this may need to be
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td 
b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 0d8cb3a76d0be..aecfe37cad823 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -127,12 +127,35 @@ def : Pat<(AArch64_requires_za_save), 
(RequiresZASavePseudo)>;
 def SDT_AArch64RDSVL  : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
 def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
 
+def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
+def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
+def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
+def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+
 let Predicates = [HasSMEandIsNonStreamingSafe] in {
 def RDSVLI_XI  : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", 
/*streaming_sve=*/0b1>;
 def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
 def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
 
 def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
+
+// e.g. cntsb() * imm
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
+          (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
+          (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
+          (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
+          (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
+
+// e.g. cntsb()
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 
2, 63)>;
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 
1, 63)>;
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
+
+// Generic pattern for cntsd (RDSVL #1 >> 3)
+def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
 }
 
 let Predicates = [HasSME] in {
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll 
b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index b799f98981520..8253db1d488e7 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -44,7 +44,8 @@ define i64 @sme_cntsb_mul() {
 ; CHECK-LABEL: sme_cntsb_mul:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    lsl x0, x8, #1
+; CHECK-NEXT:    lsr x8, x8, #3
+; CHECK-NEXT:    lsl x0, x8, #4
 ; CHECK-NEXT:    ret
   %v = call i64 @llvm.aarch64.sme.cntsb()
   %res = mul i64 %v, 2
@@ -54,9 +55,8 @@ define i64 @sme_cntsb_mul() {
 define i64 @sme_cntsh_mul() {
 ; CHECK-LABEL: sme_cntsh_mul:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    lsr x8, x8, #1
-; CHECK-NEXT:    add x0, x8, x8, lsl #2
+; CHECK-NEXT:    rdsvl x8, #5
+; CHECK-NEXT:    lsr x0, x8, #1
 ; CHECK-NEXT:    ret
   %v = call i64 @llvm.aarch64.sme.cntsh()
   %res = mul i64 %v, 5
@@ -66,10 +66,8 @@ define i64 @sme_cntsh_mul() {
 define i64 @sme_cntsw_mul() {
 ; CHECK-LABEL: sme_cntsw_mul:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    lsr x8, x8, #2
-; CHECK-NEXT:    lsl x9, x8, #3
-; CHECK-NEXT:    sub x0, x9, x8
+; CHECK-NEXT:    rdsvl x8, #7
+; CHECK-NEXT:    lsr x0, x8, #2
 ; CHECK-NEXT:    ret
   %v = call i64 @llvm.aarch64.sme.cntsw()
   %res = mul i64 %v, 7
@@ -79,10 +77,8 @@ define i64 @sme_cntsw_mul() {
 define i64 @sme_cntsd_mul() {
 ; CHECK-LABEL: sme_cntsd_mul:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    lsr x8, x8, #3
-; CHECK-NEXT:    add x8, x8, x8, lsl #1
-; CHECK-NEXT:    lsl x0, x8, #2
+; CHECK-NEXT:    rdsvl x8, #3
+; CHECK-NEXT:    lsr x0, x8, #1
 ; CHECK-NEXT:    ret
   %v = call i64 @llvm.aarch64.sme.cntsd()
   %res = mul i64 %v, 12

>From 65da7184ac4d9dfe53989f233e77206c0767215e Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Thu, 21 Aug 2025 14:04:18 +0000
Subject: [PATCH 3/7] - Replace cnts[b|h|w] builtins with cntsd intrinsic in
 Clang - Remove cnts[b|h|w] intrinsics in LLVM - Add patterns for cntsd

---
 clang/include/clang/Basic/arm_sme.td          | 15 +++--
 clang/lib/CodeGen/TargetBuiltins/ARM.cpp      | 30 +++++++++-
 .../AArch64/sme-intrinsics/acle_sme_cnt.c     | 42 ++++++++------
 llvm/include/llvm/IR/IntrinsicsAArch64.td     |  9 +--
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    | 20 +++++++
 .../Target/AArch64/AArch64ISelLowering.cpp    | 21 -------
 .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 28 ++++-----
 .../AArch64/AArch64TargetTransformInfo.cpp    | 16 ++----
 .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll   | 57 ++++++++++---------
 .../sme-streaming-interface-remarks.ll        |  4 +-
 .../AArch64/sme-streaming-interface.ll        |  7 ++-
 .../sme-intrinsic-opts-counting-elems.ll      | 45 ---------------
 12 files changed, 136 insertions(+), 158 deletions(-)

diff --git a/clang/include/clang/Basic/arm_sme.td 
b/clang/include/clang/Basic/arm_sme.td
index a4eb92e76968c..f853122994497 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -156,16 +156,15 @@ let SMETargetGuard = "sme2p1" in {
 
////////////////////////////////////////////////////////////////////////////////
 // SME - Counting elements in a streaming vector
 
-multiclass ZACount<string n_suffix> {
-  def NAME : SInst<"sv" # n_suffix, "nv", "", MergeNone,
-                    "aarch64_sme_" # n_suffix,
-                    [IsOverloadNone, IsStreamingCompatible]>;
+multiclass ZACount<string intr, string n_suffix> {
+  def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
+                   intr, [IsOverloadNone, IsStreamingCompatible]>;
 }
 
-defm SVCNTSB : ZACount<"cntsb">;
-defm SVCNTSH : ZACount<"cntsh">;
-defm SVCNTSW : ZACount<"cntsw">;
-defm SVCNTSD : ZACount<"cntsd">;
+defm SVCNTSB : ZACount<"", "cntsb">;
+defm SVCNTSH : ZACount<"", "cntsh">;
+defm SVCNTSW : ZACount<"", "cntsw">;
+defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
 
 
////////////////////////////////////////////////////////////////////////////////
 // SME - ADDHA/ADDVA
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp 
b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 60413e7b18e85..217232db44b6f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4304,9 +4304,10 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags 
&TypeFlags,
   // size in bytes.
   if (Ops.size() == 5) {
     Function *StreamingVectorLength =
-        CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
+        CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd);
     llvm::Value *StreamingVectorLengthCall =
-        Builder.CreateCall(StreamingVectorLength);
+        Builder.CreateMul(Builder.CreateCall(StreamingVectorLength),
+                          llvm::ConstantInt::get(Int64Ty, 8), "svl");
     llvm::Value *Mulvl =
         Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl");
     // The type of the ptr parameter is void *, so use Int8Ty here.
@@ -4918,6 +4919,31 @@ Value 
*CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
   // Handle builtins which require their multi-vector operands to be swapped
   swapCommutativeSMEOperands(BuiltinID, Ops);
 
+  auto isCntsBuiltin = [&](int64_t &Mul) {
+    switch (BuiltinID) {
+    default:
+      Mul = 0;
+      return false;
+    case SME::BI__builtin_sme_svcntsb:
+      Mul = 8;
+      return true;
+    case SME::BI__builtin_sme_svcntsh:
+      Mul = 4;
+      return true;
+    case SME::BI__builtin_sme_svcntsw:
+      Mul = 2;
+      return true;
+    }
+  };
+
+  int64_t Mul = 0;
+  if (isCntsBuiltin(Mul)) {
+    llvm::Value *Cntd =
+        Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
+    return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),
+                             "mulsvl", /* HasNUW */ true, /* HasNSW */ true);
+  }
+
   // Should not happen!
   if (Builtin->LLVMIntrinsic == 0)
     return nullptr;
diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c 
b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
index c0b3e1a06b0ff..049c1742e5a9d 100644
--- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
+++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
@@ -6,49 +6,55 @@
 
 #include <arm_sme.h>
 
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsb(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 
@test_svcntsb(
 // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT:    ret i64 [[TMP0]]
+// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
+// CHECK-C-NEXT:    ret i64 [[MULSVL]]
 //
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntsbv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, 
-9223372036854775808) i64 @_Z12test_svcntsbv(
 // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT:    ret i64 [[TMP0]]
+// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
+// CHECK-CXX-NEXT:    ret i64 [[MULSVL]]
 //
 uint64_t test_svcntsb() {
   return svcntsb();
 }
 
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsh(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 
@test_svcntsh(
 // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
-// CHECK-C-NEXT:    ret i64 [[TMP0]]
+// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
+// CHECK-C-NEXT:    ret i64 [[MULSVL]]
 //
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntshv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, 
-9223372036854775808) i64 @_Z12test_svcntshv(
 // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
-// CHECK-CXX-NEXT:    ret i64 [[TMP0]]
+// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
+// CHECK-CXX-NEXT:    ret i64 [[MULSVL]]
 //
 uint64_t test_svcntsh() {
   return svcntsh();
 }
 
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsw(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 
@test_svcntsw(
 // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
-// CHECK-C-NEXT:    ret i64 [[TMP0]]
+// CHECK-C-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
+// CHECK-C-NEXT:    ret i64 [[MULSVL]]
 //
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntswv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, 
-9223372036854775808) i64 @_Z12test_svcntswv(
 // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
 // CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
-// CHECK-CXX-NEXT:    ret i64 [[TMP0]]
+// CHECK-CXX-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT:    [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
+// CHECK-CXX-NEXT:    ret i64 [[MULSVL]]
 //
 uint64_t test_svcntsw() {
   return svcntsw();
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td 
b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 6d53bf8b172d8..7c9aef52b3acf 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -3147,13 +3147,8 @@ let TargetPrefix = "aarch64" in {
   // Counting elements
   //
 
-  class AdvSIMD_SME_CNTSB_Intrinsic
-    : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
-
-  def int_aarch64_sme_cntsb : AdvSIMD_SME_CNTSB_Intrinsic;
-  def int_aarch64_sme_cntsh : AdvSIMD_SME_CNTSB_Intrinsic;
-  def int_aarch64_sme_cntsw : AdvSIMD_SME_CNTSB_Intrinsic;
-  def int_aarch64_sme_cntsd : AdvSIMD_SME_CNTSB_Intrinsic;
+  def int_aarch64_sme_cntsd
+      : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
 
   //
   // PSTATE Functions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index bc786f415b554..4e8255bab9437 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -71,6 +71,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   template <signed Low, signed High, signed Scale>
   bool SelectRDVLImm(SDValue N, SDValue &Imm);
 
+  template <signed Low, signed High>
+  bool SelectRDSVLShiftImm(SDValue N, SDValue &Imm);
+
   bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift);
   bool SelectArithUXTXRegister(SDValue N, SDValue &Reg, SDValue &Shift);
   bool SelectArithImmed(SDValue N, SDValue &Val, SDValue &Shift);
@@ -937,6 +940,23 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue 
&Imm) {
   return false;
 }
 
+template <signed Low, signed High>
+bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
+  if (!isa<ConstantSDNode>(N))
+    return false;
+
+  int64_t ShlImm = cast<ConstantSDNode>(N)->getSExtValue();
+  if (ShlImm >= 3) {
+    int64_t MulImm = 1 << (ShlImm - 3);
+    if (MulImm >= Low && MulImm <= High) {
+      Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32);
+      return true;
+    }
+  }
+
+  return false;
+}
+
 /// SelectArithExtendedRegister - Select a "extended register" operand.  This
 /// operand folds in an extend followed by an optional left shift.
 bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6d65d6354b462..08f0ae0b2f783 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6266,27 +6266,6 @@ SDValue 
AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_clz:
     return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
-  case Intrinsic::aarch64_sme_cntsb: {
-    SDValue Cntd = DAG.getNode(
-        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
-        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
-    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
-                       DAG.getConstant(8, DL, MVT::i64));
-  }
-  case Intrinsic::aarch64_sme_cntsh: {
-    SDValue Cntd = DAG.getNode(
-        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
-        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
-    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
-                       DAG.getConstant(4, DL, MVT::i64));
-  }
-  case Intrinsic::aarch64_sme_cntsw: {
-    SDValue Cntd = DAG.getNode(
-        ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
-        DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
-    return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
-                       DAG.getConstant(2, DL, MVT::i64));
-  }
   case Intrinsic::aarch64_sve_cnt: {
     SDValue Data = Op.getOperand(3);
     // CTPOP only supports integer operands.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td 
b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index aecfe37cad823..3b27203d45585 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -127,10 +127,12 @@ def : Pat<(AArch64_requires_za_save), 
(RequiresZASavePseudo)>;
 def SDT_AArch64RDSVL  : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
 def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
 
-def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
-def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
-def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
-def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
+def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
+def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
+def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+
+def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
 
 let Predicates = [HasSMEandIsNonStreamingSafe] in {
 def RDSVLI_XI  : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", 
/*streaming_sve=*/0b1>;
@@ -140,21 +142,21 @@ def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", 
/*streaming_sve=*/0b1>;
 def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
 
 // e.g. cntsb() * imm
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))),
           (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))),
           (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))),
           (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))),
           (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
 
-// e.g. cntsb()
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 
2, 63)>;
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 
1, 63)>;
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))),
+          (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
 
-// Generic pattern for cntsd (RDSVL #1 >> 3)
+// cntsh, cntsw, cntsd
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 
1), 1, 63)>;
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 
1), 2, 63)>;
 def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp 
b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 490f6391c15a0..38958796e2fe1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2102,15 +2102,15 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst 
&II, unsigned NumElts) {
 }
 
 static std::optional<Instruction *>
-instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts,
+instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
                        const AArch64Subtarget *ST) {
   if (!ST->isStreaming())
     return std::nullopt;
 
-  // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt
+  // In streaming-mode, aarch64_sme_cntds is equivalent to aarch64_sve_cntd
   // with SVEPredPattern::all
-  Value *Cnt = IC.Builder.CreateElementCount(
-      II.getType(), ElementCount::getScalable(NumElts));
+  Value *Cnt =
+      IC.Builder.CreateElementCount(II.getType(), 
ElementCount::getScalable(2));
   Cnt->takeName(&II);
   return IC.replaceInstUsesWith(II, Cnt);
 }
@@ -2825,13 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
   case Intrinsic::aarch64_sve_cntb:
     return instCombineSVECntElts(IC, II, 16);
   case Intrinsic::aarch64_sme_cntsd:
-    return instCombineSMECntsElts(IC, II, 2, ST);
-  case Intrinsic::aarch64_sme_cntsw:
-    return instCombineSMECntsElts(IC, II, 4, ST);
-  case Intrinsic::aarch64_sme_cntsh:
-    return instCombineSMECntsElts(IC, II, 8, ST);
-  case Intrinsic::aarch64_sme_cntsb:
-    return instCombineSMECntsElts(IC, II, 16, ST);
+    return instCombineSMECntsElts(IC, II, ST);
   case Intrinsic::aarch64_sve_ptest_any:
   case Intrinsic::aarch64_sve_ptest_first:
   case Intrinsic::aarch64_sve_ptest_last:
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll 
b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 8253db1d488e7..86d3e42deae09 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -1,54 +1,56 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | 
FileCheck %s
 
-define i64 @sme_cntsb() {
-; CHECK-LABEL: sme_cntsb:
+define i64 @cntsb() {
+; CHECK-LABEL: cntsb:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x0, #1
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsb()
-  ret i64 %v
+  %1 = call i64 @llvm.aarch64.sme.cntsd()
+  %res = shl nuw nsw i64 %1, 3
+  ret i64 %res
 }
 
-define i64 @sme_cntsh() {
-; CHECK-LABEL: sme_cntsh:
+define i64 @cntsh() {
+; CHECK-LABEL: cntsh:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    lsr x0, x8, #1
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsh()
-  ret i64 %v
+  %1 = call i64 @llvm.aarch64.sme.cntsd()
+  %res = shl nuw nsw i64 %1, 2
+  ret i64 %res
 }
 
-define i64 @sme_cntsw() {
-; CHECK-LABEL: sme_cntsw:
+define i64 @cntsw() {
+; CHECK-LABEL: cntsw:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    lsr x0, x8, #2
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsw()
-  ret i64 %v
+  %1 = call i64 @llvm.aarch64.sme.cntsd()
+  %res = shl nuw nsw i64 %1, 1
+  ret i64 %res
 }
 
-define i64 @sme_cntsd() {
-; CHECK-LABEL: sme_cntsd:
+define i64 @cntsd() {
+; CHECK-LABEL: cntsd:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    lsr x0, x8, #3
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsd()
-  ret i64 %v
+  %res = call i64 @llvm.aarch64.sme.cntsd()
+  ret i64 %res
 }
 
 define i64 @sme_cntsb_mul() {
 ; CHECK-LABEL: sme_cntsb_mul:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    lsr x8, x8, #3
-; CHECK-NEXT:    lsl x0, x8, #4
+; CHECK-NEXT:    rdsvl x0, #2
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsb()
-  %res = mul i64 %v, 2
+  %v = call i64 @llvm.aarch64.sme.cntsd()
+  %shl = shl nuw nsw i64 %v, 3
+  %res = mul i64 %shl, 2
   ret i64 %res
 }
 
@@ -58,8 +60,9 @@ define i64 @sme_cntsh_mul() {
 ; CHECK-NEXT:    rdsvl x8, #5
 ; CHECK-NEXT:    lsr x0, x8, #1
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsh()
-  %res = mul i64 %v, 5
+  %v = call i64 @llvm.aarch64.sme.cntsd()
+  %shl = shl nuw nsw i64 %v, 2
+  %res = mul i64 %shl, 5
   ret i64 %res
 }
 
@@ -69,8 +72,9 @@ define i64 @sme_cntsw_mul() {
 ; CHECK-NEXT:    rdsvl x8, #7
 ; CHECK-NEXT:    lsr x0, x8, #2
 ; CHECK-NEXT:    ret
-  %v = call i64 @llvm.aarch64.sme.cntsw()
-  %res = mul i64 %v, 7
+  %v = call i64 @llvm.aarch64.sme.cntsd()
+  %shl = shl nuw nsw i64 %v, 1
+  %res = mul i64 %shl, 7
   ret i64 %res
 }
 
@@ -85,7 +89,4 @@ define i64 @sme_cntsd_mul() {
   ret i64 %res
 }
 
-declare i64 @llvm.aarch64.sme.cntsb()
-declare i64 @llvm.aarch64.sme.cntsh()
-declare i64 @llvm.aarch64.sme.cntsw()
 declare i64 @llvm.aarch64.sme.cntsd()
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll 
b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
index e1a474d898233..2806f864c7b25 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
@@ -76,14 +76,14 @@ entry:
   %Data1 = alloca <vscale x 16 x i8>, align 16
   %Data2 = alloca <vscale x 16 x i8>, align 16
   %Data3 = alloca <vscale x 16 x i8>, align 16
-  %0 = tail call i64 @llvm.aarch64.sme.cntsb()
+  %0 = tail call i64 @llvm.aarch64.sme.cntsd()
   call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr 
noundef nonnull %Data3, i64 noundef %0)
   %1 = load <vscale x 16 x i8>, ptr %Data1, align 16
   %vecext = extractelement <vscale x 16 x i8> %1, i64 0
   ret i8 %vecext
 }
 
-declare i64 @llvm.aarch64.sme.cntsb()
+declare i64 @llvm.aarch64.sme.cntsd()
 
 declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
 
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll 
b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
index 8c4d57e244e03..505a40c16653b 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
@@ -366,9 +366,10 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr 
nocapture noundef readnone
 ; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
 ; CHECK-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
 ; CHECK-NEXT:    addvl sp, sp, #-3
-; CHECK-NEXT:    rdsvl x3, #1
+; CHECK-NEXT:    rdsvl x8, #1
 ; CHECK-NEXT:    addvl x0, sp, #2
 ; CHECK-NEXT:    addvl x1, sp, #1
+; CHECK-NEXT:    lsr x3, x8, #3
 ; CHECK-NEXT:    mov x2, sp
 ; CHECK-NEXT:    smstop sm
 ; CHECK-NEXT:    bl foo
@@ -386,7 +387,7 @@ entry:
   %Data1 = alloca <vscale x 16 x i8>, align 16
   %Data2 = alloca <vscale x 16 x i8>, align 16
   %Data3 = alloca <vscale x 16 x i8>, align 16
-  %0 = tail call i64 @llvm.aarch64.sme.cntsb()
+  %0 = tail call i64 @llvm.aarch64.sme.cntsd()
   call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr 
noundef nonnull %Data3, i64 noundef %0)
   %1 = load <vscale x 16 x i8>, ptr %Data1, align 16
   %vecext = extractelement <vscale x 16 x i8> %1, i64 0
@@ -421,7 +422,7 @@ entry:
   ret void
 }
 
-declare i64 @llvm.aarch64.sme.cntsb()
+declare i64 @llvm.aarch64.sme.cntsd()
 
 declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
 declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 
noundef, float noundef, float noundef, double noundef, double noundef)
diff --git 
a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll 
b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
index f213c0b53f6ef..c1d12b825b72c 100644
--- 
a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
+++ 
b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
@@ -5,48 +5,6 @@
 
 target triple = "aarch64-unknown-linux-gnu"
 
-define i64 @cntsb() {
-; CHECK-LABEL: @cntsb(
-; CHECK-NEXT:    [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsb()
-; CHECK-NEXT:    ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsb(
-; CHECK-STREAMING-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT:    [[OUT:%.*]] = shl nuw i64 [[TMP1]], 4
-; CHECK-STREAMING-NEXT:    ret i64 [[OUT]]
-;
-  %out = call i64 @llvm.aarch64.sme.cntsb()
-  ret i64 %out
-}
-
-define i64 @cntsh() {
-; CHECK-LABEL: @cntsh(
-; CHECK-NEXT:    [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsh()
-; CHECK-NEXT:    ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsh(
-; CHECK-STREAMING-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT:    [[OUT:%.*]] = shl nuw i64 [[TMP1]], 3
-; CHECK-STREAMING-NEXT:    ret i64 [[OUT]]
-;
-  %out = call i64 @llvm.aarch64.sme.cntsh()
-  ret i64 %out
-}
-
-define i64 @cntsw() {
-; CHECK-LABEL: @cntsw(
-; CHECK-NEXT:    [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsw()
-; CHECK-NEXT:    ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsw(
-; CHECK-STREAMING-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT:    [[OUT:%.*]] = shl nuw i64 [[TMP1]], 2
-; CHECK-STREAMING-NEXT:    ret i64 [[OUT]]
-;
-  %out = call i64 @llvm.aarch64.sme.cntsw()
-  ret i64 %out
-}
-
 define i64 @cntsd() {
 ; CHECK-LABEL: @cntsd(
 ; CHECK-NEXT:    [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsd()
@@ -61,8 +19,5 @@ define i64 @cntsd() {
   ret i64 %out
 }
 
-declare i64 @llvm.aarch64.sve.cntsb()
-declare i64 @llvm.aarch64.sve.cntsh()
-declare i64 @llvm.aarch64.sve.cntsw()
 declare i64 @llvm.aarch64.sve.cntsd()
 

>From 191a4de7c1bc6a2cde664572709c667a57643715 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Mon, 1 Sep 2025 13:04:42 +0000
Subject: [PATCH 4/7] - Remove cnts[b,h,w] intrinsics from MLIR and fix tests -
 Remove ZACount class from arm_sme.td

---
 clang/include/clang/Basic/arm_sme.td          | 13 +++----
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   |  3 --
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 36 ++++++++++++-------
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         | 17 ++++++---
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  2 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          |  6 ----
 6 files changed, 40 insertions(+), 37 deletions(-)

diff --git a/clang/include/clang/Basic/arm_sme.td 
b/clang/include/clang/Basic/arm_sme.td
index f853122994497..5f6a6eaab80a3 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -156,15 +156,10 @@ let SMETargetGuard = "sme2p1" in {
 
////////////////////////////////////////////////////////////////////////////////
 // SME - Counting elements in a streaming vector
 
-multiclass ZACount<string intr, string n_suffix> {
-  def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
-                   intr, [IsOverloadNone, IsStreamingCompatible]>;
-}
-
-defm SVCNTSB : ZACount<"", "cntsb">;
-defm SVCNTSH : ZACount<"", "cntsh">;
-defm SVCNTSW : ZACount<"", "cntsw">;
-defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
+def SVCNTSB : SInst<"svcntsb", "nv", "", MergeNone, "", [IsOverloadNone, 
IsStreamingCompatible]>;
+def SVCNTSH : SInst<"svcntsh", "nv", "", MergeNone, "", [IsOverloadNone, 
IsStreamingCompatible]>;
+def SVCNTSW : SInst<"svcntsw", "nv", "", MergeNone, "", [IsOverloadNone, 
IsStreamingCompatible]>;
+def SVCNTSD : SInst<"svcntsd", "nv", "", MergeNone, "aarch64_sme_cntsd", 
[IsOverloadNone, IsStreamingCompatible]>;
 
 
////////////////////////////////////////////////////////////////////////////////
 // SME - ADDHA/ADDVA
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td 
b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index 06fb8511774e8..4d19fa5415ef0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -201,9 +201,6 @@ class ArmSME_IntrCountOp<string mnemonic>
                     /*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", 
I64>>],
                     /*numResults=*/1, /*overloadedResults=*/[]>;
 
-def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">;
-def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">;
-def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">;
 def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">;
 
 #endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp 
b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 8a2e3b639aaa7..6b795b18211b2 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -822,7 +822,7 @@ struct OuterProductWideningOpConversion
   }
 };
 
-/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic.
 ///
 /// Example:
 ///
@@ -830,8 +830,10 @@ struct OuterProductWideningOpConversion
 ///
 /// is converted to:
 ///
-///   %cnt = "arm_sme.intr.cntsh"() : () -> i64
-///   %0 = arith.index_cast %cnt : i64 to index
+///   %cnt = "arm_sme.intr.cntsd"() : () -> i64
+///   %0 = arith.constant 4 : i64
+///   %1 = arith.muli %cnt, %0 : i64
+///   %2 = arith.index_cast %1 : i64 to index
 ///
 struct StreamingVLOpConversion
     : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
@@ -845,15 +847,25 @@ struct StreamingVLOpConversion
     auto loc = streamingVlOp.getLoc();
     auto i64Type = rewriter.getI64Type();
     auto *intrOp = [&]() -> Operation * {
+      auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
       switch (streamingVlOp.getTypeSize()) {
-      case arm_sme::TypeSize::Byte:
-        return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
-      case arm_sme::TypeSize::Half:
-        return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
-      case arm_sme::TypeSize::Word:
-        return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
+      case arm_sme::TypeSize::Byte: {
+        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
+        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+      }
+      case arm_sme::TypeSize::Half: {
+        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
+        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+      }
+      case arm_sme::TypeSize::Word: {
+        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
+        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+      }
       case arm_sme::TypeSize::Double:
-        return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
+        return cntsd;
       }
       llvm_unreachable("unknown type size in StreamingVLOpConversion");
     }();
@@ -964,9 +976,7 @@ void 
mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
       arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
       arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
-      arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
-      arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
-      arm_sme::aarch64_sme_cntsd>();
+      arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
   target.addLegalDialect<arith::ArithDialect,
                          /* The following are used to lower tile spills/fills 
*/
                          vector::VectorDialect, scf::SCFDialect,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir 
b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 6a4d77e86ab58..4f3c1dad24b76 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -586,9 +586,10 @@ func.func 
@arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_bytes
-// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
-// CHECK: return %[[INDEX_COUNT]] : index
+// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
 func.func @arm_sme_streaming_vl_bytes() -> index {
   %svl_b = arm_sme.streaming_vl <byte>
   return %svl_b : index
@@ -597,7 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_half_words
-// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
 func.func @arm_sme_streaming_vl_half_words() -> index {
   %svl_h = arm_sme.streaming_vl <half>
   return %svl_h : index
@@ -606,7 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_words
-// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
 func.func @arm_sme_streaming_vl_words() -> index {
   %svl_w = arm_sme.streaming_vl <word>
   return %svl_w : index
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir 
b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index 14821da838726..6f5b1d8c5d93d 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -36,6 +36,6 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
 
 llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 {
   // expected-error @+1 {{failed to verify that `res` is i64}}
-  %res = "arm_sme.intr.cntsb"() : () -> i32
+  %res = "arm_sme.intr.cntsd"() : () -> i32
   llvm.return %res : i32
 }
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir 
b/mlir/test/Target/LLVMIR/arm-sme.mlir
index aedb6730b06bb..0a13a75618a23 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -419,12 +419,6 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : 
i32,
 // -----
 
 llvm.func @arm_sme_streaming_vl() {
-  // CHECK: call i64 @llvm.aarch64.sme.cntsb()
-  %svl_b = "arm_sme.intr.cntsb"() : () -> i64
-  // CHECK: call i64 @llvm.aarch64.sme.cntsh()
-  %svl_h = "arm_sme.intr.cntsh"() : () -> i64
-  // CHECK: call i64 @llvm.aarch64.sme.cntsw()
-  %svl_w = "arm_sme.intr.cntsw"() : () -> i64
   // CHECK: call i64 @llvm.aarch64.sme.cntsd()
   %svl_d = "arm_sme.intr.cntsd"() : () -> i64
   llvm.return

>From e2d91e5e53043b916961940a01f2d65e4ac7b752 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Thu, 4 Sep 2025 10:23:08 +0000
Subject: [PATCH 5/7] - Remove lambda from StreamingVLOpConversion - Add
 getSizeInBytes helper

---
 clang/lib/CodeGen/TargetBuiltins/ARM.cpp      | 17 ++++------
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    |  3 ++
 .../AArch64/AArch64TargetTransformInfo.cpp    |  6 ++--
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  3 ++
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 32 ++++---------------
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          | 15 +++++++++
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         | 18 +++++------
 7 files changed, 46 insertions(+), 48 deletions(-)

diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp 
b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 217232db44b6f..de1bdb335469d 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4919,25 +4919,20 @@ Value 
*CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
   // Handle builtins which require their multi-vector operands to be swapped
   swapCommutativeSMEOperands(BuiltinID, Ops);
 
-  auto isCntsBuiltin = [&](int64_t &Mul) {
+  auto isCntsBuiltin = [&]() {
     switch (BuiltinID) {
     default:
-      Mul = 0;
-      return false;
+      return 0;
     case SME::BI__builtin_sme_svcntsb:
-      Mul = 8;
-      return true;
+      return 8;
     case SME::BI__builtin_sme_svcntsh:
-      Mul = 4;
-      return true;
+      return 4;
     case SME::BI__builtin_sme_svcntsw:
-      Mul = 2;
-      return true;
+      return 2;
     }
   };
 
-  int64_t Mul = 0;
-  if (isCntsBuiltin(Mul)) {
+  if (auto Mul = isCntsBuiltin()) {
     llvm::Value *Cntd =
         Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
     return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 4e8255bab9437..8af10ef8dadc9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -940,6 +940,9 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue 
&Imm) {
   return false;
 }
 
+// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier
+// for RDSVL to calculate the streaming vector length in bytes * N. i.e.
+//   rdsvl, #(ShlImm - 3)
 template <signed Low, signed High>
 bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
   if (!isa<ConstantSDNode>(N))
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp 
b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 38958796e2fe1..d4c7cb11a70a3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2102,8 +2102,8 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst 
&II, unsigned NumElts) {
 }
 
 static std::optional<Instruction *>
-instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
-                       const AArch64Subtarget *ST) {
+instCombineSMECntsd(InstCombiner &IC, IntrinsicInst &II,
+                    const AArch64Subtarget *ST) {
   if (!ST->isStreaming())
     return std::nullopt;
 
@@ -2825,7 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
   case Intrinsic::aarch64_sve_cntb:
     return instCombineSVECntElts(IC, II, 16);
   case Intrinsic::aarch64_sme_cntsd:
-    return instCombineSMECntsElts(IC, II, ST);
+    return instCombineSMECntsd(IC, II, ST);
   case Intrinsic::aarch64_sve_ptest_any:
   case Intrinsic::aarch64_sve_ptest_first:
   case Intrinsic::aarch64_sve_ptest_last:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h 
b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..b57b27de4e1de 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -32,6 +32,9 @@ namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
 
+/// Return the size represented by arm_sme::TypeSize in bytes.
+unsigned getSizeInBytes(TypeSize type);
+
 /// Return minimum number of elements for the given element `type` in
 /// a vector of SVL bits.
 unsigned getSMETileSliceMinNumElts(Type type);
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp 
b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6b795b18211b2..a36f8f09ceada 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -846,31 +846,13 @@ struct StreamingVLOpConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = streamingVlOp.getLoc();
     auto i64Type = rewriter.getI64Type();
-    auto *intrOp = [&]() -> Operation * {
-      auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
-      switch (streamingVlOp.getTypeSize()) {
-      case arm_sme::TypeSize::Byte: {
-        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
-        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
-        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
-      }
-      case arm_sme::TypeSize::Half: {
-        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
-        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
-        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
-      }
-      case arm_sme::TypeSize::Word: {
-        auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
-        auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
-        return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
-      }
-      case arm_sme::TypeSize::Double:
-        return cntsd;
-      }
-      llvm_unreachable("unknown type size in StreamingVLOpConversion");
-    }();
-    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
-        streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+    auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
+    auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
+                                               rewriter.getIndexType(), cntsd);
+    auto scale = arith::ConstantIndexOp::create(
+        rewriter, loc,
+        8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize()));
+    rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp 
b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index e5e1312f0eb04..92f4e4f63c200 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -14,6 +14,21 @@
 
 namespace mlir::arm_sme {
 
+unsigned getSizeInBytes(TypeSize type) {
+  switch (type) {
+  case arm_sme::TypeSize::Byte:
+    return 1;
+  case arm_sme::TypeSize::Half:
+    return 2;
+  case arm_sme::TypeSize::Word:
+    return 4;
+  case arm_sme::TypeSize::Double:
+    return 8;
+  default:
+    llvm_unreachable("unknown type size");
+  }
+}
+
 unsigned getSMETileSliceMinNumElts(Type type) {
   assert(isValidSMETileElementType(type) && "invalid tile type!");
   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir 
b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 4f3c1dad24b76..fd8910265cd89 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -586,10 +586,10 @@ func.func 
@arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_bytes
-// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 8 : index
 // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
 func.func @arm_sme_streaming_vl_bytes() -> index {
   %svl_b = arm_sme.streaming_vl <byte>
   return %svl_b : index
@@ -598,10 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_half_words
-// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 4 : index
 // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
 func.func @arm_sme_streaming_vl_half_words() -> index {
   %svl_h = arm_sme.streaming_vl <half>
   return %svl_h : index
@@ -610,10 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
 // -----
 
 // CHECK-LABEL: @arm_sme_streaming_vl_words
-// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 2 : index
 // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
 func.func @arm_sme_streaming_vl_words() -> index {
   %svl_w = arm_sme.streaming_vl <word>
   return %svl_w : index

>From f5de33d9e1df0155b96fcc54543bc754ab044907 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Fri, 5 Sep 2025 10:27:14 +0000
Subject: [PATCH 6/7] - Fix 'default label in switch' build failure

---
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp 
b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 92f4e4f63c200..e64ae42204fa0 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -24,9 +24,9 @@ unsigned getSizeInBytes(TypeSize type) {
     return 4;
   case arm_sme::TypeSize::Double:
     return 8;
-  default:
-    llvm_unreachable("unknown type size");
   }
+  llvm_unreachable("unknown type size");
+  return 0;
 }
 
 unsigned getSMETileSliceMinNumElts(Type type) {

>From a91144e0f096bcb49bd46eb5286a666dbd45696b Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaugh...@arm.com>
Date: Fri, 5 Sep 2025 14:02:45 +0000
Subject: [PATCH 7/7] - Fix comments in AArch64ISelDAGToDAG & ArmSMEToLLVM

---
 llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp   | 5 ++---
 mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 6 +++---
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 8af10ef8dadc9..8ab313dfed46a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -940,9 +940,8 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue 
&Imm) {
   return false;
 }
 
-// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier
-// for RDSVL to calculate the streaming vector length in bytes * N. i.e.
-//   rdsvl, #(ShlImm - 3)
+// Given `cntsd = (rdsvl, #1) >> 3`, attempt to return a suitable multiplier
+// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(N - 3)`.
 template <signed Low, signed High>
 bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
   if (!isa<ConstantSDNode>(N))
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp 
b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index a36f8f09ceada..033e9ae1f4d4c 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -831,9 +831,9 @@ struct OuterProductWideningOpConversion
 /// is converted to:
 ///
 ///   %cnt = "arm_sme.intr.cntsd"() : () -> i64
-///   %0 = arith.constant 4 : i64
-///   %1 = arith.muli %cnt, %0 : i64
-///   %2 = arith.index_cast %1 : i64 to index
+///   %scale = arith.constant 4 : index
+///   %cntIndex = arith.index_cast %cnt : i64 to index
+///   %0 = arith.muli %cntIndex, %scale : index
 ///
 struct StreamingVLOpConversion
     : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to