llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Add support for `arith.negf`. --- Full diff: https://github.com/llvm/llvm-project/pull/169759.diff 4 Files Affected: - (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+44) - (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+9) - (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+10) - (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+4) ``````````diff diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 566632bd8707f..230abb51e8158 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { SymbolOpInterface symTable; }; +struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> { + NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Cast operand to 64-bit integer. + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto floatTy = cast<FloatType>(op.getOperand().getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand())); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + negatedBits); + Value result = + arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits); + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { @@ -482,6 +525,7 @@ void ArithToAPFloatConversionPass::runOnOperation() { patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), /*isUnsigned=*/true); patterns.add<CmpFOpToAPFloatConversion>(context, getOperation()); + patterns.add<NegFOpToAPFloatConversion>(context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 77f7137264888..f2d5254be6b57 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -142,4 +142,13 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); return static_cast<int8_t>(x.compare(y)); } + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + x.changeSign(); + return x.bitcastToAPInt().getZExtValue(); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index 78ce3640ecc67..775cb5ea60f22 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -213,3 +213,13 @@ func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64 +func.func @negf(%arg0: f32) { + %0 = arith.negf %arg0 : f32 + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 433d058d025cf..555cc9a531966 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,6 +43,10 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: -2.25 + %negated = arith.negf %cvt : f8E4M3FN + vector.print %negated : f8E4M3FN + // CHECK-NEXT: 1 %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN vector.print %cmp1 : i1 `````````` </details> https://github.com/llvm/llvm-project/pull/169759 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
