llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds intrinsics and clang builtins for the cvt instruction variants 
of type (FP4) `.e2m1x2`. introduced in PTX 8.6 for `sm_100a`, `sm_101a`, and 
`sm_120a`.

Tests are added in `NVPTX/convert-sm100a.ll` and
`clang/test/CodeGen/builtins-nvptx.c` and verified through ptxas 12.8.0.

PTX Spec Reference: 
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---
Full diff: https://github.com/llvm/llvm-project/pull/139244.diff


6 Files Affected:

- (modified) clang/include/clang/Basic/BuiltinsNVPTX.td (+6) 
- (modified) clang/test/CodeGen/builtins-nvptx.c (+20) 
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+7) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+17) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+14) 
- (modified) llvm/test/CodeGen/NVPTX/convert-sm100a.ll (+82) 


``````````diff
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td 
b/clang/include/clang/Basic/BuiltinsNVPTX.td
index f797e29fe66a3..2cea44e224674 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.td
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.td
@@ -620,6 +620,12 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : 
NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
 def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, 
__fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, 
__fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 
+def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, 
float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, 
float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+
+def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, 
__fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, 
__fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
+
 def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", 
SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, 
float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
 def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", 
SM<"100a", [SM_101a, SM_120a]>, PTX86>;
diff --git a/clang/test/CodeGen/builtins-nvptx.c 
b/clang/test/CodeGen/builtins-nvptx.c
index 639c18190f436..7904762709df6 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -1127,6 +1127,26 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
   // CHECK_PTX86_SM120a: call <2 x half> 
@llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 19532)
   __nvvm_e3m2x2_to_f16x2_rn_relu(0x4C4C);
 
+  // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 
1.000000e+00, float 1.000000e+00)
+  // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 
1.000000e+00, float 1.000000e+00)
+  // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 
1.000000e+00, float 1.000000e+00)
+  __nvvm_ff_to_e2m1x2_rn_satfinite(1.0f, 1.0f);
+
+  // CHECK_PTX86_SM100a: call i16 
@llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 
1.000000e+00)
+  // CHECK_PTX86_SM101a: call i16 
@llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 
1.000000e+00)
+  // CHECK_PTX86_SM120a: call i16 
@llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 
1.000000e+00)
+  __nvvm_ff_to_e2m1x2_rn_relu_satfinite(1.0f, 1.0f);
+  
+  // CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+  // CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+  // CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
+  __nvvm_e2m1x2_to_f16x2_rn(0x004C);
+  
+  // CHECK_PTX86_SM100a: call <2 x half> 
@llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+  // CHECK_PTX86_SM101a: call <2 x half> 
@llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+  // CHECK_PTX86_SM120a: call <2 x half> 
@llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
+  __nvvm_e2m1x2_to_f16x2_rn_relu(0x004C);
+
   // CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 
1.000000e+00, float 1.000000e+00)
   // CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 
1.000000e+00, float 1.000000e+00)
   // CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 
1.000000e+00, float 1.000000e+00)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td 
b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 8b87822d3fdda..60178bf01f266 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1663,6 +1663,13 @@ let TargetPrefix = "nvvm" in {
       def int_nvvm_ # type # _to_f16x2 # suffix : CVT_I16_TO_F16X2<type, 
suffix>;
     }
   }
+
+  // FP4 conversions.
+  foreach relu = ["", "_relu"] in {
+    defvar suffix = !strconcat("_rn", relu);
+    def int_nvvm_ff_to_e2m1x2 # suffix # _satfinite : CVT_FF_TO_I16<"e2m1x2", 
!strconcat(suffix, "_satfinite")>;
+    def int_nvvm_e2m1x2_to_f16x2 # suffix : CVT_I16_TO_F16X2<"e2m1x2", suffix>;
+  }
   
   // UE8M0x2 conversions.
   foreach rmode = ["_rz", "_rp"] in {
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td 
b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 11d77599d4ac3..b127ea66eeb71 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -721,6 +721,23 @@ let hasSideEffects = false in {
                                         # type # " \t$dst, $src;", []>;
   }
   
+  // FP4 conversions.
+  def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst),
+      (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
+      !strconcat("{{ \n\t",
+                 ".reg .b8 \t%e2m1x2_out; \n\t",
+                 "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 
\t%e2m1x2_out, $src1, $src2; \n\t",
+                 "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t",
+                 "}}"), []>;
+
+  def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst),
+      (ins Int16Regs:$src, CvtMode:$mode),
+      !strconcat("{{ \n\t",
+                 ".reg .b8 \t%e2m1x2_in; \n\t",
+                 "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
+                 "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; 
\n\t",
+                 "}}"), []>;
+
   // UE8M0x2 conversions.
   class CVT_f32_to_ue8m0x2<string sat = ""> :
     NVPTXInst<(outs Int16Regs:$dst),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td 
b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 3eedb43e4c81a..3dcf66b793409 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1976,6 +1976,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a),
 def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a),
           (CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>,
       Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+      
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b),
+          (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>,
+      Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b),
+          (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>,
+      Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+      
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a),
+          (CVT_f16x2_e2m1x2 $a, CvtRN)>,
+      Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a),
+          (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>,
+      Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
 
 def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b),
           (CVT_ue8m0x2_f32 $a, $b, CvtRZ)>,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll 
b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
index 04d7a65f9e40e..e2e7cd9ff2fe6 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
@@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) {
     %val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in)
     ret <2 x bfloat> %val
 }
+
+define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT:    ld.param.f32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .reg .b8 %e2m1x2_out;
+; CHECK-NEXT:    cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT:    cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT:    }
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+    %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2)
+    ret i16 %val
+}
+
+define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT:    ld.param.f32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .reg .b8 %e2m1x2_out;
+; CHECK-NEXT:    cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT:    cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT:    }
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+    %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float 
%f2)
+    ret i16 %val
+}
+
+define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_f16x2_e2m1x2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .reg .b8 %e2m1x2_in;
+; CHECK-NEXT:    cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT:    cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT:    }
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+    %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in)
+    ret <2 x half> %val
+}
+
+define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0];
+; CHECK-NEXT:    {
+; CHECK-NEXT:    .reg .b8 %e2m1x2_in;
+; CHECK-NEXT:    cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT:    cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT:    }
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+    %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in)
+    ret <2 x half> %val
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/139244
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to