csigg updated this revision to Diff 432871.
csigg added a comment.

Rebase.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D126158/new/

https://reviews.llvm.org/D126158

Files:
  mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
  mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
  mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
  mlir/test/Dialect/LLVMIR/nvvm.mlir
  mlir/test/Target/LLVMIR/nvvmir.mlir

Index: mlir/test/Target/LLVMIR/nvvmir.mlir
===================================================================
--- mlir/test/Target/LLVMIR/nvvmir.mlir
+++ mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -33,6 +33,13 @@
   llvm.return %1 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+llvm.func @nvvm_rcp(%0: f32) -> f32 {
+  // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f
+  %1 = nvvm.rcp.approx.ftz.f %0 : f32
+  llvm.return %1 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 llvm.func @llvm_nvvm_barrier0() {
   // CHECK: call void @llvm.nvvm.barrier0()
Index: mlir/test/Dialect/LLVMIR/nvvm.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -29,6 +29,13 @@
   llvm.return %0 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+func.func @nvvm_rcp(%arg0: f32) -> f32 {
+  // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
+  %0 = nvvm.rcp.approx.ftz.f %arg0 : f32
+  llvm.return %0 : f32
+}
+
 // CHECK-LABEL: @llvm_nvvm_barrier0
 func.func @llvm_nvvm_barrier0() {
   // CHECK: nvvm.barrier0
Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
===================================================================
--- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -488,3 +488,30 @@
   }
 }
 
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_divf_fp16
+  func.func @gpu_divf_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
+    // CHECK: %[[lhs:.*]]     = llvm.fpext %arg0 : f16 to f32
+    // CHECK: %[[rhs:.*]]     = llvm.fpext %arg1 : f16 to f32
+    // CHECK: %[[rcp:.*]]     = nvvm.rcp.approx.ftz.f %1 : f32
+    // CHECK: %[[approx:.*]]  = llvm.fmul %[[lhs]], %[[rcp]] : f32
+    // CHECK: %[[neg:.*]]     = llvm.fneg %[[rhs]] : f32
+    // CHECK: %[[err:.*]]     = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32
+    // CHECK: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32
+    // CHECK: %[[mask:.*]]    = llvm.mlir.constant(2139095040 : ui32) : i32
+    // CHECK: %[[cast:.*]]    = llvm.bitcast %[[approx]] : f32 to i32
+    // CHECK: %[[exp:.*]]     = llvm.and %[[cast]], %[[mask]] : i32
+    // CHECK: %[[c0:.*]]      = llvm.mlir.constant(0 : ui32) : i32
+    // CHECK: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32
+    // CHECK: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32
+    // CHECK: %[[pred:.*]]    = llvm.or %[[is_zero]], %[[is_mask]] : i1
+    // CHECK: %[[select:.*]]  = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32
+    // CHECK: %[[result:.*]]  = llvm.fptrunc %[[select]] : f32 to f16
+    %result = arith.divf %arg0, %arg1 : f16
+    // CHECK: llvm.return %[[result]] : f16
+    func.return %result : f16
+  }
+}
+
Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
===================================================================
--- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -148,6 +148,62 @@
   }
 };
 
+// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
+// (conditional) Newton iteration.
+//
+// This as accurate as promoting the division to fp32 in the NVPTX backend, but
+// faster because it performs less Newton iterations, avoids the slow path
+// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
+// by the same divisor.
+struct ExpandDivF16 : public ConvertOpToLLVMPattern<LLVM::FDivOp> {
+  using ConvertOpToLLVMPattern<LLVM::FDivOp>::ConvertOpToLLVMPattern;
+
+private:
+  LogicalResult
+  matchAndRewrite(LLVM::FDivOp op, LLVM::FDivOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getType().isF16())
+      return rewriter.notifyMatchFailure(op, "not f16");
+    Location loc = op.getLoc();
+
+    Type f32Type = rewriter.getF32Type();
+    Type i32Type = rewriter.getI32Type();
+
+    // Extend lhs and rhs to fp32.
+    Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getLhs());
+    Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getRhs());
+
+    // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
+    Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
+    Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
+
+    // Refine the approximation with one Newton iteration:
+    // float refined = approx + (lhs - approx * rhs) * rcp;
+    Value err = rewriter.create<LLVM::FMAOp>(
+        loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
+    Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
+
+    // Use refined value if approx is normal (exponent neither all 0 or all 1).
+    Value mask = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
+    Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
+    Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getUI32IntegerAttr(0));
+    Value pred = rewriter.create<LLVM::OrOp>(
+        loc,
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
+    Value result =
+        rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
+
+    // Replace with trucation back to fp16.
+    rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
+
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -222,6 +278,10 @@
                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op,
                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>();
 
+  // Expand fdiv on fp16 to faster code than NVPTX backend's fp32 promotion.
+  target.addDynamicallyLegalOp<LLVM::FDivOp>(
+      [&](LLVM::FDivOp op) { return !op.getType().isF16(); });
+
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
@@ -241,6 +301,8 @@
            GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
           converter);
 
+  patterns.add<ExpandDivF16>(converter);
+
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
   // memory space and does not support `alloca`s with addrspace(5).
Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -51,21 +51,21 @@
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
 
-class NVVM_IntrOp<string mnem, list<int> overloadedResults,
-                  list<int> overloadedOperands, list<Trait> traits,
+class NVVM_IntrOp<string mnem, list<Trait> traits,
                   int numResults>
   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
-                    overloadedResults, overloadedOperands, traits, numResults>;
+                    /*list<int> overloadedResults=*/[],
+                    /*list<int> overloadedOperands=*/[],
+                    traits, numResults>;
 
 
 //===----------------------------------------------------------------------===//
 // NVVM special register op definitions
 //===----------------------------------------------------------------------===//
 
-class NVVM_SpecialRegisterOp<string mnemonic,
-    list<Trait> traits = []> :
-  NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
-  Arguments<(ins)> {
+class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> {
+  let arguments = (ins);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
@@ -92,6 +92,16 @@
 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
 
+//===----------------------------------------------------------------------===//
+// NVVM approximate op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> {
+  let arguments = (ins F32:$arg);
+  let results = (outs F32:$res);
+  let assemblyFormat = "$arg attr-dict `:` type($res)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to