[Lldb-commits] [llvm] [lldb] [clang-tools-extra] [clang] [libunwind] [libc] [compiler-rt] [flang] [lld] [libclc] [mlir] [libcxx] [mlir][tensor] Enhance pack/unpack simplification for identity outer_di

2024-01-10 Thread Han-Chung Wang via lldb-commits

https://github.com/hanhanW updated 
https://github.com/llvm/llvm-project/pull/77409

>From e74b859897cdf1b1effbbd48a4e5376a231f7132 Mon Sep 17 00:00:00 2001
From: hanhanW 
Date: Mon, 8 Jan 2024 20:17:30 -0800
Subject: [PATCH 1/2] [mlir][tensor] Enhance pack/unpack simplication patterns
 for identity outer_dims_perm cases.

They can be simplified to reshape ops if outer_dims_perm is an identity
permutation. The revision adds a `isIdentityPermutation` method to
IndexingUtils.
---
 .../mlir/Dialect/Utils/IndexingUtils.h|  3 +++
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/PackAndUnpackPatterns.cpp  | 17 +
 mlir/lib/Dialect/Utils/IndexingUtils.cpp  |  8 +++
 .../Dialect/Tensor/simplify-pack-unpack.mlir  | 24 +++
 5 files changed, 48 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h 
b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548e..2453d841f633e4 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -228,6 +228,9 @@ void applyPermutationToVector(SmallVector &inVec,
 /// Helper method to apply to inverse a permutation.
 SmallVector invertPermutationVector(ArrayRef permutation);
 
+/// Returns true if `permutation` is an identity permutation.
+bool isIdentityPermutation(ArrayRef permutation);
+
 /// Method to check if an interchange vector is a permutation.
 bool isPermutationVector(ArrayRef interchange);
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt 
b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cbc0d499d9d52c..c6ef6ed86e0d9d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRArithUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRDialectUtils
   MLIRIR
   MLIRLinalgDialect
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp 
b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index cfd838e85c1b80..2853cb8fe77a3b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
@@ -38,8 +39,12 @@ struct SimplifyPackToExpandShape : public 
OpRewritePattern {
 if (packOp.getPaddingValue())
   return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-if (!packOp.getOuterDimsPerm().empty())
-  return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+auto outerDimsPerm = packOp.getOuterDimsPerm();
+if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+  return rewriter.notifyMatchFailure(
+  packOp,
+  "expects outer_dims_perm is empty or an identity permutation");
+}
 
 RankedTensorType sourceType = packOp.getSourceType();
 RankedTensorType destType = packOp.getDestType();
@@ -74,9 +79,11 @@ struct SimplifyUnPackToCollapseShape : public 
OpRewritePattern {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
 PatternRewriter &rewriter) const override {
-if (!unpackOp.getOuterDimsPerm().empty()) {
-  return rewriter.notifyMatchFailure(unpackOp,
- "expects no outer_dims_perm");
+auto outerDimsPerm = unpackOp.getOuterDimsPerm();
+if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+  return rewriter.notifyMatchFailure(
+  unpackOp,
+  "expects outer_dims_perm is empty or an identity permutation");
 }
 
 RankedTensorType sourceType = unpackOp.getSourceType();
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp 
b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index bb8a0d5912d7c1..f3de454dc4b81a 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -213,6 +213,14 @@ mlir::invertPermutationVector(ArrayRef 
permutation) {
   return inversion;
 }
 
+bool mlir::isIdentityPermutation(ArrayRef permutation) {
+  int n = permutation.size();
+  for (int i = 0; i < n; ++i)
+if (permutation[i] != i)
+  return false;
+  return true;
+}
+
 bool mlir::isPermutationVector(ArrayRef interchange) {
   assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
  "permutation must be non-negative");
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir 
b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index b78ab9bb3fd87e..ffbb2278a2e327 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tenso

[Lldb-commits] [libcxx] [libc] [lld] [clang-tools-extra] [libclc] [lldb] [llvm] [clang] [compiler-rt] [flang] [mlir] [libunwind] [mlir][tensor] Enhance pack/unpack simplification for identity outer_di

2024-01-10 Thread Han-Chung Wang via lldb-commits

https://github.com/hanhanW closed 
https://github.com/llvm/llvm-project/pull/77409
___
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits


[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

2023-10-13 Thread Han-Chung Wang via lldb-commits

https://github.com/hanhanW updated 
https://github.com/llvm/llvm-project/pull/68941

>From 877111a139b2f01037fdbe7c0cb120a2f4e64562 Mon Sep 17 00:00:00 2001
From: hanhanW 
Date: Thu, 12 Oct 2023 17:14:29 -0700
Subject: [PATCH 1/2] Reland "[mlir][arith] Canonicalization patterns for
 `arith.select` (#67809)"

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting
types to i1.
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 46 +++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp|  4 +-
 mlir/test/Dialect/Arith/canonicalize.mlir | 76 +++
 3 files changed, 125 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td 
b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..817de0e06c661b9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,52 @@ def CmpIExtUI :
 CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===--===//
+// SelectOp
+//===--===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), 
$a, $b),
+(SelectOp $pred, $b, $a),
+[(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+(SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+(SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, 
y)
+def SelectAndNotCond :
+Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+(SelectOp (Arith_AndIOp $predA,
+(Arith_XOrIOp $predB, (Arith_ConstantOp 
ConstantAttr))),
+  $x, $y),
+[(Constraint> $predB)]>;
+
+// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
+def SelectOrCond :
+Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
+(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
+
+// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
+def SelectOrNotCond :
+Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
+(SelectOp (Arith_OrIOp $predA,
+   (Arith_XOrIOp $predB, (Arith_ConstantOp 
ConstantAttr))),
+  $x, $y),
+[(Constraint> $predB)]>;
+
 
//===--===//
 // IndexCastOp
 
//===--===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp 
b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..0ecc288f3b07701 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public 
OpRewritePattern {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
   MLIRContext *context) {
-  results.add(context);
+  results.add(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir 
b/mlir/test/Dialect/Arith/canonicalize.mlir
index f697f3d01458eee..1b0547c9e8f804a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) 
-> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @redundantSelectTrue
+//   CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+//   CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : 
i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %0, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+//   CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+//   CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : 
i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %arg3, %0 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+//   CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1

[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

2023-10-13 Thread Han-Chung Wang via lldb-commits


@@ -233,6 +233,52 @@ def CmpIExtUI :
 CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===--===//
+// SelectOp
+//===--===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), 
$a, $b),
+(SelectOp $pred, $b, $a),
+[(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+(SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+(SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, 
y)
+def SelectAndNotCond :
+Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+(SelectOp (Arith_AndIOp $predA,
+(Arith_XOrIOp $predB, (Arith_ConstantOp 
ConstantAttr))),

hanhanW wrote:

done

https://github.com/llvm/llvm-project/pull/68941
___
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits


[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

2023-10-13 Thread Han-Chung Wang via lldb-commits

https://github.com/hanhanW edited 
https://github.com/llvm/llvm-project/pull/68941
___
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits


[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

2023-10-13 Thread Han-Chung Wang via lldb-commits

hanhanW wrote:

thanks for the review!

https://github.com/llvm/llvm-project/pull/68941
___
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits


[Lldb-commits] [lldb] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

2023-10-13 Thread Han-Chung Wang via lldb-commits

https://github.com/hanhanW closed 
https://github.com/llvm/llvm-project/pull/68941
___
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits