[Lldb-commits] [llvm] [lldb] [clang-tools-extra] [clang] [libunwind] [libc] [compiler-rt] [flang] [lld] [libclc] [mlir] [libcxx] [mlir][tensor] Enhance pack/unpack simplification for identity outer_di
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/77409 >From e74b859897cdf1b1effbbd48a4e5376a231f7132 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Mon, 8 Jan 2024 20:17:30 -0800 Subject: [PATCH 1/2] [mlir][tensor] Enhance pack/unpack simplication patterns for identity outer_dims_perm cases. They can be simplified to reshape ops if outer_dims_perm is an identity permutation. The revision adds a `isIdentityPermutation` method to IndexingUtils. --- .../mlir/Dialect/Utils/IndexingUtils.h| 3 +++ .../Dialect/Tensor/Transforms/CMakeLists.txt | 1 + .../Transforms/PackAndUnpackPatterns.cpp | 17 + mlir/lib/Dialect/Utils/IndexingUtils.cpp | 8 +++ .../Dialect/Tensor/simplify-pack-unpack.mlir | 24 +++ 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index f51a8b28b7548e..2453d841f633e4 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -228,6 +228,9 @@ void applyPermutationToVector(SmallVector &inVec, /// Helper method to apply to inverse a permutation. SmallVector invertPermutationVector(ArrayRef permutation); +/// Returns true if `permutation` is an identity permutation. +bool isIdentityPermutation(ArrayRef permutation); + /// Method to check if an interchange vector is a permutation. bool isPermutationVector(ArrayRef interchange); diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index cbc0d499d9d52c..c6ef6ed86e0d9d 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRTensorTransforms MLIRArithUtils MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRDialectUtils MLIRIR MLIRLinalgDialect MLIRMemRefDialect diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp index cfd838e85c1b80..2853cb8fe77a3b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" @@ -38,8 +39,12 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); -if (!packOp.getOuterDimsPerm().empty()) - return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm"); +auto outerDimsPerm = packOp.getOuterDimsPerm(); +if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { + return rewriter.notifyMatchFailure( + packOp, + "expects outer_dims_perm is empty or an identity permutation"); +} RankedTensorType sourceType = packOp.getSourceType(); RankedTensorType destType = packOp.getDestType(); @@ -74,9 +79,11 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unpackOp, PatternRewriter &rewriter) const override { -if (!unpackOp.getOuterDimsPerm().empty()) { - return rewriter.notifyMatchFailure(unpackOp, - "expects no outer_dims_perm"); +auto outerDimsPerm = unpackOp.getOuterDimsPerm(); +if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { + return rewriter.notifyMatchFailure( + unpackOp, + "expects outer_dims_perm is empty or an identity permutation"); } RankedTensorType sourceType = unpackOp.getSourceType(); diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index bb8a0d5912d7c1..f3de454dc4b81a 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -213,6 +213,14 @@ mlir::invertPermutationVector(ArrayRef permutation) { return inversion; } +bool mlir::isIdentityPermutation(ArrayRef permutation) { + int n = permutation.size(); + for (int i = 0; i < n; ++i) +if (permutation[i] != i) + return false; + return true; +} + bool mlir::isPermutationVector(ArrayRef interchange) { assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && "permutation must be non-negative"); diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir index b78ab9bb3fd87e..ffbb2278a2e327 100644 --- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Tenso
[Lldb-commits] [libcxx] [libc] [lld] [clang-tools-extra] [libclc] [lldb] [llvm] [clang] [compiler-rt] [flang] [mlir] [libunwind] [mlir][tensor] Enhance pack/unpack simplification for identity outer_di
https://github.com/hanhanW closed https://github.com/llvm/llvm-project/pull/77409 ___ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/68941 >From 877111a139b2f01037fdbe7c0cb120a2f4e64562 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Thu, 12 Oct 2023 17:14:29 -0700 Subject: [PATCH 1/2] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" This cherry-picks the changes in llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting types to i1. --- .../Dialect/Arith/IR/ArithCanonicalization.td | 46 +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp| 4 +- mlir/test/Dialect/Arith/canonicalize.mlir | 76 +++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index f3d84d0b261e8dd..817de0e06c661b9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -233,6 +233,52 @@ def CmpIExtUI : CPred<"$0.getValue() == arith::CmpIPredicate::eq || " "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>; +//===--===// +// SelectOp +//===--===// + +// select(not(pred), a, b) => select(pred, b, a) +def SelectNotCond : +Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b), +(SelectOp $pred, $b, $a), +[(IsScalarOrSplatNegativeOne $ones)]>; + +// select(pred, select(pred, a, b), c) => select(pred, a, c) +def RedundantSelectTrue : +Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c), +(SelectOp $pred, $a, $c)>; + +// select(pred, a, select(pred, b, c)) => select(pred, a, c) +def RedundantSelectFalse : +Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), +(SelectOp $pred, $a, $c)>; + +// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) +def SelectAndCond : +Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y), +(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>; + +// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) +def SelectAndNotCond : +Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y), +(SelectOp (Arith_AndIOp $predA, +(Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr))), + $x, $y), +[(Constraint> $predB)]>; + +// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) +def SelectOrCond : +Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)), +(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>; + +// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) +def SelectOrNotCond : +Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)), +(SelectOp (Arith_OrIOp $predA, + (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr))), + $x, $y), +[(Constraint> $predB)]>; + //===--===// // IndexCastOp //===--===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ae8a6ef350ce191..0ecc288f3b07701 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern { void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index f697f3d01458eee..1b0547c9e8f804a 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { return %res : i1 } +// CHECK-LABEL: @redundantSelectTrue +// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3 +// CHECK-NEXT: return %[[res]] +func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %res = arith.select %arg0, %0, %arg3 : i32 + return %res : i32 +} + +// CHECK-LABEL: @redundantSelectFalse +// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2 +// CHECK-NEXT: return %[[res]] +func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %res = arith.select %arg0, %arg3, %0 : i32 + return %res : i32 +} + +// CHECK-LABEL: @selNotCond +// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)
@@ -233,6 +233,52 @@ def CmpIExtUI : CPred<"$0.getValue() == arith::CmpIPredicate::eq || " "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>; +//===--===// +// SelectOp +//===--===// + +// select(not(pred), a, b) => select(pred, b, a) +def SelectNotCond : +Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b), +(SelectOp $pred, $b, $a), +[(IsScalarOrSplatNegativeOne $ones)]>; + +// select(pred, select(pred, a, b), c) => select(pred, a, c) +def RedundantSelectTrue : +Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c), +(SelectOp $pred, $a, $c)>; + +// select(pred, a, select(pred, b, c)) => select(pred, a, c) +def RedundantSelectFalse : +Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), +(SelectOp $pred, $a, $c)>; + +// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) +def SelectAndCond : +Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y), +(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>; + +// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) +def SelectAndNotCond : +Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y), +(SelectOp (Arith_AndIOp $predA, +(Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr))), hanhanW wrote: done https://github.com/llvm/llvm-project/pull/68941 ___ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)
https://github.com/hanhanW edited https://github.com/llvm/llvm-project/pull/68941 ___ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)
hanhanW wrote: thanks for the review! https://github.com/llvm/llvm-project/pull/68941 ___ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)
https://github.com/hanhanW closed https://github.com/llvm/llvm-project/pull/68941 ___ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits