Author: Srinivasa Ravi Date: 2025-05-14T14:39:59+05:30 New Revision: 155e188d94c95b9f389912db2fb180ac8dd75a28
URL: https://github.com/llvm/llvm-project/commit/155e188d94c95b9f389912db2fb180ac8dd75a28 DIFF: https://github.com/llvm/llvm-project/commit/155e188d94c95b9f389912db2fb180ac8dd75a28.diff LOG: [NVPTX] Add intrinsics and clang builtins for conversions of f4x2 type (#139244) This change adds intrinsics and clang builtins for the cvt instruction variants of type (FP4) `.e2m1x2`. introduced in PTX 8.6 for `sm_100a`, `sm_101a`, and `sm_120a`. Tests are added in `NVPTX/convert-sm100a.ll` and `clang/test/CodeGen/builtins-nvptx.c` and verified through ptxas 12.8.0. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Added: Modified: clang/include/clang/Basic/BuiltinsNVPTX.td clang/test/CodeGen/builtins-nvptx.c llvm/include/llvm/IR/IntrinsicsNVVM.td llvm/lib/Target/NVPTX/NVPTXInstrInfo.td llvm/lib/Target/NVPTX/NVPTXIntrinsics.td llvm/test/CodeGen/NVPTX/convert-sm100a.ll Removed: ################################################################################ diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index f797e29fe66a3..2cea44e224674 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -620,6 +620,12 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; + +def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; +def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; + def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>; diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 639c18190f436..7904762709df6 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1127,6 +1127,26 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() { // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 19532) __nvvm_e3m2x2_to_f16x2_rn_relu(0x4C4C); + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e2m1x2_rn_satfinite(1.0f, 1.0f); + + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e2m1x2_rn_relu_satfinite(1.0f, 1.0f); + + // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76) + __nvvm_e2m1x2_to_f16x2_rn(0x004C); + + // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76) + __nvvm_e2m1x2_to_f16x2_rn_relu(0x004C); + // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 5be1a915a06a7..0b26bb9829005 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1293,10 +1293,19 @@ let TargetPrefix = "nvvm" in { } } + // FP4 conversions. + foreach relu = ["", "_relu"] in { + def int_nvvm_ff_to_e2m1x2_rn # relu # _satfinite : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + + def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + } + // UE8M0x2 conversions. foreach rmode = ["_rz", "_rp"] in { foreach satmode = ["", "_satfinite"] in { - defvar suffix = !strconcat(rmode, satmode); + defvar suffix = rmode # satmode; def int_nvvm_ff_to_ue8m0x2 # suffix : NVVMBuiltin, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b6104a5aed0d1..2c65ee6d484d5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -714,6 +714,23 @@ let hasSideEffects = false in { # type # " \t$dst, $src;", []>; } + // FP4 conversions. + def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst), + (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_out; \n\t", + "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t", + "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t", + "}}"), []>; + + def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_in; \n\t", + "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t", + "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t", + "}}"), []>; + // UE8M0x2 conversions. class CVT_f32_to_ue8m0x2<string sat = ""> : NVPTXInst<(outs Int16Regs:$dst), diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8110ba1b2b37b..d3cfce76c666e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -2003,6 +2003,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a), def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a), (CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>, Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b), (CVT_ue8m0x2_f32 $a, $b, CvtRZ)>, diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll index def2575deb042..9acbb7984638a 100644 --- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll +++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll @@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) { %val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in) ret <2 x bfloat> %val } + +define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2) + ret i16 %val +} + +define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2) + ret i16 %val +} + +define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in) + ret <2 x half> %val +} + +define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in) + ret <2 x half> %val +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits