This revision was automatically updated to reflect the committed changes. Closed by commit rGbcfc0a905101: [MLIR][GPU] Replace fdiv on fp16 with promoted (fp32) multiplication with… (authored by csigg).
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D126158/new/ https://reviews.llvm.org/D126158 Files: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp mlir/test/Dialect/LLVMIR/nvvm.mlir mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir mlir/test/Target/LLVMIR/nvvmir.mlir utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3386,7 +3386,9 @@ ":IR", ":LLVMDialect", ":LLVMPassIncGen", + ":NVVMDialect", ":Pass", + ":Transforms", ], ) Index: mlir/test/Target/LLVMIR/nvvmir.mlir =================================================================== --- mlir/test/Target/LLVMIR/nvvmir.mlir +++ mlir/test/Target/LLVMIR/nvvmir.mlir @@ -33,6 +33,13 @@ llvm.return %1 : i32 } +// CHECK-LABEL: @nvvm_rcp +llvm.func @nvvm_rcp(%0: f32) -> f32 { + // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f + %1 = nvvm.rcp.approx.ftz.f %0 : f32 + llvm.return %1 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0() Index: mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s + +// CHECK-LABEL: llvm.func @fdiv_fp16 +llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK-DAG: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 + // CHECK-DAG: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 + // CHECK-DAG: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 + // CHECK-DAG: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 + // CHECK-DAG: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %[[rhs]] : f32 + // CHECK-DAG: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 + // CHECK-DAG: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 + // CHECK-DAG: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 + // CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 + // CHECK-DAG: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 + // CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 + // CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 + // CHECK-DAG: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 + // CHECK-DAG: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 + // CHECK-DAG: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 + %result = llvm.fdiv %arg0, %arg1 : f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +} Index: mlir/test/Dialect/LLVMIR/nvvm.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/nvvm.mlir +++ mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -29,6 +29,13 @@ llvm.return %0 : i32 } +// CHECK-LABEL: @nvvm_rcp +func.func @nvvm_rcp(%arg0: f32) -> f32 { + // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 + %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 + llvm.return %0 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 Index: mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -0,0 +1,97 @@ +//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" +#include "PassDetail.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one +// (conditional) Newton iteration. +// +// This as accurate as promoting the division to fp32 in the NVPTX backend, but +// faster because it performs less Newton iterations, avoids the slow path +// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions +// by the same divisor. +struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> { + using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern; + +private: + LogicalResult matchAndRewrite(LLVM::FDivOp op, + PatternRewriter &rewriter) const override; +}; + +struct NVVMOptimizeForTarget + : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> { + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<NVVM::NVVMDialect>(); + } +}; +} // namespace + +LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, + PatternRewriter &rewriter) const { + if (!op.getType().isF16()) + return rewriter.notifyMatchFailure(op, "not f16"); + Location loc = op.getLoc(); + + Type f32Type = rewriter.getF32Type(); + Type i32Type = rewriter.getI32Type(); + + // Extend lhs and rhs to fp32. + Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs()); + Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs()); + + // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. + Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); + Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); + + // Refine the approximation with one Newton iteration: + // float refined = approx + (lhs - approx * rhs) * rcp; + Value err = rewriter.create<LLVM::FMAOp>( + loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); + Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); + + // Use refined value if approx is normal (exponent neither all 0 or all 1). + Value mask = rewriter.create<LLVM::ConstantOp>( + loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); + Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); + Value zero = rewriter.create<LLVM::ConstantOp>( + loc, i32Type, rewriter.getUI32IntegerAttr(0)); + Value pred = rewriter.create<LLVM::OrOp>( + loc, + rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), + rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value result = + rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); + + // Replace with trucation back to fp16. + rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); + + return success(); +} + +void NVVMOptimizeForTarget::runOnOperation() { + MLIRContext *ctx = getOperation()->getContext(); + RewritePatternSet patterns(ctx); + patterns.add<ExpandDivF16>(ctx); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); +} + +std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() { + return std::make_unique<NVVMOptimizeForTarget>(); +} Index: mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms LegalizeForExport.cpp + OptimizeForNVVM.cpp DEPENDS MLIRLLVMPassIncGen Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::LLVM::createLegalizeForExportPass()"; } +def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { + let summary = "Optimize NVVM IR"; + let constructor = "mlir::NVVM::createOptimizeForTargetPass()"; +} + #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" +#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" #include "mlir/Pass/Pass.h" namespace mlir { Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h @@ -0,0 +1,25 @@ +//===- OptimizeForNVVM.h - Optimize LLVM IR for NVVM -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H + +#include <memory> + +namespace mlir { +class Pass; + +namespace NVVM { + +/// Creates a pass that optimizes LLVM IR for the NVVM target. +std::unique_ptr<Pass> createOptimizeForTargetPass(); + +} // namespace NVVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp<string mnem, list<int> overloadedResults, - list<int> overloadedOperands, list<Trait> traits, +class NVVM_IntrOp<string mnem, list<Trait> traits, int numResults> : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem), - overloadedResults, overloadedOperands, traits, numResults>; + /*list<int> overloadedResults=*/[], + /*list<int> overloadedOperands=*/[], + traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp<string mnemonic, - list<Trait> traits = []> : - NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>, - Arguments<(ins)> { +class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> : + NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> { + let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,6 +92,16 @@ def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM approximate op definitions +//===----------------------------------------------------------------------===// + +def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { + let arguments = (ins F32:$arg); + let results = (outs F32:$res); + let assemblyFormat = "$arg attr-dict `:` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===//
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits