llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-cf Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Basic blocks with a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that drops all preceding operations. This commit also adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable. Depends on #<!-- -->169872. --- Full diff: https://github.com/llvm/llvm-project/pull/169873.diff 8 Files Affected: - (modified) mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td (+1-1) - (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+4) - (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.td (+1) - (modified) mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt (+1) - (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+31-1) - (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+26) - (modified) mlir/test/Dialect/ControlFlow/canonicalize.mlir (+25) - (modified) mlir/test/Dialect/UB/canonicalize.mlir (+10) ``````````diff diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index a441fd82546e3..c9b4da44ffa01 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def ControlFlow_Dialect : Dialect { let name = "cf"; let cppNamespace = "::mlir::cf"; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"]; let description = [{ This dialect contains low-level, i.e. non-region based, control flow constructs. These constructs generally represent control flow directly diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h index 21de5cb0c182a..02081e2d6d15f 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h @@ -9,6 +9,10 @@ #ifndef MLIR_DIALECT_UB_IR_OPS_H #define MLIR_DIALECT_UB_IR_OPS_H +namespace mlir { +class PatternRewriter; +} + #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td index 8a354da2db10c..c1d74290ec174 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td @@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> { }]; let assemblyFormat = "attr-dict"; + let hasCanonicalizeMethod = 1; } #endif // MLIR_DIALECT_UB_IR_UBOPS_TD diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb435c86..05a787fa53ec3 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a125e9ef..aabf8930cf78e 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { return success(replaced); } }; + +struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination has unreachable an unreachable terminator, + // always branch to the "false" destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa<ub::UnreachableOp>(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination has unreachable an unreachable terminator, + // always branch to the "true" destination. + if (llvm::hasSingleElement(*falseDest) && + isa<ub::UnreachableOp>(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +482,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, SimplifyCondBranchIdenticalSuccessors, SimplifyCondBranchFromCondBranchOnSameCondition, - CondBranchTruthPropagation>(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp index ee523f9522953..419e3f9d76fb2 100644 --- a/mlir/lib/Dialect/UB/IR/UBOps.cpp +++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc" @@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value, return nullptr; } +//===----------------------------------------------------------------------===// +// PoisonOp +//===----------------------------------------------------------------------===// + OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); } +//===----------------------------------------------------------------------===// +// UnreachableOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp, + PatternRewriter &rewriter) { + Block *block = unreachableOp->getBlock(); + if (llvm::hasSingleElement(*block)) + return rewriter.notifyMatchFailure( + unreachableOp, "unreachable op is the only operation in the block"); + + // Erase all other operations in the block. They must be dead. + for (Operation &op : llvm::make_early_inc_range(*block)) { + if (&op == unreachableOp.getOperation()) + continue; + op.dropAllUses(); + rewriter.eraseOp(&op); + } + return success(); +} + #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 17f7d28ba59fb..75dec6dacde91 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) { ^bb7: cf.br ^bb6 } + +// CHECK-LABEL: @drop_unreachable_branch_1 +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: return +func.func @drop_unreachable_branch_1(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + return +^bb2: + "test.bar"() : () -> () + ub.unreachable +} + +// CHECK-LABEL: @drop_unreachable_branch_2 +// CHECK-NEXT: ub.unreachable +func.func @drop_unreachable_branch_2(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + ub.unreachable +^bb2: + "test.bar"() : () -> () + ub.unreachable +} diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir index c3f286e49b09b..74ba9f1932384 100644 --- a/mlir/test/Dialect/UB/canonicalize.mlir +++ b/mlir/test/Dialect/UB/canonicalize.mlir @@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) { %1 = ub.poison : i32 return %0, %1 : i32, i32 } + +// ----- + +// CHECK-LABEL: func @drop_ops_before_unreachable() +// CHECK-NEXT: ub.unreachable +func.func @drop_ops_before_unreachable() { + "test.foo"() : () -> () + "test.bar"() : () -> () + ub.unreachable +} `````````` </details> https://github.com/llvm/llvm-project/pull/169873 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
