https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/178056
>From fe1f4f48f6c5f6e0a41fe3ef1b1c77dc8b93c809 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Fri, 23 Jan 2026 12:20:11 -0800 Subject: [PATCH 1/8] handle arg promotion with customtypechecking --- clang/include/clang/Basic/Builtins.td | 6 ++ clang/lib/CodeGen/CGHLSLBuiltins.cpp | 58 ++++++++++++++++ .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 24 +++++++ clang/lib/Sema/SemaHLSL.cpp | 26 +++++++ .../CodeGenHLSL/builtins/WaveActiveBitOr.hlsl | 67 +++++++++++++++++++ .../CodeGenHLSL/builtins/WaveActiveSum.hlsl | 11 +-- .../BuiltIns/WaveActiveBitOr-errors.hlsl | 30 +++++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 25 +++++++ .../DirectX/DirectXTargetTransformInfo.cpp | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 32 +++++++++ 12 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index bc8f1474493b0..e2c46634081a0 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5091,6 +5091,12 @@ def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> { let Prototype = "_ExtVector<4, unsigned int>(bool)"; } +def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_active_bit_or"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_wave_active_count_bits"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 36691c7b72efe..d3d925e6fc8f4 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -276,6 +276,51 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { return RT.getFirstBitUHighIntrinsic(); } +// select and return a specific wave bit op intrinsic, +// based on the provided op kind. +// OpKinds: +// And = 0, bitwise and of values +// Or = 1, bitwise or of values +// Xor = 2, bitwise xor of values +static Intrinsic::ID getWaveBitOpIntrinsic(int OpKind, + llvm::Triple::ArchType Arch, + CGHLSLRuntime &RT, QualType QT) { + switch (Arch) { + case llvm::Triple::spirv: + switch (OpKind) { + + case 0: + case 2: { + llvm_unreachable("Not implemented yet!"); + } + case 1: { + return Intrinsic::spv_wave_bit_or; + } + default: { + llvm_unreachable("Unexpected SubOp ID"); + } + } + case llvm::Triple::dxil: { + switch (OpKind) { + + case 0: + case 2: { + llvm_unreachable("Not implemented yet!"); + } + case 1: { + return Intrinsic::dx_wave_bit_or; + } + default: { + llvm_unreachable("Unexpected SubOp ID"); + } + } + } + default: + llvm_unreachable("Intrinsic WaveActiveBitOr" + " not supported by target architecture"); + } +} + // Return wave active sum that corresponds to the QT scalar type static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, CGHLSLRuntime &RT, QualType QT) { @@ -872,6 +917,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, return handleHlslWaveActiveBallot(*this, E); } + case Builtin::BI__builtin_hlsl_wave_active_bit_or: { + Value *Op = EmitScalarExpr(E->getArg(0)); + assert(Op->getType()->isIntegerTy() && + "Intrinsic WaveActiveBitOr operand must be an integer type"); + + Intrinsic::ID IID = getWaveBitOpIntrinsic( + /* OpKind */ 1, getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), + E->getArg(0)->getType()); + + return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( + &CGM.getModule(), IID, {Op->getType()}), + ArrayRef{Op}, "hlsl.wave.active.bit.or"); + } case Builtin::BI__builtin_hlsl_wave_active_count_bits: { Value *OpExpr = EmitScalarExpr(E->getArg(0)); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic(); diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index e9a41b94d6c03..73481334b0abf 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2422,6 +2422,30 @@ _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_ballot) __attribute__((convergent)) uint4 WaveActiveBallot(bool Val); +/// \brief Returns the bitwise OR of all the values of Expr across all +/// active non-helper lanes in the current wave, and replicates it back +/// to all active non-helper lanes. +/// \param Expr The integer expression to evaluate. +/// \return The bitwise OR value of Expr across all active threads. +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) int16_t WaveActiveBitOr(int16_t Expr); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) int WaveActiveBitOr(int Expr); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) uint WaveActiveBitOr(uint Expr); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) +__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t Expr); + /// \brief Counts the number of boolean variables which evaluate to true across /// all active lanes in the current wave. /// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 4d31e26d56e6b..1df66df1f3e70 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3583,6 +3583,32 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyExpr); break; } + case Builtin::BI__builtin_hlsl_wave_active_bit_or: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + // Ensure input expr type is a scalar/vector and then + // set the return type to the arg type + QualType ArgType = TheCall->getArg(0)->getType(); + auto *VTy = ArgType->getAs<VectorType>(); + // not the scalar or vector<scalar> + if (!(ArgType->isScalarType())) { + SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_any_scalar_or_vector) + << ArgType << 0; + return true; + } + + if (!(ArgType->isIntegerType())) { + SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_any_scalar_or_vector) + << ArgType << 0; + return true; + } + + TheCall->setType(ArgType); + break; + } // Note these are llvm builtins that we want to catch invalid intrinsic // generation. Normal handling of these builtins will occur elsewhere. case Builtin::BI__builtin_elementwise_bitreverse: { diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl new file mode 100644 index 0000000000000..6b757ce60b389 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl @@ -0,0 +1,67 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL + +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: test_int +int test_int(int expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_int16 +int16_t test_int16_t(int16_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i16([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i16([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_int64 +int64_t test_int64_t(int64_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-LABEL: test_uint +uint test_uint(uint expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-LABEL: test_uint16 +uint16_t test_uint16_t(uint16_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-LABEL: test_uint64 +uint64_t test_uint64_t(uint64_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl index 1fc93c62c8db0..87ddd96e8368c 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl @@ -1,9 +1,12 @@ // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ -// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL + // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ -// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV // Test basic lowering to runtime function call. diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl new file mode 100644 index 0000000000000..e3fd2eac28159 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-int16-type -fnative-half-type \ +// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ +// RUN: -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_active_bit_or(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +int test_too_many_arg(int p0) { + return __builtin_hlsl_wave_active_bit_or(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +float test_expr_bool_type_check(float p0) { + return __builtin_hlsl_wave_active_bit_or(p0); + // expected-error@-1 {{invalid operand of type 'float'}} +} + +float2 test_expr_bool_vec_type_check(float2 p0) { + return __builtin_hlsl_wave_active_bit_or(p0); + // expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_active_bit_or(p0); + // expected-error@-1 {{invalid operand of type 'S'}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index f79945785566c..acbabb128258d 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -154,6 +154,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 293cb750cea98..303ee8e1a61bf 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -122,6 +122,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">, DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; + def int_spv_wave_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 3a40d2c36139d..7f170faef204c 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -215,6 +215,7 @@ defset list<DXILOpClass> OpClasses = { def waveActiveAllEqual : DXILOpClass; def waveActiveBit : DXILOpClass; def waveActiveOp : DXILOpClass; + def waveBitOp : DXILOpClass; def waveAllOp : DXILOpClass; def waveAllTrue : DXILOpClass; def waveAnyTrue : DXILOpClass; @@ -317,6 +318,10 @@ defvar WaveOpKind_Product = 1; defvar WaveOpKind_Min = 2; defvar WaveOpKind_Max = 3; +defvar WaveBitOpKind_And = 0; +defvar WaveBitOpKind_Or = 1; +defvar WaveBitOpKind_Xor = 2; + defvar SignedOpKind_Signed = 0; defvar SignedOpKind_Unsigned = 1; @@ -1124,6 +1129,26 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> { let attributes = [Attributes<DXIL1_0, []>]; } +// we may not need the third argument to intrinselect. +def WaveBitOp : DXILOp<120, waveBitOp> { + let Doc = "returns the result of the bitwise operation across waves"; + let intrinsics = [ + IntrinSelect<int_dx_wave_bit_or, + [ + IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>, + IntrinArgI8<SignedOpKind_Signed> + ]> + ]; + + let arguments = [OverloadTy, Int8Ty, Int8Ty]; + let result = OverloadTy; + let overloads = [ + Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]> + ]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, []>]; +} + def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> { let Doc = "returns the float16 stored in the low-half of the uint converted " "to a float"; diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index f54b48b91265e..4b69b5c5ce239 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -55,6 +55,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_rsqrt: case Intrinsic::dx_saturate: case Intrinsic::dx_splitdouble: + case Intrinsic::dx_wave_bit_or: case Intrinsic::dx_wave_readlane: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 626393d4ecb40..09e0d25b5f639 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -224,6 +224,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectWaveBitOr(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsUnsigned) const; @@ -2711,6 +2714,33 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return Result; } +bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + + assert(I.getNumOperands() == 3); + assert(I.getOperand(1).isReg()); + MachineBasicBlock &BB = *I.getParent(); + Register InputRegister = I.getOperand(1).getReg(); + SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); + + if (!InputType) + report_fatal_error("Input Type could not be determined."); + if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeInt)) + report_fatal_error("WaveActiveBitOr requires integer input"); + + SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); + + return BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBitwiseOr)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) + .addImm(SPIRV::GroupOperation::Reduce) + .addUse(InputRegister) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, @@ -3815,6 +3845,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll); case Intrinsic::spv_wave_any: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny); + case Intrinsic::spv_wave_bit_or: + return selectWaveBitOr(ResVReg, ResType, I); case Intrinsic::spv_subgroup_ballot: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformBallot); >From b8401525e8f117165f76c4b1b6366eb040799839 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 26 Jan 2026 14:41:42 -0800 Subject: [PATCH 2/8] remove unused var --- clang/lib/Sema/SemaHLSL.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 1df66df1f3e70..97f833e921239 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3590,8 +3590,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { // Ensure input expr type is a scalar/vector and then // set the return type to the arg type QualType ArgType = TheCall->getArg(0)->getType(); - auto *VTy = ArgType->getAs<VectorType>(); - // not the scalar or vector<scalar> + if (!(ArgType->isScalarType())) { SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_any_scalar_or_vector) >From 89a3bf9ab27717f5aabc3de1b2db431b65c0a12b Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 26 Jan 2026 16:01:40 -0800 Subject: [PATCH 3/8] add enable 16bit preprocess flag --- clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 0887a4c5d7a64..5760c251aac56 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2448,18 +2448,22 @@ __attribute__((convergent)) uint4 WaveActiveBallot(bool Val); /// to all active non-helper lanes. /// \param Expr The integer expression to evaluate. /// \return The bitwise OR value of Expr across all active threads. +#ifdef __HLSL_ENABLE_16_BIT _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) __attribute__((convergent)) int16_t WaveActiveBitOr(int16_t Expr); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) -__attribute__((convergent)) int WaveActiveBitOr(int Expr); +__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr); +#endif + _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) -__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr); +__attribute__((convergent)) int WaveActiveBitOr(int Expr); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) -__attribute__((convergent)) uint16_t WaveActiveBitOr(uint16_t Expr); +__attribute__((convergent)) int64_t WaveActiveBitOr(int64_t Expr); + _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or) __attribute__((convergent)) uint WaveActiveBitOr(uint Expr); >From 41ea8b0a7cbe49f24641ab0f9ae0a5bd23c349f8 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Tue, 27 Jan 2026 16:04:48 -0800 Subject: [PATCH 4/8] add some more tests --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 4 +-- llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll | 19 ++++++++++++ .../SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll | 30 +++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index ca63bf3e17782..151f8da509735 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2723,9 +2723,9 @@ bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg, MachineInstr &I) const { assert(I.getNumOperands() == 3); - assert(I.getOperand(1).isReg()); + assert(I.getOperand(2).isReg()); MachineBasicBlock &BB = *I.getParent(); - Register InputRegister = I.getOperand(1).getReg(); + Register InputRegister = I.getOperand(2).getReg(); SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); if (!InputType) diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll new file mode 100644 index 0000000000000..e7bc6a5292c3d --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll @@ -0,0 +1,19 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +define noundef i32 @wave_bitor_simple(i32 noundef %p1) { +entry: +; CHECK: call i32 @dx.op.waveBitOp.i32(i32 120, i32 %p1, i8 1, i8 0) + %ret = call i32 @llvm.dx.wave.bit.or.i32(i32 %p1) + ret i32 %ret +} + +declare i32 @llvm.dx.wave.bit.or.i32(i32) + +define noundef i64 @wave_bitor_simple64(i64 noundef %p1) { +entry: +; CHECK: call i64 @dx.op.waveBitOp.i64(i32 120, i64 %p1, i8 1, i8 0) + %ret = call i64 @llvm.dx.wave.bit.or.i64(i64 %p1) + ret i64 %ret +} + +declare i64 @llvm.dx.wave.bit.or.i64(i64) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll new file mode 100644 index 0000000000000..81b0bfe03dbe7 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll @@ -0,0 +1,30 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val --target-env spv1.4 %} + +; Test lowering to spir-v backend for various types and scalar/vector + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#uint64:]] = OpTypeInt 64 0 +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_uint +; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]] +define i32 @test_uint(i32 %iexpr) { +entry: +; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]] + %0 = call i32 @llvm.spv.wave.bit.or.i32(i32 %iexpr) + ret i32 %0 +} + +declare i32 @llvm.spv.wave.bit.or.i32(i32) + +; CHECK-LABEL: Begin function test_uint64 +; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]] +define i64 @test_uint64(i64 %iexpr64) { +entry: +; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]] + %0 = call i64 @llvm.spv.wave.bit.or.i64(i64 %iexpr64) + ret i64 %0 +} + +declare i64 @llvm.spv.wave.bit.or.i64(i64) >From fcf01d2a7d7e8ac186c415378d0e5a219b72180f Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Thu, 29 Jan 2026 17:32:31 -0800 Subject: [PATCH 5/8] address Tex and Farzon --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 55 +++++-------------- clang/lib/Sema/SemaHLSL.cpp | 21 +++---- .../CodeGenHLSL/builtins/WaveActiveBitOr.hlsl | 33 +++++++++++ .../CodeGenHLSL/builtins/WaveActiveSum.hlsl | 11 ++-- .../BuiltIns/WaveActiveBitOr-errors.hlsl | 6 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 43 +++++++-------- 6 files changed, 82 insertions(+), 87 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 15491368caf82..24b1c2caefe7d 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -329,45 +329,15 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { return RT.getFirstBitUHighIntrinsic(); } -// select and return a specific wave bit op intrinsic, -// based on the provided op kind. -// OpKinds: -// And = 0, bitwise and of values -// Or = 1, bitwise or of values -// Xor = 2, bitwise xor of values -static Intrinsic::ID getWaveBitOpIntrinsic(int OpKind, - llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { +static Intrinsic::ID getWaveBitOpOrIntrinsic(llvm::Triple::ArchType Arch, + CGHLSLRuntime &RT, QualType QT) { switch (Arch) { case llvm::Triple::spirv: - switch (OpKind) { + return Intrinsic::spv_wave_bit_or; - case 0: - case 2: { - llvm_unreachable("Not implemented yet!"); - } - case 1: { - return Intrinsic::spv_wave_bit_or; - } - default: { - llvm_unreachable("Unexpected SubOp ID"); - } - } - case llvm::Triple::dxil: { - switch (OpKind) { + case llvm::Triple::dxil: + return Intrinsic::dx_wave_bit_or; - case 0: - case 2: { - llvm_unreachable("Not implemented yet!"); - } - case 1: { - return Intrinsic::dx_wave_bit_or; - } - default: { - llvm_unreachable("Unexpected SubOp ID"); - } - } - } default: llvm_unreachable("Intrinsic WaveActiveBitOr" " not supported by target architecture"); @@ -975,12 +945,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, } case Builtin::BI__builtin_hlsl_wave_active_bit_or: { Value *Op = EmitScalarExpr(E->getArg(0)); - assert(Op->getType()->isIntegerTy() && - "Intrinsic WaveActiveBitOr operand must be an integer type"); - - Intrinsic::ID IID = getWaveBitOpIntrinsic( - /* OpKind */ 1, getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), - E->getArg(0)->getType()); + llvm::Type *Ty = Op->getType(); + assert(Ty->isIntegerTy() || + (Ty->isVectorTy() && Ty->getScalarType()->isIntegerTy()) && + "Intrinsic WaveActiveBitOr operand must be integer or " + "vector of integers"); + + Intrinsic::ID IID = + getWaveBitOpOrIntrinsic(getTarget().getTriple().getArch(), + CGM.getHLSLRuntime(), E->getArg(0)->getType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {Op->getType()}), diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d5ec41bedd242..349c6a967e534 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3601,24 +3601,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { if (SemaRef.checkArgCount(TheCall, 1)) return true; - // Ensure input expr type is a scalar/vector and then - // set the return type to the arg type QualType ArgType = TheCall->getArg(0)->getType(); - if (!(ArgType->isScalarType())) { + // Ensure input expr type is a scalar/vector + if (!ArgType->hasIntegerRepresentation()) { SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_typecheck_expect_any_scalar_or_vector) - << ArgType << 0; - return true; - } - - if (!(ArgType->isIntegerType())) { - SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_typecheck_expect_any_scalar_or_vector) - << ArgType << 0; + diag::err_builtin_invalid_arg_type) + << 1 // %ordinal0: 1st argument + << 5 // %select1: scalar or vector of + << 1 // %select2: integer + << 0 // %select3: no floating-point + << TheCall->getArg(0)->getType(); return true; } + // Set the return type to the arg type TheCall->setType(ArgType); break; } diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl index 6b757ce60b389..f9966bc9ebf63 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl @@ -21,6 +21,39 @@ int test_int(int expr) { // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i32([[TY]]) #[[#attr:]] // CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i32([[TY]]) #[[#attr:]] +// CHECK-LABEL: test_int2 +int2 test_int2(int2 expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v2i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v2i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v2i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v2i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_int3 +int3 test_int3(int3 expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v3i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v3i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v3i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v3i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_int4 +int4 test_int4(int4 expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v4i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v4i32([[TY]] %[[#]]) + // CHECK: ret [[TY]] %[[RET]] + return WaveActiveBitOr(expr); +} + +// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v4i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v4i32([[TY]]) #[[#attr:]] + // CHECK-LABEL: test_int16 int16_t test_int16_t(int16_t expr) { // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]]) diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl index 87ddd96e8368c..1fc93c62c8db0 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl @@ -1,12 +1,9 @@ // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ -// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ -// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ -// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL - +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ -// RUN: spirv-pc-vulkan-compute %s -emit-llvm -fnative-int16-type -fnative-half-type \ -// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \ -// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV // Test basic lowering to runtime function call. diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl index e3fd2eac28159..19c1f0ede4765 100644 --- a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl @@ -14,17 +14,17 @@ int test_too_many_arg(int p0) { float test_expr_bool_type_check(float p0) { return __builtin_hlsl_wave_active_bit_or(p0); - // expected-error@-1 {{invalid operand of type 'float'}} + // expected-error@-1 {{1st argument must be a scalar or vector of integer types (was 'float')}} } float2 test_expr_bool_vec_type_check(float2 p0) { return __builtin_hlsl_wave_active_bit_or(p0); - // expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}} + // expected-error@-1 {{1st argument must be a scalar or vector of integer types (was 'float2' (aka 'vector<float, 2>'))}} } struct S { float f; }; S test_expr_struct_type_check(S p0) { return __builtin_hlsl_wave_active_bit_or(p0); - // expected-error@-1 {{invalid operand of type 'S'}} + // expected-error@-1 {{1st argument must be a scalar or vector of integer types (was 'S')}} } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 151f8da509735..8b0a06231ccca 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -224,8 +224,8 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; - bool selectWaveBitOr(Register ResVReg, const SPIRVType *ResType, - MachineInstr &I) const; + bool selectWaveBitOpInst(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, unsigned Opcode) const; bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool IsUnsigned) const; @@ -2718,31 +2718,25 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return Result; } -bool SPIRVInstructionSelector::selectWaveBitOr(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { - - assert(I.getNumOperands() == 3); - assert(I.getOperand(2).isReg()); +bool SPIRVInstructionSelector::selectWaveBitOpInst(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned Opcode) const { MachineBasicBlock &BB = *I.getParent(); - Register InputRegister = I.getOperand(2).getReg(); - SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister); + SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - if (!InputType) - report_fatal_error("Input Type could not be determined."); - if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeInt)) - report_fatal_error("WaveActiveBitOr requires integer input"); + auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, + IntTy, TII, !STI.isShader())); + BMI.addImm(SPIRV::GroupOperation::Reduce); - SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); + for (unsigned J = 2; J < I.getNumOperands(); J++) { + BMI.addUse(I.getOperand(J).getReg()); + } - return BuildMI(BB, I, I.getDebugLoc(), - TII.get(SPIRV::OpGroupNonUniformBitwiseOr)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII)) - .addImm(SPIRV::GroupOperation::Reduce) - .addUse(InputRegister) - .constrainAllUses(TII, TRI, RBI); + return BMI.constrainAllUses(TII, TRI, RBI); } bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, @@ -3896,7 +3890,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_wave_any: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny); case Intrinsic::spv_wave_bit_or: - return selectWaveBitOr(ResVReg, ResType, I); + return selectWaveBitOpInst(ResVReg, ResType, I, + SPIRV::OpGroupNonUniformBitwiseOr); case Intrinsic::spv_subgroup_ballot: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformBallot); >From ac22964f653268326511b671715d77fe04aa8c7b Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Thu, 29 Jan 2026 18:13:44 -0800 Subject: [PATCH 6/8] clang-format --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index c15fbd36a67fe..09353a127e033 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -226,7 +226,7 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectWaveBitOpInst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; - + bool selectWavePrefixBitCount(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; >From 998c1cb10a15ea65f797722d11aaae70352ff26b Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Thu, 29 Jan 2026 22:45:33 -0800 Subject: [PATCH 7/8] add missing stages + attributes --- llvm/lib/Target/DirectX/DXIL.td | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 5931a5218b811..43b4c47522cf2 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1144,6 +1144,8 @@ def WaveBitOp : DXILOp<120, waveBitOp> { let overloads = [ Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]> ]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, []>]; } def WavePrefixBitCount : DXILOp<136, wavePrefixOp> { >From 63911c660f0f26dbf38a282df41644ab69ce852d Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Fri, 30 Jan 2026 17:59:51 -0800 Subject: [PATCH 8/8] create macro for spirv intrinsics to use subgroup instead of wave --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 20 +------------ clang/lib/CodeGen/CGHLSLRuntime.h | 28 ++++++++++++++++++- .../CodeGenHLSL/builtins/WaveActiveBitOr.hlsl | 28 +++++++++---------- llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +- .../SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll | 8 +++--- 6 files changed, 48 insertions(+), 40 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 9c72e2f6108c7..bf5b01ff43a45 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -329,21 +329,6 @@ static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { return RT.getFirstBitUHighIntrinsic(); } -static Intrinsic::ID getWaveBitOpOrIntrinsic(llvm::Triple::ArchType Arch, - CGHLSLRuntime &RT, QualType QT) { - switch (Arch) { - case llvm::Triple::spirv: - return Intrinsic::spv_wave_bit_or; - - case llvm::Triple::dxil: - return Intrinsic::dx_wave_bit_or; - - default: - llvm_unreachable("Intrinsic WaveActiveBitOr" - " not supported by target architecture"); - } -} - // Return wave active sum that corresponds to the QT scalar type static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, CGHLSLRuntime &RT, QualType QT) { @@ -976,10 +961,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, "Intrinsic WaveActiveBitOr operand must be integer or " "vector of integers"); - Intrinsic::ID IID = - getWaveBitOpOrIntrinsic(getTarget().getTriple().getArch(), - CGM.getHLSLRuntime(), E->getArg(0)->getType()); - + Intrinsic::ID IID = CGM.getHLSLRuntime().getWaveActiveBitOpOrIntrinsic(); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {Op->getType()}), ArrayRef{Op}, "hlsl.wave.active.bit.or"); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index ba2ca2c358388..683c1e0eb5e3a 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -34,7 +34,15 @@ // A function generator macro for picking the right intrinsic // for the target backend -#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix) \ + +#define _GEN_INTRIN_CHOOSER(_1, _2, _3, NAME, ...) NAME + +#define GENERATE_HLSL_INTRINSIC_FUNCTION(...) \ + _GEN_INTRIN_CHOOSER(__VA_ARGS__, GENERATE_HLSL_INTRINSIC_FUNCTION3, \ + GENERATE_HLSL_INTRINSIC_FUNCTION2)(__VA_ARGS__) + +// 2-arg form: same postfix for both backends (uses the identity) +#define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix) \ llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \ llvm::Triple::ArchType Arch = getArch(); \ switch (Arch) { \ @@ -48,6 +56,22 @@ } \ } +// 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup) +#define GENERATE_HLSL_INTRINSIC_FUNCTION3(FunctionName, DxilPostfix, \ + SpirvPostfix) \ + llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \ + llvm::Triple::ArchType Arch = getArch(); \ + switch (Arch) { \ + case llvm::Triple::dxil: \ + return llvm::Intrinsic::dx_##DxilPostfix; \ + case llvm::Triple::spirv: \ + return llvm::Intrinsic::spv_##SpirvPostfix; \ + default: \ + llvm_unreachable("Intrinsic " #DxilPostfix \ + " not supported by target architecture"); \ + } \ + } + using ResourceClass = llvm::dxil::ResourceClass; namespace llvm { @@ -146,6 +170,8 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any) + GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBitOpOr, wave_bit_or, + subgroup_bit_or) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count) diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl index f9966bc9ebf63..7c57d3298f71f 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl @@ -12,62 +12,62 @@ // CHECK-LABEL: test_int int test_int(int expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i32([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); } // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i32([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.subgroup.bit.or.i32([[TY]]) #[[#attr:]] // CHECK-LABEL: test_int2 int2 test_int2(int2 expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v2i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.v2i32([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v2i32([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); } // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v2i32([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v2i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.subgroup.bit.or.v2i32([[TY]]) #[[#attr:]] // CHECK-LABEL: test_int3 int3 test_int3(int3 expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v3i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.v3i32([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v3i32([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); } // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v3i32([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v3i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.subgroup.bit.or.v3i32([[TY]]) #[[#attr:]] // CHECK-LABEL: test_int4 int4 test_int4(int4 expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.v4i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.v4i32([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.v4i32([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); } // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.v4i32([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.v4i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.subgroup.bit.or.v4i32([[TY]]) #[[#attr:]] // CHECK-LABEL: test_int16 int16_t test_int16_t(int16_t expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i16([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); } // CHECK-DXIL: declare [[TY]] @llvm.dx.wave.bit.or.i16([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.bit.or.i16([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare [[TY]] @llvm.spv.subgroup.bit.or.i16([[TY]]) #[[#attr:]] // CHECK-LABEL: test_int64 int64_t test_int64_t(int64_t expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i64([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); @@ -75,7 +75,7 @@ int64_t test_int64_t(int64_t expr) { // CHECK-LABEL: test_uint uint test_uint(uint expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i32([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i32([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); @@ -83,7 +83,7 @@ uint test_uint(uint expr) { // CHECK-LABEL: test_uint16 uint16_t test_uint16_t(uint16_t expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i16([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i16([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i16([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); @@ -91,7 +91,7 @@ uint16_t test_uint16_t(uint16_t expr) { // CHECK-LABEL: test_uint64 uint64_t test_uint64_t(uint64_t expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.bit.or.i64([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.bit.or.i64([[TY]] %[[#]]) // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.bit.or.i64([[TY]] %[[#]]) // CHECK: ret [[TY]] %[[RET]] return WaveActiveBitOr(expr); diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 9c1cab67b580e..81fdaf6c71f4c 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -123,7 +123,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">, DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; - def int_spv_wave_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; + def int_spv_subgroup_bit_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 09353a127e033..0f99697e96c37 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3944,7 +3944,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll); case Intrinsic::spv_wave_any: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny); - case Intrinsic::spv_wave_bit_or: + case Intrinsic::spv_subgroup_bit_or: return selectWaveBitOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformBitwiseOr); case Intrinsic::spv_subgroup_ballot: diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll index 81b0bfe03dbe7..2f871094ab328 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll @@ -12,19 +12,19 @@ define i32 @test_uint(i32 %iexpr) { entry: ; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]] - %0 = call i32 @llvm.spv.wave.bit.or.i32(i32 %iexpr) + %0 = call i32 @llvm.spv.subgroup.bit.or.i32(i32 %iexpr) ret i32 %0 } -declare i32 @llvm.spv.wave.bit.or.i32(i32) +declare i32 @llvm.spv.subgroup.bit.or.i32(i32) ; CHECK-LABEL: Begin function test_uint64 ; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]] define i64 @test_uint64(i64 %iexpr64) { entry: ; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]] - %0 = call i64 @llvm.spv.wave.bit.or.i64(i64 %iexpr64) + %0 = call i64 @llvm.spv.subgroup.bit.or.i64(i64 %iexpr64) ret i64 %0 } -declare i64 @llvm.spv.wave.bit.or.i64(i64) +declare i64 @llvm.spv.subgroup.bit.or.i64(i64) _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
