================ @@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { } }; +class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { +public: + using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; + + inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, + cir::YieldOp yieldOp, + mlir::Block *destination) const { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), + destination); + } + + // Return the new defaultDestination block. + Block *condBrToRangeDestination(cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + const APInt &lowerBound, + const APInt &upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + mlir::Block *resBlock = rewriter.createBlock(defaultDestination); + cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); + cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); + + cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + + cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::BinOp diffValue = + rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( + op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, + uDiffValue, uRangeLength); + rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, + defaultDestination); + return resBlock; + } + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector<CaseOp> cases; + op.collectCases(cases); + + // Empty switch statement: just erase it. + if (cases.empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + // Create exit block from the next node of cir.switch op. + mlir::Block *exitBlock = rewriter.splitBlock( + rewriter.getBlock(), op->getNextNode()->getIterator()); + + // We lower cir.switch op in the following process: + // 1. Inline the region from the switch op after switch op. + // 2. Traverse each cir.case op: + // a. Record the entry block, block arguments and condition for every + // case. b. Inline the case region after the case op. + // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the + // recorded block and conditions. + + // inline everything from switch body between the switch op and the exit + // block. + { + cir::YieldOp switchYield = nullptr; + // Clear switch operation. + for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) + if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) + switchYield = yieldOp; + + assert(!op.getBody().empty()); + mlir::Block *originalBlock = op->getBlock(); + mlir::Block *swopBlock = + rewriter.splitBlock(originalBlock, op->getIterator()); + rewriter.inlineRegionBefore(op.getBody(), exitBlock); + + if (switchYield) + rewriteYieldOp(rewriter, switchYield, exitBlock); + + rewriter.setInsertionPointToEnd(originalBlock); + rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + } + + // Allocate required data structures (disconsider default case in + // vectors). + llvm::SmallVector<mlir::APInt, 8> caseValues; + llvm::SmallVector<mlir::Block *, 8> caseDestinations; + llvm::SmallVector<mlir::ValueRange, 8> caseOperands; + + llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; + llvm::SmallVector<mlir::Block *> rangeDestinations; + llvm::SmallVector<mlir::ValueRange> rangeOperands; + + // Initialize default case as optional. + mlir::Block *defaultDestination = exitBlock; + mlir::ValueRange defaultOperands = exitBlock->getArguments(); + + // Digest the case statements values and bodies. + for (auto caseOp : cases) { ---------------- andykaylor wrote:
```suggestion for (cir::CaseOp caseOp : cases) { ``` https://github.com/llvm/llvm-project/pull/139154 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits