https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/154761
>From 625925797e8e7a76471aeaa01150dbee8cf69de5 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Wed, 20 Aug 2025 09:33:17 +0000 Subject: [PATCH 1/7] RDSVL tests --- .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll index 5d10d7e13da14..b799f98981520 100644 --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll @@ -40,6 +40,55 @@ define i64 @sme_cntsd() { ret i64 %v } +define i64 @sme_cntsb_mul() { +; CHECK-LABEL: sme_cntsb_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: lsl x0, x8, #1 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsb() + %res = mul i64 %v, 2 + ret i64 %res +} + +define i64 @sme_cntsh_mul() { +; CHECK-LABEL: sme_cntsh_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: lsr x8, x8, #1 +; CHECK-NEXT: add x0, x8, x8, lsl #2 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsh() + %res = mul i64 %v, 5 + ret i64 %res +} + +define i64 @sme_cntsw_mul() { +; CHECK-LABEL: sme_cntsw_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: lsr x8, x8, #2 +; CHECK-NEXT: lsl x9, x8, #3 +; CHECK-NEXT: sub x0, x9, x8 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsw() + %res = mul i64 %v, 7 + ret i64 %res +} + +define i64 @sme_cntsd_mul() { +; CHECK-LABEL: sme_cntsd_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: lsr x8, x8, #3 +; CHECK-NEXT: add x8, x8, x8, lsl #1 +; CHECK-NEXT: lsl x0, x8, #2 +; CHECK-NEXT: ret + %v = call i64 @llvm.aarch64.sme.cntsd() + %res = mul i64 %v, 12 + ret i64 %res +} + declare i64 @llvm.aarch64.sme.cntsb() declare i64 @llvm.aarch64.sme.cntsh() declare i64 @llvm.aarch64.sme.cntsw() >From fd8ff8ab97a5d876811b11c43d1e6d6c19100399 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Wed, 13 Aug 2025 14:13:12 +0000 Subject: [PATCH 2/7] [AArch64][SME] Improve codegen for aarch64.sme.cnts* when not in streaming mode --- .../Target/AArch64/AArch64ISelLowering.cpp | 36 ++++++++++--------- .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 23 ++++++++++++ .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 20 +++++------ 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 23328ed57fb36..6d65d6354b462 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6266,25 +6266,26 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_clz: return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sme_cntsb: - return DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), - DAG.getConstant(1, DL, MVT::i32)); + case Intrinsic::aarch64_sme_cntsb: { + SDValue Cntd = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), + DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); + return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, + DAG.getConstant(8, DL, MVT::i64)); + } case Intrinsic::aarch64_sme_cntsh: { - SDValue One = DAG.getConstant(1, DL, MVT::i32); - SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), One); - return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, One); + SDValue Cntd = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), + DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); + return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, + DAG.getConstant(4, DL, MVT::i64)); } case Intrinsic::aarch64_sme_cntsw: { - SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), - DAG.getConstant(1, DL, MVT::i32)); - return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, - DAG.getConstant(2, DL, MVT::i32)); - } - case Intrinsic::aarch64_sme_cntsd: { - SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), - DAG.getConstant(1, DL, MVT::i32)); - return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, - DAG.getConstant(3, DL, MVT::i32)); + SDValue Cntd = DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), + DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); + return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, + DAG.getConstant(2, DL, MVT::i64)); } case Intrinsic::aarch64_sve_cnt: { SDValue Data = Op.getOperand(3); @@ -19200,6 +19201,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, if (ConstValue.sge(1) && ConstValue.sle(16)) return SDValue(); + if (getIntrinsicID(N0.getNode()) == Intrinsic::aarch64_sme_cntsd) + return SDValue(); + // Multiplication of a power of two plus/minus one can be done more // cheaply as shift+add/sub. For now, this is true unilaterally. If // future CPUs have a cheaper MADD instruction, this may need to be diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 0d8cb3a76d0be..aecfe37cad823 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -127,12 +127,35 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>; def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>; def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>; +def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">; +def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">; +def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">; +def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">; + let Predicates = [HasSMEandIsNonStreamingSafe] in { def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>; def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>; def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>; def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>; + +// e.g. cntsb() * imm +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))), + (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>; +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))), + (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>; +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))), + (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>; +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))), + (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>; + +// e.g. cntsb() +def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>; +def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>; +def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>; + +// Generic pattern for cntsd (RDSVL #1 >> 3) +def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>; } let Predicates = [HasSME] in { diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll index b799f98981520..8253db1d488e7 100644 --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll @@ -44,7 +44,8 @@ define i64 @sme_cntsb_mul() { ; CHECK-LABEL: sme_cntsb_mul: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: lsl x0, x8, #1 +; CHECK-NEXT: lsr x8, x8, #3 +; CHECK-NEXT: lsl x0, x8, #4 ; CHECK-NEXT: ret %v = call i64 @llvm.aarch64.sme.cntsb() %res = mul i64 %v, 2 @@ -54,9 +55,8 @@ define i64 @sme_cntsb_mul() { define i64 @sme_cntsh_mul() { ; CHECK-LABEL: sme_cntsh_mul: ; CHECK: // %bb.0: -; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #1 -; CHECK-NEXT: add x0, x8, x8, lsl #2 +; CHECK-NEXT: rdsvl x8, #5 +; CHECK-NEXT: lsr x0, x8, #1 ; CHECK-NEXT: ret %v = call i64 @llvm.aarch64.sme.cntsh() %res = mul i64 %v, 5 @@ -66,10 +66,8 @@ define i64 @sme_cntsh_mul() { define i64 @sme_cntsw_mul() { ; CHECK-LABEL: sme_cntsw_mul: ; CHECK: // %bb.0: -; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #2 -; CHECK-NEXT: lsl x9, x8, #3 -; CHECK-NEXT: sub x0, x9, x8 +; CHECK-NEXT: rdsvl x8, #7 +; CHECK-NEXT: lsr x0, x8, #2 ; CHECK-NEXT: ret %v = call i64 @llvm.aarch64.sme.cntsw() %res = mul i64 %v, 7 @@ -79,10 +77,8 @@ define i64 @sme_cntsw_mul() { define i64 @sme_cntsd_mul() { ; CHECK-LABEL: sme_cntsd_mul: ; CHECK: // %bb.0: -; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #3 -; CHECK-NEXT: add x8, x8, x8, lsl #1 -; CHECK-NEXT: lsl x0, x8, #2 +; CHECK-NEXT: rdsvl x8, #3 +; CHECK-NEXT: lsr x0, x8, #1 ; CHECK-NEXT: ret %v = call i64 @llvm.aarch64.sme.cntsd() %res = mul i64 %v, 12 >From 65da7184ac4d9dfe53989f233e77206c0767215e Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Thu, 21 Aug 2025 14:04:18 +0000 Subject: [PATCH 3/7] - Replace cnts[b|h|w] builtins with cntsd intrinsic in Clang - Remove cnts[b|h|w] intrinsics in LLVM - Add patterns for cntsd --- clang/include/clang/Basic/arm_sme.td | 15 +++-- clang/lib/CodeGen/TargetBuiltins/ARM.cpp | 30 +++++++++- .../AArch64/sme-intrinsics/acle_sme_cnt.c | 42 ++++++++------ llvm/include/llvm/IR/IntrinsicsAArch64.td | 9 +-- .../Target/AArch64/AArch64ISelDAGToDAG.cpp | 20 +++++++ .../Target/AArch64/AArch64ISelLowering.cpp | 21 ------- .../lib/Target/AArch64/AArch64SMEInstrInfo.td | 28 ++++----- .../AArch64/AArch64TargetTransformInfo.cpp | 16 ++---- .../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 57 ++++++++++--------- .../sme-streaming-interface-remarks.ll | 4 +- .../AArch64/sme-streaming-interface.ll | 7 ++- .../sme-intrinsic-opts-counting-elems.ll | 45 --------------- 12 files changed, 136 insertions(+), 158 deletions(-) diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td index a4eb92e76968c..f853122994497 100644 --- a/clang/include/clang/Basic/arm_sme.td +++ b/clang/include/clang/Basic/arm_sme.td @@ -156,16 +156,15 @@ let SMETargetGuard = "sme2p1" in { //////////////////////////////////////////////////////////////////////////////// // SME - Counting elements in a streaming vector -multiclass ZACount<string n_suffix> { - def NAME : SInst<"sv" # n_suffix, "nv", "", MergeNone, - "aarch64_sme_" # n_suffix, - [IsOverloadNone, IsStreamingCompatible]>; +multiclass ZACount<string intr, string n_suffix> { + def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone, + intr, [IsOverloadNone, IsStreamingCompatible]>; } -defm SVCNTSB : ZACount<"cntsb">; -defm SVCNTSH : ZACount<"cntsh">; -defm SVCNTSW : ZACount<"cntsw">; -defm SVCNTSD : ZACount<"cntsd">; +defm SVCNTSB : ZACount<"", "cntsb">; +defm SVCNTSH : ZACount<"", "cntsh">; +defm SVCNTSW : ZACount<"", "cntsw">; +defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">; //////////////////////////////////////////////////////////////////////////////// // SME - ADDHA/ADDVA diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp index 60413e7b18e85..217232db44b6f 100644 --- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp @@ -4304,9 +4304,10 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags, // size in bytes. if (Ops.size() == 5) { Function *StreamingVectorLength = - CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb); + CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd); llvm::Value *StreamingVectorLengthCall = - Builder.CreateCall(StreamingVectorLength); + Builder.CreateMul(Builder.CreateCall(StreamingVectorLength), + llvm::ConstantInt::get(Int64Ty, 8), "svl"); llvm::Value *Mulvl = Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl"); // The type of the ptr parameter is void *, so use Int8Ty here. @@ -4918,6 +4919,31 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID, // Handle builtins which require their multi-vector operands to be swapped swapCommutativeSMEOperands(BuiltinID, Ops); + auto isCntsBuiltin = [&](int64_t &Mul) { + switch (BuiltinID) { + default: + Mul = 0; + return false; + case SME::BI__builtin_sme_svcntsb: + Mul = 8; + return true; + case SME::BI__builtin_sme_svcntsh: + Mul = 4; + return true; + case SME::BI__builtin_sme_svcntsw: + Mul = 2; + return true; + } + }; + + int64_t Mul = 0; + if (isCntsBuiltin(Mul)) { + llvm::Value *Cntd = + Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd)); + return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul), + "mulsvl", /* HasNUW */ true, /* HasNSW */ true); + } + // Should not happen! if (Builtin->LLVMIntrinsic == 0) return nullptr; diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c index c0b3e1a06b0ff..049c1742e5a9d 100644 --- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c +++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c @@ -6,49 +6,55 @@ #include <arm_sme.h> -// CHECK-C-LABEL: define dso_local i64 @test_svcntsb( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsb( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntsbv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntsbv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsb() { return svcntsb(); } -// CHECK-C-LABEL: define dso_local i64 @test_svcntsh( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsh( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntshv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntshv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsh() { return svcntsh(); } -// CHECK-C-LABEL: define dso_local i64 @test_svcntsw( +// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsw( // CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-C-NEXT: entry: -// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw() -// CHECK-C-NEXT: ret i64 [[TMP0]] +// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1 +// CHECK-C-NEXT: ret i64 [[MULSVL]] // -// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntswv( +// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntswv( // CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] { // CHECK-CXX-NEXT: entry: -// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw() -// CHECK-CXX-NEXT: ret i64 [[TMP0]] +// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd() +// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1 +// CHECK-CXX-NEXT: ret i64 [[MULSVL]] // uint64_t test_svcntsw() { return svcntsw(); diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index 6d53bf8b172d8..7c9aef52b3acf 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3147,13 +3147,8 @@ let TargetPrefix = "aarch64" in { // Counting elements // - class AdvSIMD_SME_CNTSB_Intrinsic - : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>; - - def int_aarch64_sme_cntsb : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsh : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsw : AdvSIMD_SME_CNTSB_Intrinsic; - def int_aarch64_sme_cntsd : AdvSIMD_SME_CNTSB_Intrinsic; + def int_aarch64_sme_cntsd + : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>; // // PSTATE Functions diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index bc786f415b554..4e8255bab9437 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -71,6 +71,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel { template <signed Low, signed High, signed Scale> bool SelectRDVLImm(SDValue N, SDValue &Imm); + template <signed Low, signed High> + bool SelectRDSVLShiftImm(SDValue N, SDValue &Imm); + bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift); bool SelectArithUXTXRegister(SDValue N, SDValue &Reg, SDValue &Shift); bool SelectArithImmed(SDValue N, SDValue &Val, SDValue &Shift); @@ -937,6 +940,23 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) { return false; } +template <signed Low, signed High> +bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) { + if (!isa<ConstantSDNode>(N)) + return false; + + int64_t ShlImm = cast<ConstantSDNode>(N)->getSExtValue(); + if (ShlImm >= 3) { + int64_t MulImm = 1 << (ShlImm - 3); + if (MulImm >= Low && MulImm <= High) { + Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32); + return true; + } + } + + return false; +} + /// SelectArithExtendedRegister - Select a "extended register" operand. This /// operand folds in an extend followed by an optional left shift. bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 6d65d6354b462..08f0ae0b2f783 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6266,27 +6266,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_clz: return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sme_cntsb: { - SDValue Cntd = DAG.getNode( - ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), - DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); - return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, - DAG.getConstant(8, DL, MVT::i64)); - } - case Intrinsic::aarch64_sme_cntsh: { - SDValue Cntd = DAG.getNode( - ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), - DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); - return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, - DAG.getConstant(4, DL, MVT::i64)); - } - case Intrinsic::aarch64_sme_cntsw: { - SDValue Cntd = DAG.getNode( - ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), - DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64)); - return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd, - DAG.getConstant(2, DL, MVT::i64)); - } case Intrinsic::aarch64_sve_cnt: { SDValue Data = Op.getOperand(3); // CTPOP only supports integer operands. diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index aecfe37cad823..3b27203d45585 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -127,10 +127,12 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>; def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>; def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>; -def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">; -def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">; -def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">; -def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">; +def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">; +def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">; +def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">; +def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">; + +def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">; let Predicates = [HasSMEandIsNonStreamingSafe] in { def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>; @@ -140,21 +142,21 @@ def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>; def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>; // e.g. cntsb() * imm -def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))), +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))), (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>; -def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))), +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))), (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>; -def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))), +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))), (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>; -def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))), +def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))), (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>; -// e.g. cntsb() -def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>; -def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>; -def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>; +def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))), + (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>; -// Generic pattern for cntsd (RDSVL #1 >> 3) +// cntsh, cntsw, cntsd +def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>; +def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>; def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 490f6391c15a0..38958796e2fe1 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -2102,15 +2102,15 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { } static std::optional<Instruction *> -instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts, +instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, const AArch64Subtarget *ST) { if (!ST->isStreaming()) return std::nullopt; - // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt + // In streaming-mode, aarch64_sme_cntds is equivalent to aarch64_sve_cntd // with SVEPredPattern::all - Value *Cnt = IC.Builder.CreateElementCount( - II.getType(), ElementCount::getScalable(NumElts)); + Value *Cnt = + IC.Builder.CreateElementCount(II.getType(), ElementCount::getScalable(2)); Cnt->takeName(&II); return IC.replaceInstUsesWith(II, Cnt); } @@ -2825,13 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, case Intrinsic::aarch64_sve_cntb: return instCombineSVECntElts(IC, II, 16); case Intrinsic::aarch64_sme_cntsd: - return instCombineSMECntsElts(IC, II, 2, ST); - case Intrinsic::aarch64_sme_cntsw: - return instCombineSMECntsElts(IC, II, 4, ST); - case Intrinsic::aarch64_sme_cntsh: - return instCombineSMECntsElts(IC, II, 8, ST); - case Intrinsic::aarch64_sme_cntsb: - return instCombineSMECntsElts(IC, II, 16, ST); + return instCombineSMECntsElts(IC, II, ST); case Intrinsic::aarch64_sve_ptest_any: case Intrinsic::aarch64_sve_ptest_first: case Intrinsic::aarch64_sve_ptest_last: diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll index 8253db1d488e7..86d3e42deae09 100644 --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll @@ -1,54 +1,56 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s -define i64 @sme_cntsb() { -; CHECK-LABEL: sme_cntsb: +define i64 @cntsb() { +; CHECK-LABEL: cntsb: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x0, #1 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsb() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 3 + ret i64 %res } -define i64 @sme_cntsh() { -; CHECK-LABEL: sme_cntsh: +define i64 @cntsh() { +; CHECK-LABEL: cntsh: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #1 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsh() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 2 + ret i64 %res } -define i64 @sme_cntsw() { -; CHECK-LABEL: sme_cntsw: +define i64 @cntsw() { +; CHECK-LABEL: cntsw: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #2 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsw() - ret i64 %v + %1 = call i64 @llvm.aarch64.sme.cntsd() + %res = shl nuw nsw i64 %1, 1 + ret i64 %res } -define i64 @sme_cntsd() { -; CHECK-LABEL: sme_cntsd: +define i64 @cntsd() { +; CHECK-LABEL: cntsd: ; CHECK: // %bb.0: ; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: lsr x0, x8, #3 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsd() - ret i64 %v + %res = call i64 @llvm.aarch64.sme.cntsd() + ret i64 %res } define i64 @sme_cntsb_mul() { ; CHECK-LABEL: sme_cntsb_mul: ; CHECK: // %bb.0: -; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #3 -; CHECK-NEXT: lsl x0, x8, #4 +; CHECK-NEXT: rdsvl x0, #2 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsb() - %res = mul i64 %v, 2 + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 3 + %res = mul i64 %shl, 2 ret i64 %res } @@ -58,8 +60,9 @@ define i64 @sme_cntsh_mul() { ; CHECK-NEXT: rdsvl x8, #5 ; CHECK-NEXT: lsr x0, x8, #1 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsh() - %res = mul i64 %v, 5 + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 2 + %res = mul i64 %shl, 5 ret i64 %res } @@ -69,8 +72,9 @@ define i64 @sme_cntsw_mul() { ; CHECK-NEXT: rdsvl x8, #7 ; CHECK-NEXT: lsr x0, x8, #2 ; CHECK-NEXT: ret - %v = call i64 @llvm.aarch64.sme.cntsw() - %res = mul i64 %v, 7 + %v = call i64 @llvm.aarch64.sme.cntsd() + %shl = shl nuw nsw i64 %v, 1 + %res = mul i64 %shl, 7 ret i64 %res } @@ -85,7 +89,4 @@ define i64 @sme_cntsd_mul() { ret i64 %res } -declare i64 @llvm.aarch64.sme.cntsb() -declare i64 @llvm.aarch64.sme.cntsh() -declare i64 @llvm.aarch64.sme.cntsw() declare i64 @llvm.aarch64.sme.cntsd() diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll index e1a474d898233..2806f864c7b25 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll @@ -76,14 +76,14 @@ entry: %Data1 = alloca <vscale x 16 x i8>, align 16 %Data2 = alloca <vscale x 16 x i8>, align 16 %Data3 = alloca <vscale x 16 x i8>, align 16 - %0 = tail call i64 @llvm.aarch64.sme.cntsb() + %0 = tail call i64 @llvm.aarch64.sme.cntsd() call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0) %1 = load <vscale x 16 x i8>, ptr %Data1, align 16 %vecext = extractelement <vscale x 16 x i8> %1, i64 0 ret i8 %vecext } -declare i64 @llvm.aarch64.sme.cntsb() +declare i64 @llvm.aarch64.sme.cntsd() declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll index 8c4d57e244e03..505a40c16653b 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll @@ -366,9 +366,10 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone ; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill ; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill ; CHECK-NEXT: addvl sp, sp, #-3 -; CHECK-NEXT: rdsvl x3, #1 +; CHECK-NEXT: rdsvl x8, #1 ; CHECK-NEXT: addvl x0, sp, #2 ; CHECK-NEXT: addvl x1, sp, #1 +; CHECK-NEXT: lsr x3, x8, #3 ; CHECK-NEXT: mov x2, sp ; CHECK-NEXT: smstop sm ; CHECK-NEXT: bl foo @@ -386,7 +387,7 @@ entry: %Data1 = alloca <vscale x 16 x i8>, align 16 %Data2 = alloca <vscale x 16 x i8>, align 16 %Data3 = alloca <vscale x 16 x i8>, align 16 - %0 = tail call i64 @llvm.aarch64.sme.cntsb() + %0 = tail call i64 @llvm.aarch64.sme.cntsd() call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0) %1 = load <vscale x 16 x i8>, ptr %Data1, align 16 %vecext = extractelement <vscale x 16 x i8> %1, i64 0 @@ -421,7 +422,7 @@ entry: ret void } -declare i64 @llvm.aarch64.sme.cntsb() +declare i64 @llvm.aarch64.sme.cntsd() declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef) diff --git a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll index f213c0b53f6ef..c1d12b825b72c 100644 --- a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll @@ -5,48 +5,6 @@ target triple = "aarch64-unknown-linux-gnu" -define i64 @cntsb() { -; CHECK-LABEL: @cntsb( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsb() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsb( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 4 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsb() - ret i64 %out -} - -define i64 @cntsh() { -; CHECK-LABEL: @cntsh( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsh() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsh( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 3 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsh() - ret i64 %out -} - -define i64 @cntsw() { -; CHECK-LABEL: @cntsw( -; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsw() -; CHECK-NEXT: ret i64 [[OUT]] -; -; CHECK-STREAMING-LABEL: @cntsw( -; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 2 -; CHECK-STREAMING-NEXT: ret i64 [[OUT]] -; - %out = call i64 @llvm.aarch64.sme.cntsw() - ret i64 %out -} - define i64 @cntsd() { ; CHECK-LABEL: @cntsd( ; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsd() @@ -61,8 +19,5 @@ define i64 @cntsd() { ret i64 %out } -declare i64 @llvm.aarch64.sve.cntsb() -declare i64 @llvm.aarch64.sve.cntsh() -declare i64 @llvm.aarch64.sve.cntsw() declare i64 @llvm.aarch64.sve.cntsd() >From 191a4de7c1bc6a2cde664572709c667a57643715 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Mon, 1 Sep 2025 13:04:42 +0000 Subject: [PATCH 4/7] - Remove cnts[b,h,w] intrinsics from MLIR and fix tests - Remove ZACount class from arm_sme.td --- clang/include/clang/Basic/arm_sme.td | 13 +++---- .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 3 -- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 36 ++++++++++++------- .../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 17 ++++++--- mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 2 +- mlir/test/Target/LLVMIR/arm-sme.mlir | 6 ---- 6 files changed, 40 insertions(+), 37 deletions(-) diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td index f853122994497..5f6a6eaab80a3 100644 --- a/clang/include/clang/Basic/arm_sme.td +++ b/clang/include/clang/Basic/arm_sme.td @@ -156,15 +156,10 @@ let SMETargetGuard = "sme2p1" in { //////////////////////////////////////////////////////////////////////////////// // SME - Counting elements in a streaming vector -multiclass ZACount<string intr, string n_suffix> { - def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone, - intr, [IsOverloadNone, IsStreamingCompatible]>; -} - -defm SVCNTSB : ZACount<"", "cntsb">; -defm SVCNTSH : ZACount<"", "cntsh">; -defm SVCNTSW : ZACount<"", "cntsw">; -defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">; +def SVCNTSB : SInst<"svcntsb", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSH : SInst<"svcntsh", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSW : SInst<"svcntsw", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>; +def SVCNTSD : SInst<"svcntsd", "nv", "", MergeNone, "aarch64_sme_cntsd", [IsOverloadNone, IsStreamingCompatible]>; //////////////////////////////////////////////////////////////////////////////// // SME - ADDHA/ADDVA diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index 06fb8511774e8..4d19fa5415ef0 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -201,9 +201,6 @@ class ArmSME_IntrCountOp<string mnemonic> /*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>], /*numResults=*/1, /*overloadedResults=*/[]>; -def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">; -def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">; -def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">; def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">; #endif // ARMSME_INTRINSIC_OPS diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 8a2e3b639aaa7..6b795b18211b2 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -822,7 +822,7 @@ struct OuterProductWideningOpConversion } }; -/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. +/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic. /// /// Example: /// @@ -830,8 +830,10 @@ struct OuterProductWideningOpConversion /// /// is converted to: /// -/// %cnt = "arm_sme.intr.cntsh"() : () -> i64 -/// %0 = arith.index_cast %cnt : i64 to index +/// %cnt = "arm_sme.intr.cntsd"() : () -> i64 +/// %0 = arith.constant 4 : i64 +/// %1 = arith.muli %cnt, %0 : i64 +/// %2 = arith.index_cast %1 : i64 to index /// struct StreamingVLOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, @@ -845,15 +847,25 @@ struct StreamingVLOpConversion auto loc = streamingVlOp.getLoc(); auto i64Type = rewriter.getI64Type(); auto *intrOp = [&]() -> Operation * { + auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); switch (streamingVlOp.getTypeSize()) { - case arm_sme::TypeSize::Byte: - return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); - case arm_sme::TypeSize::Half: - return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); - case arm_sme::TypeSize::Word: - return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); + case arm_sme::TypeSize::Byte: { + auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8); + auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); + return arith::MulIOp::create(rewriter, loc, cntsd, mul64); + } + case arm_sme::TypeSize::Half: { + auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4); + auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); + return arith::MulIOp::create(rewriter, loc, cntsd, mul64); + } + case arm_sme::TypeSize::Word: { + auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2); + auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); + return arith::MulIOp::create(rewriter, loc, cntsd, mul64); + } case arm_sme::TypeSize::Double: - return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); + return cntsd; } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -964,9 +976,7 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide, - arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, - arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, - arm_sme::aarch64_sme_cntsd>(); + arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>(); target.addLegalDialect<arith::ArithDialect, /* The following are used to lower tile spills/fills */ vector::VectorDialect, scf::SCFDialect, diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index 6a4d77e86ab58..4f3c1dad24b76 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -586,9 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec // ----- // CHECK-LABEL: @arm_sme_streaming_vl_bytes -// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64 -// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index -// CHECK: return %[[INDEX_COUNT]] : index +// CHECK: %[[CONST:.*]] = arith.constant 8 : i64 +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 +// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index func.func @arm_sme_streaming_vl_bytes() -> index { %svl_b = arm_sme.streaming_vl <byte> return %svl_b : index @@ -597,7 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_half_words -// CHECK: "arm_sme.intr.cntsh"() : () -> i64 +// CHECK: %[[CONST:.*]] = arith.constant 4 : i64 +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 +// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index func.func @arm_sme_streaming_vl_half_words() -> index { %svl_h = arm_sme.streaming_vl <half> return %svl_h : index @@ -606,7 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_words -// CHECK: "arm_sme.intr.cntsw"() : () -> i64 +// CHECK: %[[CONST:.*]] = arith.constant 2 : i64 +// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 +// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index func.func @arm_sme_streaming_vl_words() -> index { %svl_w = arm_sme.streaming_vl <word> return %svl_w : index diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir index 14821da838726..6f5b1d8c5d93d 100644 --- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir @@ -36,6 +36,6 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types( llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 { // expected-error @+1 {{failed to verify that `res` is i64}} - %res = "arm_sme.intr.cntsb"() : () -> i32 + %res = "arm_sme.intr.cntsd"() : () -> i32 llvm.return %res : i32 } diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir index aedb6730b06bb..0a13a75618a23 100644 --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -419,12 +419,6 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32, // ----- llvm.func @arm_sme_streaming_vl() { - // CHECK: call i64 @llvm.aarch64.sme.cntsb() - %svl_b = "arm_sme.intr.cntsb"() : () -> i64 - // CHECK: call i64 @llvm.aarch64.sme.cntsh() - %svl_h = "arm_sme.intr.cntsh"() : () -> i64 - // CHECK: call i64 @llvm.aarch64.sme.cntsw() - %svl_w = "arm_sme.intr.cntsw"() : () -> i64 // CHECK: call i64 @llvm.aarch64.sme.cntsd() %svl_d = "arm_sme.intr.cntsd"() : () -> i64 llvm.return >From e2d91e5e53043b916961940a01f2d65e4ac7b752 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Thu, 4 Sep 2025 10:23:08 +0000 Subject: [PATCH 5/7] - Remove lambda from StreamingVLOpConversion - Add getSizeInBytes helper --- clang/lib/CodeGen/TargetBuiltins/ARM.cpp | 17 ++++------ .../Target/AArch64/AArch64ISelDAGToDAG.cpp | 3 ++ .../AArch64/AArch64TargetTransformInfo.cpp | 6 ++-- .../include/mlir/Dialect/ArmSME/Utils/Utils.h | 3 ++ .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 32 ++++--------------- mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 15 +++++++++ .../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 18 +++++------ 7 files changed, 46 insertions(+), 48 deletions(-) diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp index 217232db44b6f..de1bdb335469d 100644 --- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp @@ -4919,25 +4919,20 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID, // Handle builtins which require their multi-vector operands to be swapped swapCommutativeSMEOperands(BuiltinID, Ops); - auto isCntsBuiltin = [&](int64_t &Mul) { + auto isCntsBuiltin = [&]() { switch (BuiltinID) { default: - Mul = 0; - return false; + return 0; case SME::BI__builtin_sme_svcntsb: - Mul = 8; - return true; + return 8; case SME::BI__builtin_sme_svcntsh: - Mul = 4; - return true; + return 4; case SME::BI__builtin_sme_svcntsw: - Mul = 2; - return true; + return 2; } }; - int64_t Mul = 0; - if (isCntsBuiltin(Mul)) { + if (auto Mul = isCntsBuiltin()) { llvm::Value *Cntd = Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd)); return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul), diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 4e8255bab9437..8af10ef8dadc9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -940,6 +940,9 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) { return false; } +// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier +// for RDSVL to calculate the streaming vector length in bytes * N. i.e. +// rdsvl, #(ShlImm - 3) template <signed Low, signed High> bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) { if (!isa<ConstantSDNode>(N)) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 38958796e2fe1..d4c7cb11a70a3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -2102,8 +2102,8 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { } static std::optional<Instruction *> -instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, - const AArch64Subtarget *ST) { +instCombineSMECntsd(InstCombiner &IC, IntrinsicInst &II, + const AArch64Subtarget *ST) { if (!ST->isStreaming()) return std::nullopt; @@ -2825,7 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, case Intrinsic::aarch64_sve_cntb: return instCombineSVECntElts(IC, II, 16); case Intrinsic::aarch64_sme_cntsd: - return instCombineSMECntsElts(IC, II, ST); + return instCombineSMECntsd(IC, II, ST); case Intrinsic::aarch64_sve_ptest_any: case Intrinsic::aarch64_sve_ptest_first: case Intrinsic::aarch64_sve_ptest_last: diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 1f40eb6fc693c..b57b27de4e1de 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -32,6 +32,9 @@ namespace mlir::arm_sme { constexpr unsigned MinStreamingVectorLengthInBits = 128; +/// Return the size represented by arm_sme::TypeSize in bytes. +unsigned getSizeInBytes(TypeSize type); + /// Return minimum number of elements for the given element `type` in /// a vector of SVL bits. unsigned getSMETileSliceMinNumElts(Type type); diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 6b795b18211b2..a36f8f09ceada 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -846,31 +846,13 @@ struct StreamingVLOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = streamingVlOp.getLoc(); auto i64Type = rewriter.getI64Type(); - auto *intrOp = [&]() -> Operation * { - auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); - switch (streamingVlOp.getTypeSize()) { - case arm_sme::TypeSize::Byte: { - auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8); - auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); - return arith::MulIOp::create(rewriter, loc, cntsd, mul64); - } - case arm_sme::TypeSize::Half: { - auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4); - auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); - return arith::MulIOp::create(rewriter, loc, cntsd, mul64); - } - case arm_sme::TypeSize::Word: { - auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2); - auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul); - return arith::MulIOp::create(rewriter, loc, cntsd, mul64); - } - case arm_sme::TypeSize::Double: - return cntsd; - } - llvm_unreachable("unknown type size in StreamingVLOpConversion"); - }(); - rewriter.replaceOpWithNewOp<arith::IndexCastOp>( - streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); + auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); + auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), cntsd); + auto scale = arith::ConstantIndexOp::create( + rewriter, loc, + 8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize())); + rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale); return success(); } }; diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index e5e1312f0eb04..92f4e4f63c200 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -14,6 +14,21 @@ namespace mlir::arm_sme { +unsigned getSizeInBytes(TypeSize type) { + switch (type) { + case arm_sme::TypeSize::Byte: + return 1; + case arm_sme::TypeSize::Half: + return 2; + case arm_sme::TypeSize::Word: + return 4; + case arm_sme::TypeSize::Double: + return 8; + default: + llvm_unreachable("unknown type size"); + } +} + unsigned getSMETileSliceMinNumElts(Type type) { assert(isValidSMETileElementType(type) && "invalid tile type!"); return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index 4f3c1dad24b76..fd8910265cd89 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -586,10 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec // ----- // CHECK-LABEL: @arm_sme_streaming_vl_bytes -// CHECK: %[[CONST:.*]] = arith.constant 8 : i64 +// CHECK: %[[CONST:.*]] = arith.constant 8 : index // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 -// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 -// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_bytes() -> index { %svl_b = arm_sme.streaming_vl <byte> return %svl_b : index @@ -598,10 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_half_words -// CHECK: %[[CONST:.*]] = arith.constant 4 : i64 +// CHECK: %[[CONST:.*]] = arith.constant 4 : index // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 -// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 -// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_half_words() -> index { %svl_h = arm_sme.streaming_vl <half> return %svl_h : index @@ -610,10 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index { // ----- // CHECK-LABEL: @arm_sme_streaming_vl_words -// CHECK: %[[CONST:.*]] = arith.constant 2 : i64 +// CHECK: %[[CONST:.*]] = arith.constant 2 : index // CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64 -// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64 -// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index +// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index +// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index func.func @arm_sme_streaming_vl_words() -> index { %svl_w = arm_sme.streaming_vl <word> return %svl_w : index >From f5de33d9e1df0155b96fcc54543bc754ab044907 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Fri, 5 Sep 2025 10:27:14 +0000 Subject: [PATCH 6/7] - Fix 'default label in switch' build failure --- mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 92f4e4f63c200..e64ae42204fa0 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -24,9 +24,9 @@ unsigned getSizeInBytes(TypeSize type) { return 4; case arm_sme::TypeSize::Double: return 8; - default: - llvm_unreachable("unknown type size"); } + llvm_unreachable("unknown type size"); + return 0; } unsigned getSMETileSliceMinNumElts(Type type) { >From a91144e0f096bcb49bd46eb5286a666dbd45696b Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin <kerry.mclaugh...@arm.com> Date: Fri, 5 Sep 2025 14:02:45 +0000 Subject: [PATCH 7/7] - Fix comments in AArch64ISelDAGToDAG & ArmSMEToLLVM --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp | 5 ++--- mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 8af10ef8dadc9..8ab313dfed46a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -940,9 +940,8 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) { return false; } -// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier -// for RDSVL to calculate the streaming vector length in bytes * N. i.e. -// rdsvl, #(ShlImm - 3) +// Given `cntsd = (rdsvl, #1) >> 3`, attempt to return a suitable multiplier +// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(N - 3)`. template <signed Low, signed High> bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) { if (!isa<ConstantSDNode>(N)) diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index a36f8f09ceada..033e9ae1f4d4c 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -831,9 +831,9 @@ struct OuterProductWideningOpConversion /// is converted to: /// /// %cnt = "arm_sme.intr.cntsd"() : () -> i64 -/// %0 = arith.constant 4 : i64 -/// %1 = arith.muli %cnt, %0 : i64 -/// %2 = arith.index_cast %1 : i64 to index +/// %scale = arith.constant 4 : index +/// %cntIndex = arith.index_cast %cnt : i64 to index +/// %0 = arith.muli %cntIndex, %scale : index /// struct StreamingVLOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits