================
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public 
mlir::OpRewritePattern<cir::ScopeOp> {
   }
 };
 
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+  using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+  inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+                             cir::YieldOp yieldOp,
+                             mlir::Block *destination) const {
+    rewriter.setInsertionPoint(yieldOp);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+                                           destination);
+  }
+
+  // Return the new defaultDestination block.
+  Block *condBrToRangeDestination(cir::SwitchOp op,
+                                  mlir::PatternRewriter &rewriter,
+                                  mlir::Block *rangeDestination,
+                                  mlir::Block *defaultDestination,
+                                  const APInt &lowerBound,
+                                  const APInt &upperBound) const {
+    assert(lowerBound.sle(upperBound) && "Invalid range");
+    mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+    cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+    cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+    cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+    cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+        op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+    cir::BinOp diffValue =
+        rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+                                    op.getCondition(), lowerBoundValue);
+
+    // Use unsigned comparison to check if the condition is in the range.
+    cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, diffValue);
+    cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+        op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+    cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+        op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+        uDiffValue, uRangeLength);
+    rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+                                   defaultDestination);
+    return resBlock;
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::SwitchOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    llvm::SmallVector<CaseOp> cases;
+    op.collectCases(cases);
+
+    // Empty switch statement: just erase it.
+    if (cases.empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Create exit block from the next node of cir.switch op.
+    mlir::Block *exitBlock = rewriter.splitBlock(
+        rewriter.getBlock(), op->getNextNode()->getIterator());
+
+    // We lower cir.switch op in the following process:
+    // 1. Inline the region from the switch op after switch op.
+    // 2. Traverse each cir.case op:
+    //    a. Record the entry block, block arguments and condition for every
+    //    case. b. Inline the case region after the case op.
+    // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+    //    recorded block and conditions.
+
+    // inline everything from switch body between the switch op and the exit
+    // block.
+    {
+      cir::YieldOp switchYield = nullptr;
+      // Clear switch operation.
+      for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+          switchYield = yieldOp;
+
+      assert(!op.getBody().empty());
+      mlir::Block *originalBlock = op->getBlock();
+      mlir::Block *swopBlock =
+          rewriter.splitBlock(originalBlock, op->getIterator());
+      rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+      if (switchYield)
+        rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+      rewriter.setInsertionPointToEnd(originalBlock);
+      rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+    }
+
+    // Allocate required data structures (disconsider default case in
+    // vectors).
+    llvm::SmallVector<mlir::APInt, 8> caseValues;
+    llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+    llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+    llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+    llvm::SmallVector<mlir::Block *> rangeDestinations;
+    llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+    // Initialize default case as optional.
+    mlir::Block *defaultDestination = exitBlock;
+    mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+    // Digest the case statements values and bodies.
+    for (auto caseOp : cases) {
+      mlir::Region &region = caseOp.getCaseRegion();
+
+      // Found default case: save destination and operands.
+      switch (caseOp.getKind()) {
+      case cir::CaseOpKind::Default:
+        defaultDestination = &region.front();
+        defaultOperands = defaultDestination->getArguments();
+        break;
+      case cir::CaseOpKind::Range:
+        assert(caseOp.getValue().size() == 2 &&
+               "Case range should have 2 case value");
+        rangeValues.push_back(
+            {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+             cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+        rangeDestinations.push_back(&region.front());
+        rangeOperands.push_back(rangeDestinations.back()->getArguments());
+        break;
+      case cir::CaseOpKind::Anyof:
+      case cir::CaseOpKind::Equal:
+        // AnyOf cases kind can have multiple values, hence the loop below.
+        for (auto &value : caseOp.getValue()) {
+          caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+          caseDestinations.push_back(&region.front());
+          caseOperands.push_back(caseDestinations.back()->getArguments());
+        }
+        break;
+      }
+
+      // Handle break statements.
+      walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+          region, [&](mlir::Operation *op) {
+            if (!isa<cir::BreakOp>(op))
+              return mlir::WalkResult::advance();
+
+            lowerTerminator(op, exitBlock, rewriter);
+            return mlir::WalkResult::skip();
+          });
+
+      // Track fallthrough in cases.
+      for (auto &blk : region.getBlocks()) {
+        if (blk.getNumSuccessors())
+          continue;
+
+        if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+          mlir::Operation *nextOp = caseOp->getNextNode();
+          assert(nextOp && "caseOp is not expected to be the last op");
+          mlir::Block *oldBlock = nextOp->getBlock();
+          mlir::Block *newBlock =
+              rewriter.splitBlock(oldBlock, nextOp->getIterator());
+          rewriter.setInsertionPointToEnd(oldBlock);
+          rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+                                     newBlock);
+          rewriteYieldOp(rewriter, yieldOp, newBlock);
+        }
+      }
+
+      mlir::Block *oldBlock = caseOp->getBlock();
+      mlir::Block *newBlock =
+          rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+      mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+      rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+      // Create a branch to the entry of the inlined region.
+      rewriter.setInsertionPointToEnd(oldBlock);
+      rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+    }
+
+    // Remove all cases since we've inlined the regions.
+    for (auto caseOp : cases) {
+      mlir::Block *caseBlock = caseOp->getBlock();
+      // Erase the block with no predecessors here to make the generated code
+      // simpler a little bit.
+      if (caseBlock->hasNoPredecessors())
+        rewriter.eraseBlock(caseBlock);
+      else
+        rewriter.eraseOp(caseOp);
+    }
+
+    for (size_t index = 0; index < rangeValues.size(); ++index) {
----------------
Andres-Salamanca wrote:

Sorry, just noticed we can use llvm::zip here  I’ve made the change 🙂

https://github.com/llvm/llvm-project/pull/139154
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to