llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-execution-engine Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Add support for `arith.fptosi` and `arith.fptoui`. Depends on #<!-- -->169275. --- Full diff: https://github.com/llvm/llvm-project/pull/169277.diff 4 Files Affected: - (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+58) - (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+14) - (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+26) - (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+10) ``````````diff diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 90e6e674da519..1fe698f1c8902 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -185,6 +185,60 @@ struct FpToFpConversion final : OpRewritePattern<OpTy> { SymbolOpInterface symTable; }; +template <typename OpTy> +struct FpToIntConversion final : OpRewritePattern<OpTy> { + FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned){}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + rewriter.setInsertionPoint(op); + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + auto inFloatTy = cast<FloatType>(op.getOperand().getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand())); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outIntTy = cast<IntegerType>(op.getType()); + Value outWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, outIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue, + operandBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy, + resultOp->getResult(0)); + rewriter.replaceOp(op, truncatedBits); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { @@ -208,6 +262,10 @@ void ArithToAPFloatConversionPass::runOnOperation() { context, "remainder", getOperation()); patterns.add<FpToFpConversion<arith::ExtFOp>>(context, getOperation()); patterns.add<FpToFpConversion<arith::TruncFOp>>(context, getOperation()); + patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(), + /*isUnsigned=*/true); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 511b05ea380f0..632fe9cf2269d 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -20,6 +20,7 @@ // APFloatBase::Semantics enum value. // #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APSInt.h" #ifdef _WIN32 #ifndef MLIR_APFLOAT_WRAPPERS_EXPORT @@ -101,4 +102,17 @@ _mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) { llvm::APInt result = val.bitcastToAPInt(); return result.getZExtValue(); } + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int( + int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat val(sem, llvm::APInt(inputWidth, a)); + llvm::APSInt result(resultWidth, isUnsigned); + bool isExact; + // TODO: Custom rounding modes are not supported yet. + val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact); + return result.getZExtValue(); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index 038acbfc965a2..f1acfd5e5618a 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -148,3 +148,29 @@ func.func @truncf(%arg0: bf16) { %0 = arith.truncf %arg0 : bf16 to f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptosi(%arg0: f16) { + %0 = arith.fptosi %arg0 : f16 to i4 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptoui(%arg0: f16) { + %0 = arith.fptoui %arg0 : f16 to i4 + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 51976434d2be2..5e93945c3eb60 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,5 +43,15 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: 1 + // Bit pattern: 01, interpreted as signed integer: 1 + %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2 + vector.print %cvt_int_signed : i2 + + // CHECK-NEXT: -2 + // Bit pattern: 10, interpreted as signed integer: -2 + %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2 + vector.print %cvt_int_unsigned : i2 + return } `````````` </details> https://github.com/llvm/llvm-project/pull/169277 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
