https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/95849
>From b10aa2317f566febdf4cd3630a972be58fea515b Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Mon, 17 Jun 2024 14:03:03 -0700 Subject: [PATCH 1/2] [SPIRV][HLSL] Add lowering of `rsqrt` to SPIRV --- clang/lib/CodeGen/CGBuiltin.cpp | 2 +- clang/lib/CodeGen/CGHLSLRuntime.h | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 22 ++++++++++++++ .../CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll | 29 +++++++++++++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 511e1fd4016d7..3c233eb3f2dbf 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18331,7 +18331,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("rsqrt operand must have a float representation"); return Builder.CreateIntrinsic( - /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt, + /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt"); } case Builtin::BI__builtin_hlsl_wave_get_lane_index: { diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 0abe39dedcb96..4036ce711bea1 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -75,6 +75,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(All, all) GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any) GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp) + GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt) GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id) //===----------------------------------------------------------------------===// diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 90f12674d0470..683acf4a6ffa9 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -62,4 +62,5 @@ let TargetPrefix = "spv" in { def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >; + def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index db83172f7fa9c..b9e5569029cfd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -173,6 +173,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectFmix(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectRsqrt(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, int OpIdx) const; void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, @@ -1315,6 +1318,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + + assert(I.getNumOperands() == 3); + assert(I.getOperand(2).isReg()); + MachineBasicBlock &BB = *I.getParent(); + + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450)) + .addImm(GL::InverseSqrt) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -1992,6 +2012,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAny(ResVReg, ResType, I); case Intrinsic::spv_lerp: return selectFmix(ResVReg, ResType, I); + case Intrinsic::spv_rsqrt: + return selectRsqrt(ResVReg, ResType, I); case Intrinsic::spv_lifetime_start: case Intrinsic::spv_lifetime_end: { unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll new file mode 100644 index 0000000000000..1541a5715b952 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll @@ -0,0 +1,29 @@ +; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: OpExtInstImport "GLSL.std.450" + +define noundef float @rsqrt_float(float noundef %a) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] + %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a) + ret float %elt.rsqrt +} + +define noundef half @rsqrt_half(half noundef %a) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] + %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a) + ret half %elt.rsqrt +} + +define noundef double @rsqrt_double(double noundef %a) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] + %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a) + ret double %elt.rsqrt +} + +declare half @llvm.spv.sqrt.f16(half) +declare float @llvm.spv.sqrt.f32(float) +declare float @llvm.spv.sqrt.f64(float) >From 6a1b0e40045bc5652bc88b0766aec6727893d02e Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Mon, 17 Jun 2024 17:18:06 -0700 Subject: [PATCH 2/2] Improve test - add vectors and check types --- .../CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll | 53 ++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll index 1541a5715b952..650b32910d65e 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll @@ -1,29 +1,68 @@ ; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} -; CHECK: OpExtInstImport "GLSL.std.450" +; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450" + +; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64 + +; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4 +; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4 +; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4 define noundef float @rsqrt_float(float noundef %a) { entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] +; CHECK: %[[#float_32_arg:]] = OpFunctionParameter %[[#float_32]] +; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#float_32_arg]] %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a) ret float %elt.rsqrt } define noundef half @rsqrt_half(half noundef %a) { entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] +; CHECK: %[[#float_16_arg:]] = OpFunctionParameter %[[#float_16]] +; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#float_16_arg]] %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a) ret half %elt.rsqrt } define noundef double @rsqrt_double(double noundef %a) { entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]] +; CHECK: %[[#float_64_arg:]] = OpFunctionParameter %[[#float_64]] +; CHECK: %[[#]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#float_64_arg]] %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a) ret double %elt.rsqrt } -declare half @llvm.spv.sqrt.f16(half) -declare float @llvm.spv.sqrt.f32(float) -declare float @llvm.spv.sqrt.f64(float) +define noundef <4 x float> @rsqrt_float_vector(<4 x float> noundef %a) { +entry: +; CHECK: %[[#vec4_float_32_arg:]] = OpFunctionParameter %[[#vec4_float_32]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_32_arg]] + %elt.rsqrt = call <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float> %a) + ret <4 x float> %elt.rsqrt +} + +define noundef <4 x half> @rsqrt_half_vector(<4 x half> noundef %a) { +entry: +; CHECK: %[[#vec4_float_16_arg:]] = OpFunctionParameter %[[#vec4_float_16]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_16_arg]] + %elt.rsqrt = call <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half> %a) + ret <4 x half> %elt.rsqrt +} + +define noundef <4 x double> @rsqrt_double_vector(<4 x double> noundef %a) { +entry: +; CHECK: %[[#vec4_float_64_arg:]] = OpFunctionParameter %[[#vec4_float_64]] +; CHECK: %[[#]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_64_arg]] + %elt.rsqrt = call <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double> %a) + ret <4 x double> %elt.rsqrt +} + +declare half @llvm.spv.rsqrt.f16(half) +declare float @llvm.spv.rsqrt.f32(float) +declare double @llvm.spv.rsqrt.f64(double) + +declare <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float>) +declare <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half>) +declare <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double>) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits