================
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public 
mlir::OpRewritePattern<cir::ScopeOp> {
   }
 };
 
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+  using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+  inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+                             cir::YieldOp yieldOp,
+                             mlir::Block *destination) const {
+    rewriter.setInsertionPoint(yieldOp);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+                                           destination);
+  }
+
+  // Return the new defaultDestination block.
+  Block *condBrToRangeDestination(cir::SwitchOp op,
+                                  mlir::PatternRewriter &rewriter,
+                                  mlir::Block *rangeDestination,
+                                  mlir::Block *defaultDestination,
+                                  const APInt &lowerBound,
+                                  const APInt &upperBound) const {
+    assert(lowerBound.sle(upperBound) && "Invalid range");
+    mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+    cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+    cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+    cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+    cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+    cir::BinOp diffValue =
+        rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+                                    op.getCondition(), lowerBoundValue);
+
+    // Use unsigned comparison to check if the condition is in the range.
+    cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, diffValue);
+    cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+    cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+        op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+        uDiffValue, uRangeLength);
+    rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+                                   defaultDestination);
+    return resBlock;
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::SwitchOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    llvm::SmallVector<CaseOp> cases;
+    op.collectCases(cases);
+
+    // Empty switch statement: just erase it.
+    if (cases.empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Create exit block from the next node of cir.switch op.
+    mlir::Block *exitBlock = rewriter.splitBlock(
+        rewriter.getBlock(), op->getNextNode()->getIterator());
+
+    // We lower cir.switch op in the following process:
+    // 1. Inline the region from the switch op after switch op.
+    // 2. Traverse each cir.case op:
+    //    a. Record the entry block, block arguments and condition for every
+    //    case. b. Inline the case region after the case op.
+    // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+    //    recorded block and conditions.
+
+    // inline everything from switch body between the switch op and the exit
+    // block.
+    {
+      cir::YieldOp switchYield = nullptr;
+      // Clear switch operation.
+      for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
----------------
andykaylor wrote:

```suggestion
      for (mlir::Block &block : 
llvm::make_early_inc_range(op.getBody().getBlocks()))
```

https://github.com/llvm/llvm-project/pull/139154
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to