llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add support for `arith.negf`.


---
Full diff: https://github.com/llvm/llvm-project/pull/169759.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+44) 
- (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+9) 
- (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+10) 
- (modified) 
mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+4) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp 
b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 566632bd8707f..230abb51e8158 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -449,6 +449,49 @@ struct CmpFOpToAPFloatConversion final : 
OpRewritePattern<arith::CmpFOp> {
   SymbolOpInterface symTable;
 };
 
+struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
+  NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) 
{}
+
+  LogicalResult matchAndRewrite(arith::NegFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn =
+        lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+    if (failed(fn))
+      return fn;
+
+    // Cast operand to 64-bit integer.
+    rewriter.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto floatTy = cast<FloatType>(op.getOperand().getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value operandBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, 
intWType, op.getOperand()));
+
+    // Call APFloat function.
+    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, operandBits};
+    Value negatedBits =
+        func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                             SymbolRefAttr::get(*fn), params)
+            ->getResult(0);
+
+    // Truncate result to the original width.
+    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+                                                  negatedBits);
+    Value result =
+        arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
 namespace {
 struct ArithToAPFloatConversionPass final
     : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -482,6 +525,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
   patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
                                                    /*isUnsigned=*/true);
   patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
+  patterns.add<NegFOpToAPFloatConversion>(context, getOperation());
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
     if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp 
b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 77f7137264888..f2d5254be6b57 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -142,4 +142,13 @@ MLIR_APFLOAT_WRAPPERS_EXPORT int8_t 
_mlir_apfloat_compare(int32_t semantics,
   llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
   return static_cast<int8_t>(x.compare(y));
 }
+
+MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, 
uint64_t a) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  x.changeSign();
+  return x.bitcastToAPInt().getZExtValue();
+}
 }
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir 
b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index 78ce3640ecc67..775cb5ea60f22 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -213,3 +213,13 @@ func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
   %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64
+// CHECK: %[[sem:.*]] = arith.constant 2 : i32
+// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, 
i64) -> i64
+func.func @negf(%arg0: f32) {
+  %0 = arith.negf %arg0 : f32
+  return
+}
diff --git 
a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir 
b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 433d058d025cf..555cc9a531966 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,6 +43,10 @@ func.func @entry() {
   %cvt = arith.truncf %b2 : f32 to f8E4M3FN
   vector.print %cvt : f8E4M3FN
 
+  // CHECK-NEXT: -2.25
+  %negated = arith.negf %cvt : f8E4M3FN
+  vector.print %negated : f8E4M3FN
+
   // CHECK-NEXT: 1
   %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
   vector.print %cmp1 : i1

``````````

</details>


https://github.com/llvm/llvm-project/pull/169759
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to