https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/178059
>From 6e20bc6d92b14abd20085589d63eb89136cbf4a6 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 26 Jan 2026 13:39:12 -0800 Subject: [PATCH 1/2] first attempt --- clang/include/clang/Basic/Builtins.td | 6 ++ clang/lib/CodeGen/CGHLSLBuiltins.cpp | 47 ++++++++++++++++ .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 10 ++++ clang/lib/Sema/SemaHLSL.cpp | 24 ++++++++ .../builtins/WavePrefixCountBits.hlsl | 27 +++++++++ .../BuiltIns/WavePrefixCountBits-errors.hlsl | 30 ++++++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 26 +++++++++ .../DirectX/DirectXTargetTransformInfo.cpp | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 55 +++++++++++++++++++ .../CodeGen/DirectX/WavePrefixBitCount.ll | 10 ++++ .../hlsl-intrinsics/WavePrefixCountBits.ll | 17 ++++++ 13 files changed, 255 insertions(+) create mode 100644 clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index bc8f1474493b0..0ef28ae16c301 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5127,6 +5127,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> { let Prototype = "bool()"; } +def HLSLWavePrefixCountBits : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_prefix_count_bits"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "int(bool)"; +} + def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_wave_read_lane_at"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 36691c7b72efe..c6998a343f496 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -331,6 +331,40 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch, } } +// select and return a specific wave prefix op intrinsic, +// based on the provided op kind. +// OpKinds: +// CountBits = 136, count all bits set in previous threads +// This is the only operation in DXIL so far under this class +static Intrinsic::ID getPrefixOpIntrinsic(int OpKind, + llvm::Triple::ArchType Arch, + CGHLSLRuntime &RT, QualType QT) { + switch (Arch) { + case llvm::Triple::spirv: + switch (OpKind) { + case 136: { + return Intrinsic::spv_subgroup_prefix_bit_count; + } + default: { + llvm_unreachable("Unexpected SubOp ID"); + } + } + case llvm::Triple::dxil: { + switch (OpKind) { + case 136: { + return Intrinsic::dx_wave_prefix_bit_count; + } + default: { + llvm_unreachable("Unexpected SubOp ID"); + } + } + } + default: + llvm_unreachable( + "WavePrefixOp instruction not supported by target architecture"); + } +} + // Returns the mangled name for a builtin function that the SPIR-V backend // will expand into a spec Constant. static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType, @@ -808,6 +842,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, "hlsl.saturate"); } + case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: { + Value *Op = EmitScalarExpr(E->getArg(0)); + assert(Op->getType()->isIntegerTy(1) && + "WavePrefixBitCount operand must be a boolean type"); + + Intrinsic::ID IID = getPrefixOpIntrinsic( + /* OpKind */ 136, getTarget().getTriple().getArch(), + CGM.getHLSLRuntime(), E->getArg(0)->getType()); + + return EmitRuntimeCall( + Intrinsic::getOrInsertDeclaration(&CGM.getModule(), IID), ArrayRef{Op}, + "hlsl.wave.prefix.bit.count"); + } case Builtin::BI__builtin_hlsl_select: { Value *OpCond = EmitScalarExpr(E->getArg(0)); RValue RValTrue = EmitAnyExpr(E->getArg(1)); diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index e9a41b94d6c03..656fa7c7dea82 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2445,6 +2445,16 @@ _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_count) __attribute__((convergent)) uint WaveGetLaneCount(); +//===----------------------------------------------------------------------===// +// WavePrefixOp builtins +//===----------------------------------------------------------------------===// +/// \brief Returns the count of bits of Expr set to 1 on prior lanes. +/// \param Expr The boolean expression to evaluate. +/// \return the count of bits set to 1 on prior lanes. +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_count_bits) +__attribute__((convergent)) int WavePrefixCountBits(bool Expr); + //===----------------------------------------------------------------------===// // WaveReadLaneAt builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 4d31e26d56e6b..cd9e77f913800 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3593,6 +3593,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + // Ensure input expr type is a scalar/vector and then + // set the return type to the arg type + QualType ArgType = TheCall->getArg(0)->getType(); + // not the scalar or vector<scalar> + if (!(ArgType->isScalarType())) { + SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_any_scalar_or_vector) + << ArgType << 0; + return true; + } + + if (!(ArgType->isBooleanType())) { + SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_any_scalar_or_vector) + << ArgType << 0; + return true; + } + + break; + } case Builtin::BI__builtin_hlsl_wave_read_lane_at: { if (SemaRef.checkArgCount(TheCall, 2)) return true; diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl new file mode 100644 index 0000000000000..135507b60c8fc --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL + +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +int test_int(bool expr) { +// CHECK-DXIL: define hidden noundef i32 {{.*}}(i1 noundef %[[EXPR:.*]]) #[[CONVATTR:.*]] { +// CHECK-SPIRV: define hidden spir_func noundef i32 {{.*}}(i1 noundef %[[EXPR:.*]]) #[[CONVATTR:.*]] { + // CHECK: entry: + // CHECK: %[[EXPRADDR:.*]] = alloca i32, align 4 + // CHECK: %[[STOREDVAL:.*]] = zext i1 %[[EXPR]] to i32 + // CHECK: store i32 %[[STOREDVAL]], ptr %[[EXPRADDR]], align 4 + // CHECK: %[[LOADEDVAL:.*]] = load i32, ptr %[[EXPRADDR]], align 4 + // CHECK: %[[TRUNCLOADEDVAL:.*]] = trunc i32 %[[LOADEDVAL]] to i1 + + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.prefix.bit.count(i1 %[[TRUNCLOADEDVAL]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.bit.count(i1 %[[TRUNCLOADEDVAL]]) + // CHECK: ret [[TY]] %[[RET]] + return WavePrefixCountBits(expr); +} + +// CHECK: attributes #[[CONVATTR]] = {{{.*}} convergent {{.*}}} diff --git a/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl new file mode 100644 index 0000000000000..2b2bca82c6b17 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_prefix_count_bits(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +int test_too_many_arg(bool p0) { + return __builtin_hlsl_wave_prefix_count_bits(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +float test_expr_bool_type_check(float p0) { + return __builtin_hlsl_wave_prefix_count_bits(p0); + // expected-error@-1 {{invalid operand of type 'float'}} +} + +float2 test_expr_bool_vec_type_check(float2 p0) { + return __builtin_hlsl_wave_prefix_count_bits(p0); + // expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_prefix_count_bits(p0); + // expected-error@-1 {{invalid operand of type 'S'}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 23627848b6214..2aa20ddd5d434 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -160,6 +160,7 @@ def int_dx_lerp : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; +def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrNoMem]>; def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index a93e8ad0ce964..1e54bf394c984 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -117,6 +117,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] [IntrNoMem, Commutative] >; def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; + def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrNoMem]>; def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 3a40d2c36139d..bbce6bc082ac8 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -235,6 +235,7 @@ defset list<DXILOpClass> OpClasses = { def waveMatrix_StoreRawBuf : DXILOpClass; def waveMultiPrefixBitCount : DXILOpClass; def waveMultiPrefixOp : DXILOpClass; + def wavePrefixBitCount : DXILOpClass; def wavePrefixOp : DXILOpClass; def waveReadLaneAt : DXILOpClass; def waveReadLaneFirst : DXILOpClass; @@ -317,6 +318,8 @@ defvar WaveOpKind_Product = 1; defvar WaveOpKind_Min = 2; defvar WaveOpKind_Max = 3; +defvar WavePrefixOpKind_BitCount = 136; + defvar SignedOpKind_Signed = 0; defvar SignedOpKind_Unsigned = 1; @@ -1124,6 +1127,29 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> { let attributes = [Attributes<DXIL1_0, []>]; } +def WavePrefixOp : DXILOp<121, wavePrefixOp> { + let Doc = "returns the result of the operation on prior lanes"; + + let intrinsics = [ + IntrinSelect<int_dx_wave_prefix_bit_count, + [ + IntrinArgI32<WavePrefixOpKind_BitCount>, IntrinArgIndex<0>, + IntrinArgI8<SignedOpKind_Unsigned> + ]> + ]; + + let arguments = [ + Int32Ty, // prefix op kind + Int1Ty, // value + Int8Ty // signedness + ]; + + let result = Int32Ty; + + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, []>]; +} + def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> { let Doc = "returns the float16 stored in the low-half of the uint converted " "to a float"; diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index f54b48b91265e..b885b459b5d72 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -55,6 +55,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_rsqrt: case Intrinsic::dx_saturate: case Intrinsic::dx_splitdouble: + case Intrinsic::dx_wave_prefix_bit_count: case Intrinsic::dx_wave_readlane: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 23cfb326bc8d9..a443acceca824 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -224,6 +224,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectWavePrefixBitCount(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsUnsigned) const; @@ -2715,6 +2718,56 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return Result; } +bool SPIRVInstructionSelector::selectWavePrefixBitCount( + Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { + + assert(I.getNumOperands() == 3); + + auto Op = I.getOperand(2); + assert(Op.isReg()); + + MachineBasicBlock &BB = *I.getParent(); + DebugLoc DL = I.getDebugLoc(); + + Register InputRegister = Op.getReg(); + SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); + + if (!InputType) + report_fatal_error("Input Type could not be determined."); + + if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool)) + report_fatal_error("WavePrefixBitCount requires boolean input"); + + // Types + SPIRVType *Int32Ty = GR.getOrCreateSPIRVIntegerType(32, I, TII); + + // Ballot result type: vector<uint32> + // Match DXC: %v4uint for Subgroup size + SPIRVType *BallotTy = GR.getOrCreateSPIRVVectorType(Int32Ty, 4, I, TII); + + // Create a vreg for the ballot result + Register BallotVReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + + // 1. OpGroupNonUniformBallot + BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformBallot)) + .addDef(BallotVReg) + .addUse(GR.getSPIRVTypeID(BallotTy)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, Int32Ty, TII)) + .addUse(InputRegister) + .constrainAllUses(TII, TRI, RBI); + + // 2. OpGroupNonUniformBallotBitCount + BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, Int32Ty, TII)) + .addImm(SPIRV::GroupOperation::ExclusiveScan) + .addUse(BallotVReg) + .constrainAllUses(TII, TRI, RBI); + + return true; +} + bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -3859,6 +3912,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp); case Intrinsic::spv_sclamp: return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp); + case Intrinsic::spv_subgroup_prefix_bit_count: + return selectWavePrefixBitCount(ResVReg, ResType, I); case Intrinsic::spv_wave_active_countbits: return selectWaveActiveCountBits(ResVReg, ResType, I); case Intrinsic::spv_wave_all: diff --git a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll new file mode 100644 index 0000000000000..67432aa8e3696 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll @@ -0,0 +1,10 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +; Test that WavePrefixCountBits maps down to the DirectX op + +define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.wavePrefixOp(i32 121, i32 136, i1 %expr, i8 1) + %ret = call i32 @llvm.dx.wave.prefix.bit.count(i1 %expr) + ret i32 %ret +} diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll new file mode 100644 index 0000000000000..321123ab5a617 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll @@ -0,0 +1,17 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; Test lowering to spir-v backend + +define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) { +entry: + ; CHECK: %[[UINT:.*]] = OpTypeInt 32 0 + ; CHECK: %[[UINT4:.*]] = OpTypeVector %[[UINT]] 4 + ; CHECK: %[[UINT3:.*]] = OpConstant %[[UINT]] 3 + ; CHECK: %[[INPUTREG:.*]] = OpFunctionParameter + ; CHECK: %[[BALLOTRESULT:.*]] = OpGroupNonUniformBallot %[[UINT4]] %[[UINT3]] %[[INPUTREG]] + ; CHECK: %[[RET:.*]] = OpGroupNonUniformBallotBitCount %[[UINT]] %[[UINT3]] ExclusiveScan %[[BALLOTRESULT]] + %ret = call i32 @llvm.spv.subgroup.prefix.bit.count(i1 %expr) + ; CHECK: OpReturnValue %[[RET]] + ret i32 %ret +} >From a0085c2ed1817d3f20ac22bc33b2bb3dbb511151 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Tue, 27 Jan 2026 12:51:33 -0800 Subject: [PATCH 2/2] address Tex, remove waveprefixop intrinsic and make countbits independent, but still named the same --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 33 ++++--------------- clang/lib/Sema/SemaHLSL.cpp | 2 +- llvm/lib/Target/DirectX/DXIL.td | 20 +++-------- .../CodeGen/DirectX/WavePrefixBitCount.ll | 2 +- 4 files changed, 12 insertions(+), 45 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 54259a3f6aba1..33196b66c576e 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -384,33 +384,13 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch, } } -// select and return a specific wave prefix op intrinsic, -// based on the provided op kind. -// OpKinds: -// CountBits = 136, count all bits set in previous threads -// This is the only operation in DXIL so far under this class -static Intrinsic::ID getPrefixOpIntrinsic(int OpKind, - llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { +static Intrinsic::ID getPrefixCountBitsIntrinsic( + llvm::Triple::ArchType Arch) { switch (Arch) { case llvm::Triple::spirv: - switch (OpKind) { - case 136: { - return Intrinsic::spv_subgroup_prefix_bit_count; - } - default: { - llvm_unreachable("Unexpected SubOp ID"); - } - } + return Intrinsic::spv_subgroup_prefix_bit_count; case llvm::Triple::dxil: { - switch (OpKind) { - case 136: { - return Intrinsic::dx_wave_prefix_bit_count; - } - default: { - llvm_unreachable("Unexpected SubOp ID"); - } - } + return Intrinsic::dx_wave_prefix_bit_count; } default: llvm_unreachable( @@ -903,9 +883,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, assert(Op->getType()->isIntegerTy(1) && "WavePrefixBitCount operand must be a boolean type"); - Intrinsic::ID IID = getPrefixOpIntrinsic( - /* OpKind */ 136, getTarget().getTriple().getArch(), - CGM.getHLSLRuntime(), E->getArg(0)->getType()); + Intrinsic::ID IID = getPrefixCountBitsIntrinsic( + getTarget().getTriple().getArch()); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), IID), ArrayRef{Op}, diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 5d65c19bab234..e6ba492549b12 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3614,7 +3614,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { // Ensure input expr type is a scalar/vector and then // set the return type to the arg type QualType ArgType = TheCall->getArg(0)->getType(); - // not the scalar or vector<scalar> + if (!(ArgType->isScalarType())) { SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_any_scalar_or_vector) diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 3843257ba26d3..d230e3daec55e 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -318,8 +318,6 @@ defvar WaveOpKind_Product = 1; defvar WaveOpKind_Min = 2; defvar WaveOpKind_Max = 3; -defvar WavePrefixOpKind_BitCount = 136; - defvar SignedOpKind_Signed = 0; defvar SignedOpKind_Unsigned = 1; @@ -1127,21 +1125,11 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> { let attributes = [Attributes<DXIL1_0, []>]; } -def WavePrefixOp : DXILOp<121, wavePrefixOp> { - let Doc = "returns the result of the operation on prior lanes"; - - let intrinsics = [ - IntrinSelect<int_dx_wave_prefix_bit_count, - [ - IntrinArgI32<WavePrefixOpKind_BitCount>, IntrinArgIndex<0>, - IntrinArgI8<SignedOpKind_Unsigned> - ]> - ]; - - let arguments = [ Int32Ty, Int1Ty, Int8Ty ]; - +def WavePrefixCountBits : DXILOp<136, wavePrefixOp> { + let Doc = "returns the count of bits of Expr set to 1 on prior lanes"; + let intrinsics = [IntrinSelect<int_dx_wave_prefix_bit_count>]; + let arguments = [Int1Ty]; let result = Int32Ty; - let stages = [Stages<DXIL1_0, [all_stages]>]; let attributes = [Attributes<DXIL1_0, []>]; } diff --git a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll index 67432aa8e3696..406bfa44a3f47 100644 --- a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll +++ b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll @@ -4,7 +4,7 @@ define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) { entry: -; CHECK: call i32 @dx.op.wavePrefixOp(i32 121, i32 136, i1 %expr, i8 1) +; CHECK: call i32 @dx.op.wavePrefixOp(i32 136, i1 %expr) %ret = call i32 @llvm.dx.wave.prefix.bit.count(i1 %expr) ret i32 %ret } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
