================ @@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> { } }; +/// Simplify `cir.switch` operations by folding cascading cases +/// into a single `cir.case` with the `anyof` kind. +/// +/// This pattern identifies cascading cases within a `cir.switch` operation. +/// Cascading cases are defined as consecutive `cir.case` operations of kind +/// `equal`, each containing a single `cir.yield` operation in their body. +/// +/// The pattern merges these cascading cases into a single `cir.case` operation +/// with kind `anyof`, aggregating all the case values. +/// +/// The merging process continues until a `cir.case` with a different body +/// (e.g., containing `cir.break` or compound stmt) is encountered, which +/// breaks the chain. +/// +/// Example: +/// +/// Before: +/// cir.case equal, [#cir.int<0> : !s32i] { +/// cir.yield +/// } +/// cir.case equal, [#cir.int<1> : !s32i] { +/// cir.yield +/// } +/// cir.case equal, [#cir.int<2> : !s32i] { +/// cir.break +/// } +/// +/// After applying SimplifySwitch: +/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : +/// !s32i] { +/// cir.break +/// } +struct SimplifySwitch : public OpRewritePattern<SwitchOp> { + using OpRewritePattern<SwitchOp>::OpRewritePattern; + LogicalResult matchAndRewrite(SwitchOp op, + PatternRewriter &rewriter) const override { + + LogicalResult changed = mlir::failure(); + llvm::SmallVector<CaseOp, 8> cases; + SmallVector<CaseOp, 4> cascadingCases; + SmallVector<mlir::Attribute, 4> cascadingCaseValues; + + op.collectCases(cases); + if (cases.empty()) + return mlir::failure(); + + auto flushMergedOps = [&]() { + for (CaseOp &c : cascadingCases) { + rewriter.eraseOp(c); + } + cascadingCases.clear(); + cascadingCaseValues.clear(); + }; + + auto mergeCascadingInto = [&](CaseOp &target) { + rewriter.modifyOpInPlace(target, [&]() { + target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues)); + target.setKind(CaseOpKind::Anyof); + }); + changed = mlir::success(); + }; + + for (CaseOp c : cases) { + cir::CaseOpKind kind = c.getKind(); + if (kind == cir::CaseOpKind::Equal && + isa<YieldOp>(c.getCaseRegion().front().front())) { + // If the case contains only a YieldOp, collect it for cascading merge + cascadingCases.push_back(c); + cascadingCaseValues.push_back(c.getValue()[0]); + + } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { + // merge previously collected cascading cases + cascadingCaseValues.push_back(c.getValue()[0]); + mergeCascadingInto(c); + flushMergedOps(); + } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { + // If a Default, Anyof or Range case is found and there are previous + // cascading cases, merge all of them into the last cascading case. + CaseOp lastCascadingCase = cascadingCases.back(); + mergeCascadingInto(lastCascadingCase); + cascadingCases.pop_back(); + flushMergedOps(); + } else { + cascadingCases.clear(); ---------------- Andres-Salamanca wrote:
we can also reach it in cases like this: ```mlir cir.case(equal, [#cir.int<1> : !s32i]) { cir.yield } cir.case(default, []) { cir.break } ``` In this scenario, we still need to clear the vectors explicitly because we can't fold the `equal` case into the `default` one. That's also why I check `cascadingCases.size() > 1` to ensure there's more than one case to fold. https://github.com/llvm/llvm-project/pull/140649 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits