https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/68900
>From ddbde18e483d12485ba25c715e8a94480b9d6dcf Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Thu, 12 Oct 2023 16:55:22 +0200 Subject: [PATCH 1/3] [mlir][arith] Fix canon pattern for large ints in chained arith --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 25 +++++++++++++++-------- mlir/test/Dialect/Arith/canonicalize.mlir | 10 +++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 0ecc288f3b07701..25578b1c52f331b 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -39,26 +39,35 @@ using namespace mlir::arith; static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, - function_ref<int64_t(int64_t, int64_t)> binFn) { - return builder.getIntegerAttr(res.getType(), - binFn(llvm::cast<IntegerAttr>(lhs).getInt(), - llvm::cast<IntegerAttr>(rhs).getInt())); + function_ref<APInt(APInt, APInt&)> binFn) { + auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); + auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); + auto value = binFn(lhsVal, rhsVal); + return IntegerAttr::get(res.getType(), value); } static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>()); + auto binFn = [](APInt a, APInt& b) -> APInt { + return std::move(a) + b; + }; + return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>()); + auto binFn = [](APInt a, APInt& b) -> APInt { + return std::move(a) - b; + }; + return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - return applyToIntegerAttrs(builder, res, lhs, rhs, - std::multiplies<int64_t>()); + auto binFn = [](APInt a, APInt& b) -> APInt { + return std::move(a) * b; + }; + return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } /// Invert an integer comparison predicate. diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 { return %mul2 : i32 } +// CHECK-LABEL: @tripleMulLargeInt +// CHECK: return +func.func @tripleMulLargeInt(%arg0: i256) -> i256 { + %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256 + %c5 = arith.constant 5 : i256 + %mul1 = arith.muli %arg0, %0 : i256 + %mul2 = arith.muli %mul1, %c5 : i256 + return %mul2 : i256 +} + // CHECK-LABEL: @addiMuliToSubiRhsI32 // CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) // CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32 >From c0f3efe78fa6e71d1acc4d38f526ca2ec194ddf8 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Fri, 13 Oct 2023 10:14:16 +0200 Subject: [PATCH 2/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck <markus.boec...@gmail.com> --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 25578b1c52f331b..b749a4444f256e7 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -39,7 +39,7 @@ using namespace mlir::arith; static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, - function_ref<APInt(APInt, APInt&)> binFn) { + function_ref<APInt(const APInt&, const APInt&)> binFn) { auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); auto value = binFn(lhsVal, rhsVal); @@ -49,7 +49,7 @@ applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { auto binFn = [](APInt a, APInt& b) -> APInt { - return std::move(a) + b; + return a + b; }; return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } >From 30e1ce11d567452dcd7481e999109d1f25164065 Mon Sep 17 00:00:00 2001 From: Rik Huijzer <git...@huijzer.xyz> Date: Fri, 13 Oct 2023 10:49:20 +0200 Subject: [PATCH 3/3] Use `const`s and check result of fold --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 20 +++++++------------- mlir/test/Dialect/Arith/canonicalize.mlir | 12 +++++++----- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index b749a4444f256e7..5fe7a256cce07d1 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -39,34 +39,28 @@ using namespace mlir::arith; static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, - function_ref<APInt(const APInt&, const APInt&)> binFn) { - auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); - auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); - auto value = binFn(lhsVal, rhsVal); + function_ref<APInt(const APInt &, const APInt &)> binFn) { + APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); + APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); + APInt value = binFn(lhsVal, rhsVal); return IntegerAttr::get(res.getType(), value); } static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - auto binFn = [](APInt a, APInt& b) -> APInt { - return a + b; - }; + auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a + b; }; return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - auto binFn = [](APInt a, APInt& b) -> APInt { - return std::move(a) - b; - }; + auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a - b; }; return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { - auto binFn = [](APInt a, APInt& b) -> APInt { - return std::move(a) * b; - }; + auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a * b; }; return applyToIntegerAttrs(builder, res, lhs, rhs, binFn); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index b18f5cfcb3f9a12..98788536980f939 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -986,13 +986,15 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 { } // CHECK-LABEL: @tripleMulLargeInt -// CHECK: return +// CHECK: %[[cres:.+]] = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020482 : i256 +// CHECK: %[[addi:.+]] = arith.addi %arg0, %[[cres]] : i256 +// CHECK: return %[[addi]] func.func @tripleMulLargeInt(%arg0: i256) -> i256 { %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256 - %c5 = arith.constant 5 : i256 - %mul1 = arith.muli %arg0, %0 : i256 - %mul2 = arith.muli %mul1, %c5 : i256 - return %mul2 : i256 + %1 = arith.constant 1 : i256 + %2 = arith.addi %arg0, %0 : i256 + %3 = arith.addi %2, %1 : i256 + return %3 : i256 } // CHECK-LABEL: @addiMuliToSubiRhsI32 _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits