csigg created this revision. csigg added a reviewer: bkramer. Herald added subscribers: bzcheeseman, mattd, gchakrabarti, awarzynski, sdasgup3, asavonic, wenzhicui, wrengr, Chia-hungDuan, dcaballe, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini, sanjoy.google, hiraditya, jholewinski. Herald added a reviewer: ftynse. Herald added a reviewer: bondhugula. Herald added a reviewer: ThomasRaoux. Herald added a project: All. csigg requested review of this revision. Herald added subscribers: llvm-commits, cfe-commits, stephenneuendorffer, nicolasvasilache, jdoerfert. Herald added a reviewer: herhut. Herald added projects: clang, MLIR, LLVM.
This is correct for all values, i.e. the same as promoting the division to fp32 in the NVPTX backend. But it is faster (~10% in average, sometimes more) because: - it performs less Newton iterations - it avoids the slow path for e.g. denormals - it allows reuse of the reciprocal for multiple divisions by the same divisor Test program: #include <stdio.h> #include "cuda_fp16.h" // This is a variant of CUDA's own __hdiv which is fast than hdiv_promote below // and doesn't suffer from the perf cliff of div.rn.fp32 with 'special' values. __device__ half hdiv_newton(half a, half b) { float fa = __half2float(a); float fb = __half2float(b); float rcp; asm("{rcp.approx.ftz.f32 %0, %1;\n}" : "=f"(rcp) : "f"(fb)); float result = fa * rcp; auto exponent = reinterpret_cast<const unsigned&>(result) & 0x7f800000; if (exponent != 0 && exponent != 0x7f800000) { float err = __fmaf_rn(-fb, result, fa); result = __fmaf_rn(rcp, err, result); } return __float2half(result); } // Surprisingly, this is faster than CUDA's own __hdiv. __device__ half hdiv_promote(half a, half b) { return __float2half(__half2float(a) / __half2float(b)); } // This is an approximation that is accurate up to 1 ulp. __device__ half hdiv_approx(half a, half b) { float fa = __half2float(a); float fb = __half2float(b); float result; asm("{div.approx.ftz.f32 %0, %1, %2;\n}" : "=f"(result) : "f"(fa), "f"(fb)); return __float2half(result); } __global__ void CheckCorrectness() { int i = threadIdx.x + blockIdx.x * blockDim.x; half x = reinterpret_cast<const half&>(i); for (int j = 0; j < 65536; ++j) { half y = reinterpret_cast<const half&>(j); half d1 = hdiv_newton(x, y); half d2 = hdiv_promote(x, y); auto s1 = reinterpret_cast<const short&>(d1); auto s2 = reinterpret_cast<const short&>(d2); if (s1 != s2) { printf("%f (%u) / %f (%u), got %f (%hu), expected: %f (%hu)\n", __half2float(x), i, __half2float(y), j, __half2float(d1), s1, __half2float(d2), s2); //__trap(); } } } __device__ half dst; __global__ void ProfileBuiltin(half x) { #pragma unroll 1 for (int i = 0; i < 10000000; ++i) { x = x / x; } dst = x; } __global__ void ProfilePromote(half x) { #pragma unroll 1 for (int i = 0; i < 10000000; ++i) { x = hdiv_promote(x, x); } dst = x; } __global__ void ProfileNewton(half x) { #pragma unroll 1 for (int i = 0; i < 10000000; ++i) { x = hdiv_newton(x, x); } dst = x; } __global__ void ProfileApprox(half x) { #pragma unroll 1 for (int i = 0; i < 10000000; ++i) { x = hdiv_approx(x, x); } dst = x; } int main() { CheckCorrectness<<<256, 256>>>(); half one = __float2half(1.0f); ProfileBuiltin<<<1, 1>>>(one); // 1.001s ProfilePromote<<<1, 1>>>(one); // 0.560s ProfileNewton<<<1, 1>>>(one); // 0.508s ProfileApprox<<<1, 1>>>(one); // 0.304s auto status = cudaDeviceSynchronize(); printf("%s\n", cudaGetErrorString(status)); } Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D126158 Files: clang/include/clang/Basic/BuiltinsNVPTX.def llvm/include/llvm/IR/IntrinsicsNVVM.td llvm/lib/Target/NVPTX/NVPTXIntrinsics.td mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir mlir/test/Dialect/LLVMIR/nvvm.mlir mlir/test/Target/LLVMIR/nvvmir.mlir
Index: mlir/test/Target/LLVMIR/nvvmir.mlir =================================================================== --- mlir/test/Target/LLVMIR/nvvmir.mlir +++ mlir/test/Target/LLVMIR/nvvmir.mlir @@ -1,5 +1,6 @@ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// CHECK-LABEL: @nvvm_special_regs llvm.func @nvvm_special_regs() -> i32 { // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() %1 = nvvm.read.ptx.sreg.tid.x : i32 @@ -32,12 +33,21 @@ llvm.return %1 : i32 } +// CHECK-LABEL: @nvvm_rcp +llvm.func @nvvm_rcp(%0: f32) -> f32 { + // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f + %1 = nvvm.rcp.approx.ftz.f %0 : f32 + llvm.return %1 : f32 +} + +// CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0() nvvm.barrier0 llvm.return } +// CHECK-LABEL: @nvvm_shfl llvm.func @nvvm_shfl( %0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : f32) -> i32 { @@ -60,6 +70,7 @@ llvm.return %6 : i32 } +// CHECK-LABEL: @nvvm_shfl_pred llvm.func @nvvm_shfl_pred( %0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> { @@ -82,6 +93,7 @@ llvm.return %6 : !llvm.struct<(i32, i1)> } +// CHECK-LABEL: @nvvm_vote llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 { // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}}) %3 = nvvm.vote.ballot.sync %0, %1 : i32 @@ -99,6 +111,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16 llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -111,6 +124,7 @@ } // f32 return type, f16 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -123,6 +137,7 @@ } // f16 return type, f32 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32 llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -135,6 +150,7 @@ } // f32 return type, f32 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32 llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -146,7 +162,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8 +llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8 @@ -158,7 +175,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8 +llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8 @@ -170,7 +188,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1 +llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1 @@ -181,6 +200,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4 llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { @@ -193,6 +213,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m8n8k4_f64_f64 llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64, %b0 : f64, %c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> { @@ -203,6 +224,7 @@ llvm.return %0 : !llvm.struct<(f64, f64)> } +// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32 llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> { @@ -228,6 +250,7 @@ // The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic // in the LLVM NVPTX backend. +// CHECK-LABEL: @gpu_wmma_store_op llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, %arg4: vector<2 xf16>, %arg5: vector<2 x f16>) { @@ -240,6 +263,7 @@ // The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic // in the LLVM NVPTX backend. +// CHECK-LABEL: @gpu_wmma_mma_op llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, @@ -261,6 +285,7 @@ llvm.return } +// CHECK-LABEL: @nvvm_wmma_load_tf32 llvm.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %{{.*}}, i32 %{{.*}}) %0 = nvvm.wmma.load %arg0, %arg1 @@ -269,6 +294,7 @@ llvm.return } +// CHECK-LABEL: @nvvm_wmma_mma llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) { @@ -280,6 +306,7 @@ llvm.return } +// CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) { // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 4 @@ -296,7 +323,7 @@ llvm.return } -// CHECK-LABEL: @ld_matrix( +// CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3i32(i32 addrspace(3)* %{{.*}}) %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32 Index: mlir/test/Dialect/LLVMIR/nvvm.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/nvvm.mlir +++ mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// CHECK-LABEL: @nvvm_special_regs func.func @nvvm_special_regs() -> i32 { // CHECK: nvvm.read.ptx.sreg.tid.x : i32 %0 = nvvm.read.ptx.sreg.tid.x : i32 @@ -28,12 +29,21 @@ llvm.return %0 : i32 } -func.func @llvm.nvvm.barrier0() { +// CHECK-LABEL: @nvvm_rcp +func.func @nvvm_rcp(%arg0: f32) -> f32 { + // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 + %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 + llvm.return %0 : f32 +} + +// CHECK-LABEL: @llvm_nvvm_barrier0 +func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 nvvm.barrier0 llvm.return } +// CHECK-LABEL: @nvvm_shfl func.func @nvvm_shfl( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> i32 { @@ -50,6 +60,7 @@ llvm.return %0 : i32 } +// CHECK-LABEL: @nvvm_shfl_pred func.func @nvvm_shfl_pred( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> { @@ -60,6 +71,7 @@ llvm.return %0 : !llvm.struct<(i32, i1)> } +// CHECK-LABEL: @nvvm_vote( func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 { // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32 %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32 @@ -77,6 +89,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m8n8k4_f16_f16 func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) { @@ -87,6 +100,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32, %c0 : i32, %c1 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> @@ -98,7 +112,8 @@ llvm.return %0 : !llvm.struct<(i32, i32)> } -func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, +// CHECK-LABEL: @nvvm_mma_m16n8k8_f16_f16 +func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>) { // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> @@ -108,6 +123,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16 func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -119,6 +135,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16 func.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -130,6 +147,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32 func.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -141,6 +159,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32 func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -152,7 +171,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32 +func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> @@ -163,7 +183,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8 +func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] @@ -174,7 +195,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8 +func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> @@ -186,6 +208,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k256_b1_b1 func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, %b0 : i32, %b1 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { @@ -197,6 +220,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1 func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { @@ -243,6 +267,7 @@ llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)> } +// CHECK-LABEL: @nvvm_wmma_mma func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) @@ -255,6 +280,7 @@ llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) { // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 nvvm.cp.async.shared.global %arg0, %arg1, 16 Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -488,3 +488,30 @@ } } +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_divf_fp16 + func.func @gpu_divf_fp16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 + // CHECK: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 + // CHECK: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %1 : f32 + // CHECK: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 + // CHECK: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 + // CHECK: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 + // CHECK: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 + // CHECK: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 + // CHECK: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 + // CHECK: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 + // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 + // CHECK: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 + // CHECK: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 + // CHECK: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 + // CHECK: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 + // CHECK: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 + %result = arith.divf %arg0, %arg1 : f16 + // CHECK: llvm.return %[[result]] : f16 + func.return %result : f16 + } +} + Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -148,6 +148,62 @@ } }; +// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one +// (conditional) Newton iteration. +// +// This as accurate as promoting the division to fp32 in the NVPTX backend, but +// faster because it performs less Newton iterations, avoids the slow path +// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions +// by the same divisor. +struct ExpandDivF16 : public ConvertOpToLLVMPattern<LLVM::FDivOp> { + using ConvertOpToLLVMPattern<LLVM::FDivOp>::ConvertOpToLLVMPattern; + +private: + LogicalResult + matchAndRewrite(LLVM::FDivOp op, LLVM::FDivOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().isF16()) + return rewriter.notifyMatchFailure(op, "not f16"); + Location loc = op.getLoc(); + + Type f32Type = rewriter.getF32Type(); + Type i32Type = rewriter.getI32Type(); + + // Extend lhs and rhs to fp32. + Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getLhs()); + Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getRhs()); + + // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. + Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); + Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); + + // Refine the approximation with one Newton iteration: + // float refined = approx + (lhs - approx * rhs) * rcp; + Value err = rewriter.create<LLVM::FMAOp>( + loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); + Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); + + // Use refined value if approx is normal (exponent neither all 0 or all 1). + Value mask = rewriter.create<LLVM::ConstantOp>( + loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); + Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); + Value zero = rewriter.create<LLVM::ConstantOp>( + loc, i32Type, rewriter.getUI32IntegerAttr(0)); + Value pred = rewriter.create<LLVM::OrOp>( + loc, + rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), + rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value result = + rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); + + // Replace with trucation back to fp16. + rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); + + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -222,6 +278,10 @@ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>(); + // Expand fdiv on fp16 to faster code than NVPTX backend's fp32 promotion. + target.addDynamicallyLegalOp<LLVM::FDivOp>( + [&](LLVM::FDivOp op) { return !op.getType().isF16(); }); + // TODO: Remove once we support replacing non-root ops. target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); } @@ -241,6 +301,8 @@ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( converter); + patterns.add<ExpandDivF16>(converter); + // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp<string mnem, list<int> overloadedResults, - list<int> overloadedOperands, list<Trait> traits, +class NVVM_IntrOp<string mnem, list<Trait> traits, int numResults> : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem), - overloadedResults, overloadedOperands, traits, numResults>; + /*list<int> overloadedResults=*/[], + /*list<int> overloadedOperands=*/[], + traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp<string mnemonic, - list<Trait> traits = []> : - NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>, - Arguments<(ins)> { +class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> : + NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> { + let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,6 +92,16 @@ def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM approximate op definitions +//===----------------------------------------------------------------------===// + +def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { + let arguments = (ins F32:$arg); + let results = (outs F32:$res); + let assemblyFormat = "$arg attr-dict `:` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1034,6 +1034,8 @@ def INT_NVVM_RCP_RP_D : F_MATH_1<"rcp.rp.f64 \t$dst, $src0;", Float64Regs, Float64Regs, int_nvvm_rcp_rp_d>; +def INT_NVVM_RCP_APPROX_FTZ_F : F_MATH_1<"rcp.approx.ftz.f32 \t$dst, $src0;", + Float32Regs, Float32Regs, int_nvvm_rcp_approx_ftz_f>; def INT_NVVM_RCP_APPROX_FTZ_D : F_MATH_1<"rcp.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs, int_nvvm_rcp_approx_ftz_d>; Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -933,6 +933,8 @@ def int_nvvm_rcp_rp_d : GCCBuiltin<"__nvvm_rcp_rp_d">, DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>; + def int_nvvm_rcp_approx_ftz_f : GCCBuiltin<"__nvvm_rcp_approx_ftz_f">, + DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>; def int_nvvm_rcp_approx_ftz_d : GCCBuiltin<"__nvvm_rcp_approx_ftz_d">, DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>; Index: clang/include/clang/Basic/BuiltinsNVPTX.def =================================================================== --- clang/include/clang/Basic/BuiltinsNVPTX.def +++ clang/include/clang/Basic/BuiltinsNVPTX.def @@ -343,6 +343,8 @@ BUILTIN(__nvvm_rcp_rz_d, "dd", "") BUILTIN(__nvvm_rcp_rm_d, "dd", "") BUILTIN(__nvvm_rcp_rp_d, "dd", "") + +BUILTIN(__nvvm_rcp_approx_ftz_f, "ff", "") BUILTIN(__nvvm_rcp_approx_ftz_d, "dd", "") // Sqrt
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits