llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-cf

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Basic blocks with a `ub.unreachable` terminator are unreachable. This commit 
adds a canonicalization pattern that drops all preceding operations. This 
commit also adds a canonicalization pattern that folds to `cf.cond_br` to 
`cf.br` if one of the destination branches is unreachable.

Depends on #<!-- -->169872.


---
Full diff: https://github.com/llvm/llvm-project/pull/169873.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+4) 
- (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.td (+1) 
- (modified) mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+31-1) 
- (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+26) 
- (modified) mlir/test/Dialect/ControlFlow/canonicalize.mlir (+25) 
- (modified) mlir/test/Dialect/UB/canonicalize.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td 
b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..c9b4da44ffa01 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def ControlFlow_Dialect : Dialect {
   let name = "cf";
   let cppNamespace = "::mlir::cf";
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
   let description = [{
     This dialect contains low-level, i.e. non-region based, control flow
     constructs. These constructs generally represent control flow directly
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h 
b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..02081e2d6d15f 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -9,6 +9,10 @@
 #ifndef MLIR_DIALECT_UB_IR_OPS_H
 #define MLIR_DIALECT_UB_IR_OPS_H
 
+namespace mlir {
+class PatternRewriter;
+}
+
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td 
b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index 8a354da2db10c..c1d74290ec174 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
   }];
 
   let assemblyFormat = "attr-dict";
+  let hasCanonicalizeMethod = 1;
 }
 
 #endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt 
b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb435c86..05a787fa53ec3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
   MLIRControlFlowInterfaces
   MLIRIR
   MLIRSideEffectInterfaces
+  MLIRUBDialect
   )
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp 
b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a125e9ef..aabf8930cf78e 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public 
OpRewritePattern<CondBranchOp> {
     return success(replaced);
   }
 };
+
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp condbr,
+                                PatternRewriter &rewriter) const override {
+    // If the "true" destination has unreachable an unreachable terminator,
+    // always branch to the "false" destination.
+    Block *trueDest = condbr.getTrueDest();
+    Block *falseDest = condbr.getFalseDest();
+    if (llvm::hasSingleElement(*trueDest) &&
+        isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+      rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+                                            condbr.getFalseOperands());
+      return success();
+    }
+
+    // If the "false" destination has unreachable an unreachable terminator,
+    // always branch to the "true" destination.
+    if (llvm::hasSingleElement(*falseDest) &&
+        isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+      rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+                                            condbr.getTrueOperands());
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +482,7 @@ void 
CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
               SimplifyCondBranchFromCondBranchOnSameCondition,
-              CondBranchTruthPropagation>(context);
+              CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
 }
 
 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..419e3f9d76fb2 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder 
&builder, Attribute value,
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// PoisonOp
+//===----------------------------------------------------------------------===//
+
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+//===----------------------------------------------------------------------===//
+// UnreachableOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
+                                          PatternRewriter &rewriter) {
+  Block *block = unreachableOp->getBlock();
+  if (llvm::hasSingleElement(*block))
+    return rewriter.notifyMatchFailure(
+        unreachableOp, "unreachable op is the only operation in the block");
+
+  // Erase all other operations in the block. They must be dead.
+  for (Operation &op : llvm::make_early_inc_range(*block)) {
+    if (&op == unreachableOp.getOperation())
+      continue;
+    op.dropAllUses();
+    rewriter.eraseOp(&op);
+  }
+  return success();
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir 
b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 17f7d28ba59fb..75dec6dacde91 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) {
 ^bb7:
   cf.br ^bb6
 }
+
+// CHECK-LABEL: @drop_unreachable_branch_1
+//  CHECK-NEXT:   "test.foo"() : () -> ()
+//  CHECK-NEXT:   return
+func.func @drop_unreachable_branch_1(%c: i1) {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  "test.foo"() : () -> ()
+  return
+^bb2:
+  "test.bar"() : () -> ()
+  ub.unreachable
+}
+
+// CHECK-LABEL: @drop_unreachable_branch_2
+//  CHECK-NEXT:   ub.unreachable
+func.func @drop_unreachable_branch_2(%c: i1) {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  "test.foo"() : () -> ()
+  ub.unreachable
+^bb2:
+  "test.bar"() : () -> ()
+  ub.unreachable
+}
diff --git a/mlir/test/Dialect/UB/canonicalize.mlir 
b/mlir/test/Dialect/UB/canonicalize.mlir
index c3f286e49b09b..74ba9f1932384 100644
--- a/mlir/test/Dialect/UB/canonicalize.mlir
+++ b/mlir/test/Dialect/UB/canonicalize.mlir
@@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) {
   %1 = ub.poison : i32
   return %0, %1 : i32, i32
 }
+
+// -----
+
+// CHECK-LABEL: func @drop_ops_before_unreachable() 
+//  CHECK-NEXT:   ub.unreachable
+func.func @drop_ops_before_unreachable() {
+  "test.foo"() : () -> ()
+  "test.bar"() : () -> ()
+  ub.unreachable
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169873
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to