================
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> 
{
   }
 };
 
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+///   cir.case equal, [#cir.int<0> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<1> : !s32i] {
+///     cir.yield
+///   }
+///   cir.case equal, [#cir.int<2> : !s32i] {
+///     cir.break
+///   }
+///
+/// After applying SimplifySwitch:
+///   cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+///   !s32i] {
+///     cir.break
+///   }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+  using OpRewritePattern<SwitchOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(SwitchOp op,
+                                PatternRewriter &rewriter) const override {
+
+    LogicalResult changed = mlir::failure();
+    llvm::SmallVector<CaseOp, 8> cases;
+    SmallVector<CaseOp, 4> cascadingCases;
+    SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+    op.collectCases(cases);
+    if (cases.empty())
+      return mlir::failure();
+
+    auto flushMergedOps = [&]() {
+      for (CaseOp &c : cascadingCases) {
+        rewriter.eraseOp(c);
+      }
+      cascadingCases.clear();
+      cascadingCaseValues.clear();
+    };
+
+    auto mergeCascadingInto = [&](CaseOp &target) {
+      rewriter.modifyOpInPlace(target, [&]() {
+        target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+        target.setKind(CaseOpKind::Anyof);
+      });
+      changed = mlir::success();
+    };
+
+    for (CaseOp c : cases) {
+      cir::CaseOpKind kind = c.getKind();
+      if (kind == cir::CaseOpKind::Equal &&
+          isa<YieldOp>(c.getCaseRegion().front().front())) {
+        // If the case contains only a YieldOp, collect it for cascading merge
+        cascadingCases.push_back(c);
+        cascadingCaseValues.push_back(c.getValue()[0]);
+
+      } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+        // merge previously collected cascading cases
+        cascadingCaseValues.push_back(c.getValue()[0]);
+        mergeCascadingInto(c);
+        flushMergedOps();
+      } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+        // If a Default, Anyof or Range case is found and there are previous
+        // cascading cases, merge all of them into the last cascading case.
+        CaseOp lastCascadingCase = cascadingCases.back();
+        mergeCascadingInto(lastCascadingCase);
+        cascadingCases.pop_back();
+        flushMergedOps();
+      } else {
+        cascadingCases.clear();
----------------
andykaylor wrote:

Yes, you're right. I had an 'off-by-one' error in my mental processing of the 
`cascadingCases.size() > 1` condition.

https://github.com/llvm/llvm-project/pull/140649
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to