llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-nvptx Author: Srinivasa Ravi (Wolfram70) <details> <summary>Changes</summary> This change adds intrinsics and clang builtins for the cvt instruction variants of type (FP4) `.e2m1x2`. introduced in PTX 8.6 for `sm_100a`, `sm_101a`, and `sm_120a`. Tests are added in `NVPTX/convert-sm100a.ll` and `clang/test/CodeGen/builtins-nvptx.c` and verified through ptxas 12.8.0. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt --- Full diff: https://github.com/llvm/llvm-project/pull/139244.diff 6 Files Affected: - (modified) clang/include/clang/Basic/BuiltinsNVPTX.td (+6) - (modified) clang/test/CodeGen/builtins-nvptx.c (+20) - (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+7) - (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+17) - (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+14) - (modified) llvm/test/CodeGen/NVPTX/convert-sm100a.ll (+82) ``````````diff diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index f797e29fe66a3..2cea44e224674 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -620,6 +620,12 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; + +def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; + def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 639c18190f436..7904762709df6 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1127,6 +1127,26 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() { // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 19532) __nvvm_e3m2x2_to_f16x2_rn_relu(0x4C4C); + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e2m1x2_rn_satfinite(1.0f, 1.0f); + + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e2m1x2_rn_relu_satfinite(1.0f, 1.0f); + + // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + __nvvm_e2m1x2_to_f16x2_rn(0x004C); + + // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + __nvvm_e2m1x2_to_f16x2_rn_relu(0x004C); + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 8b87822d3fdda..60178bf01f266 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1663,6 +1663,13 @@ let TargetPrefix = "nvvm" in { def int_nvvm_ # type # _to_f16x2 # suffix : CVT_I16_TO_F16X2<type, suffix>; } } + + // FP4 conversions. + foreach relu = ["", "_relu"] in { + defvar suffix = !strconcat("_rn", relu); + def int_nvvm_ff_to_e2m1x2 # suffix # _satfinite : CVT_FF_TO_I16<"e2m1x2", !strconcat(suffix, "_satfinite")>; + def int_nvvm_e2m1x2_to_f16x2 # suffix : CVT_I16_TO_F16X2<"e2m1x2", suffix>; + } // UE8M0x2 conversions. foreach rmode = ["_rz", "_rp"] in { diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 11d77599d4ac3..b127ea66eeb71 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -721,6 +721,23 @@ let hasSideEffects = false in { # type # " \t$dst, $src;", []>; } + // FP4 conversions. + def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst), + (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_out; \n\t", + "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t", + "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t", + "}}"), []>; + + def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_in; \n\t", + "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t", + "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t", + "}}"), []>; + // UE8M0x2 conversions. class CVT_f32_to_ue8m0x2<string sat = ""> : NVPTXInst<(outs Int16Regs:$dst), diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 3eedb43e4c81a..3dcf66b793409 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1976,6 +1976,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a), def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a), (CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>, Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b), (CVT_ue8m0x2_f32 $a, $b, CvtRZ)>, diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll index 04d7a65f9e40e..e2e7cd9ff2fe6 100644 --- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll +++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll @@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) { %val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in) ret <2 x bfloat> %val } + +define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2) + ret i16 %val +} + +define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.f32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2) + ret i16 %val +} + +define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in) + ret <2 x half> %val +} + +define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in) + ret <2 x half> %val +} `````````` </details> https://github.com/llvm/llvm-project/pull/139244 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits