https://github.com/Andres-Salamanca updated https://github.com/llvm/llvm-project/pull/139154
>From f77464dbe1de51c89fbde5e5decfc2314cd7dba6 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Thu, 8 May 2025 15:39:49 -0500 Subject: [PATCH 1/5] Add support for FlattenCFG switch and introduce SwitchFlatOp --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 46 +++ clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 97 ++++++ .../Dialect/Transforms/CIRCanonicalize.cpp | 15 +- .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 235 ++++++++++++++- clang/test/CIR/IR/switch-flat.cir | 68 +++++ clang/test/CIR/Transforms/switch.cir | 278 ++++++++++++++++++ 6 files changed, 734 insertions(+), 5 deletions(-) create mode 100644 clang/test/CIR/IR/switch-flat.cir create mode 100644 clang/test/CIR/Transforms/switch.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 7ffa10464dcd3..914af6d1dc6bd 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch", }]; } +//===----------------------------------------------------------------------===// +// SwitchFlatOp +//===----------------------------------------------------------------------===// + +def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, + Terminator]> { + + let description = [{ + The `cir.switch.flat` operation is a region-less and simplified + version of the `cir.switch`. + It's representation is closer to LLVM IR dialect + than the C/C++ language feature. + }]; + + let arguments = (ins + CIR_IntType:$condition, + Variadic<AnyType>:$defaultOperands, + VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, + ArrayAttr:$case_values, + DenseI32ArrayAttr:$case_operand_segments + ); + + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor<AnySuccessor>:$caseDestinations + ); + + let assemblyFormat = [{ + $condition `:` type($condition) `,` + $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? + custom<SwitchFlatOpCases>(ref(type($condition)), $case_values, + $caseDestinations, $caseOperands, + type($caseOperands)) + attr-dict + }]; + + let builders = [ + OpBuilder<(ins "mlir::Value":$condition, + "mlir::Block *":$defaultDestination, + "mlir::ValueRange":$defaultOperands, + CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues, + CArg<"mlir::BlockRange", "{}">:$caseDestinations, + CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)> + ]; +} + //===----------------------------------------------------------------------===// // BrOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index b131edaf403ed..ca03013edb485 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -22,6 +22,7 @@ #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc" #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc" #include "clang/CIR/MissingFeatures.h" +#include <numeric> using namespace mlir; using namespace cir; @@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) { }); } +//===----------------------------------------------------------------------===// +// SwitchFlatOp +//===----------------------------------------------------------------------===// + +void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result, + Value value, Block *defaultDestination, + ValueRange defaultOperands, + ArrayRef<APInt> caseValues, + BlockRange caseDestinations, + ArrayRef<ValueRange> caseOperands) { + + std::vector<mlir::Attribute> caseValuesAttrs; + for (auto &val : caseValues) { + caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val)); + } + mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs); + + build(builder, result, value, defaultOperands, caseOperands, attrs, + defaultDestination, caseDestinations); +} + +/// <cases> ::= `[` (case (`,` case )* )? `]` +/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? +static ParseResult parseSwitchFlatOpCases( + OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, + SmallVectorImpl<Block *> &caseDestinations, + SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>> + &caseOperands, + SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) { + if (failed(parser.parseLSquare())) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + llvm::SmallVector<mlir::Attribute> values; + + auto parseCase = [&]() { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + + values.push_back(cir::IntAttr::get(flagType, value)); + + Block *destination; + llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands; + llvm::SmallVector<Type> operandTypes; + if (parser.parseColon() || parser.parseSuccessor(destination)) + return failure(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(operandTypes) || parser.parseRParen()) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); + return success(); + }; + if (failed(parser.parseCommaSeparatedList(parseCase))) + return failure(); + + caseValues = ArrayAttr::get(flagType.getContext(), values); + + return parser.parseRSquare(); +} + +static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, + Type flagType, mlir::ArrayAttr caseValues, + SuccessorRange caseDestinations, + OperandRangeRange caseOperands, + const TypeRangeRange &caseOperandTypes) { + p << '['; + p.printNewline(); + if (!caseValues) { + p << ']'; + return; + } + + size_t index = 0; + llvm::interleave( + llvm::zip(caseValues, caseDestinations), + [&](auto i) { + p << " "; + mlir::Attribute a = std::get<0>(i); + p << mlir::cast<cir::IntAttr>(a).getValue(); + p << ": "; + p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); + }, + [&] { + p << ','; + p.printNewline(); + }); + p.printNewline(); + p << ']'; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index 3b4c7bc613133..edbb848322d41 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> { } }; +struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> { + using OpRewritePattern<SwitchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(SwitchOp op, + PatternRewriter &rewriter) const final { + if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front()))) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // CIRCanonicalizePass //===----------------------------------------------------------------------===// @@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() { assert(!cir::MissingFeatures::callOp()); // CastOp and UnaryOp are here to perform a manual `fold` in // applyOpPatternsGreedily. - if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op)) + if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 4a936d33b022a..70f383b556567 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { } }; +class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { +public: + using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; + + inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, + cir::YieldOp yieldOp, + mlir::Block *destination) const { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), + destination); + } + + // Return the new defaultDestination block. + Block *condBrToRangeDestination(cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + const APInt &lowerBound, + const APInt &upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + mlir::Block *resBlock = rewriter.createBlock(defaultDestination); + cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); + cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); + + cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + + cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::BinOp diffValue = + rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( + op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, + uDiffValue, uRangeLength); + rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, + defaultDestination); + return resBlock; + } + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector<CaseOp> cases; + op.collectCases(cases); + + // Empty switch statement: just erase it. + if (cases.empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + // Create exit block from the next node of cir.switch op. + mlir::Block *exitBlock = rewriter.splitBlock( + rewriter.getBlock(), op->getNextNode()->getIterator()); + + // We lower cir.switch op in the following process: + // 1. Inline the region from the switch op after switch op. + // 2. Traverse each cir.case op: + // a. Record the entry block, block arguments and condition for every + // case. b. Inline the case region after the case op. + // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the + // recorded block and conditions. + + // inline everything from switch body between the switch op and the exit + // block. + { + cir::YieldOp switchYield = nullptr; + // Clear switch operation. + for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) + if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) + switchYield = yieldOp; + + assert(!op.getBody().empty()); + mlir::Block *originalBlock = op->getBlock(); + mlir::Block *swopBlock = + rewriter.splitBlock(originalBlock, op->getIterator()); + rewriter.inlineRegionBefore(op.getBody(), exitBlock); + + if (switchYield) + rewriteYieldOp(rewriter, switchYield, exitBlock); + + rewriter.setInsertionPointToEnd(originalBlock); + rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + } + + // Allocate required data structures (disconsider default case in + // vectors). + llvm::SmallVector<mlir::APInt, 8> caseValues; + llvm::SmallVector<mlir::Block *, 8> caseDestinations; + llvm::SmallVector<mlir::ValueRange, 8> caseOperands; + + llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; + llvm::SmallVector<mlir::Block *> rangeDestinations; + llvm::SmallVector<mlir::ValueRange> rangeOperands; + + // Initialize default case as optional. + mlir::Block *defaultDestination = exitBlock; + mlir::ValueRange defaultOperands = exitBlock->getArguments(); + + // Digest the case statements values and bodies. + for (auto caseOp : cases) { + mlir::Region ®ion = caseOp.getCaseRegion(); + + // Found default case: save destination and operands. + switch (caseOp.getKind()) { + case cir::CaseOpKind::Default: + defaultDestination = ®ion.front(); + defaultOperands = defaultDestination->getArguments(); + break; + case cir::CaseOpKind::Range: + assert(caseOp.getValue().size() == 2 && + "Case range should have 2 case value"); + rangeValues.push_back( + {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(), + cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()}); + rangeDestinations.push_back(®ion.front()); + rangeOperands.push_back(rangeDestinations.back()->getArguments()); + break; + case cir::CaseOpKind::Anyof: + case cir::CaseOpKind::Equal: + // AnyOf cases kind can have multiple values, hence the loop below. + for (auto &value : caseOp.getValue()) { + caseValues.push_back(cast<cir::IntAttr>(value).getValue()); + caseDestinations.push_back(®ion.front()); + caseOperands.push_back(caseDestinations.back()->getArguments()); + } + break; + } + + // Handle break statements. + walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>( + region, [&](mlir::Operation *op) { + if (!isa<cir::BreakOp>(op)) + return mlir::WalkResult::advance(); + + lowerTerminator(op, exitBlock, rewriter); + return mlir::WalkResult::skip(); + }); + + // Track fallthrough in cases. + for (auto &blk : region.getBlocks()) { + if (blk.getNumSuccessors()) + continue; + + if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) { + mlir::Operation *nextOp = caseOp->getNextNode(); + assert(nextOp && "caseOp is not expected to be the last op"); + mlir::Block *oldBlock = nextOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, nextOp->getIterator()); + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), + newBlock); + rewriteYieldOp(rewriter, yieldOp, newBlock); + } + } + + mlir::Block *oldBlock = caseOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, caseOp->getIterator()); + + mlir::Block &entryBlock = caseOp.getCaseRegion().front(); + rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock); + + // Create a branch to the entry of the inlined region. + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); + } + + // Remove all cases since we've inlined the regions. + for (auto caseOp : cases) { + mlir::Block *caseBlock = caseOp->getBlock(); + // Erase the block with no predecessors here to make the generated code + // simpler a little bit. + if (caseBlock->hasNoPredecessors()) + rewriter.eraseBlock(caseBlock); + else + rewriter.eraseOp(caseOp); + } + + for (size_t index = 0; index < rangeValues.size(); ++index) { + APInt lowerBound = rangeValues[index].first; + APInt upperBound = rangeValues[index].second; + + // The case range is unreachable, skip it. + if (lowerBound.sgt(upperBound)) + continue; + + // If range is small, add multiple switch instruction cases. + // This magical number is from the original CGStmt code. + constexpr int kSmallRangeThreshold = 64; + if ((upperBound - lowerBound) + .ult(llvm::APInt(32, kSmallRangeThreshold))) { + for (APInt iValue = lowerBound; iValue.sle(upperBound); + (void)iValue++) { + caseValues.push_back(iValue); + caseOperands.push_back(rangeOperands[index]); + caseDestinations.push_back(rangeDestinations[index]); + } + continue; + } + + defaultDestination = + condBrToRangeDestination(op, rewriter, rangeDestinations[index], + defaultDestination, lowerBound, upperBound); + defaultOperands = rangeOperands[index]; + } + + // Set switch op to branch to the newly created blocks. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>( + op, op.getCondition(), defaultDestination, defaultOperands, caseValues, + caseDestinations, caseOperands); + + return mlir::success(); + } +}; + class CIRLoopOpInterfaceFlattening : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> { public: @@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { }; void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, - CIRScopeOpFlattening, CIRTernaryOpFlattening>( - patterns.getContext()); + patterns + .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening, + CIRSwitchOpFlattening, CIRTernaryOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op)) + if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir new file mode 100644 index 0000000000000..b072c224b4a2c --- /dev/null +++ b/clang/test/CIR/IR/switch-flat.cir @@ -0,0 +1,68 @@ +// RUN: cir-opt %s | FileCheck %s +!s32i = !cir.int<s, 32> + +cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb2 [ + 1: ^bb1 + ] + ^bb1: + cir.br ^bb2 + ^bb2: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [ +// CHECK-NEXT: 1: ^bb1 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb2 +// CHECK-NEXT: ^bb2: +//CHECK-NEXT: cir.return + +cir.func @FlatSwitchWithDefault(%arg0: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb2 [ + 1: ^bb1 + ] + ^bb1: + cir.br ^bb3 + ^bb2: + cir.br ^bb3 + ^bb3: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [ +// CHECK-NEXT: 1: ^bb1 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb2: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: cir.return + +cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb3 [ + 0: ^bb1(%arg1, %arg2 : !s32i, !s32i), + 1: ^bb2(%arg2, %arg1 : !s32i, !s32i) + ] +^bb1: + cir.br ^bb3 + +^bb2: + cir.br ^bb3 + +^bb3: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [ +// CHECK-NEXT: 0: ^bb1(%arg1, %arg2 : !s32i, !s32i), +// CHECK-NEXT: 1: ^bb2(%arg2, %arg1 : !s32i, !s32i) +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb2: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: cir.return diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir new file mode 100644 index 0000000000000..a05cf37e39728 --- /dev/null +++ b/clang/test/CIR/Transforms/switch.cir @@ -0,0 +1,278 @@ +// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s + +!s8i = !cir.int<s, 8> +!s32i = !cir.int<s, 32> +!s64i = !cir.int<s, 64> + +module { + cir.func @shouldFlatSwitchWithDefault(%arg0: !s8i) { + cir.switch (%arg0 : !s8i) { + cir.case (equal, [#cir.int<1> : !s8i]) { + cir.break + } + cir.case (default, []) { + cir.break + } + cir.yield + } + cir.return + } +// CHECK: cir.func @shouldFlatSwitchWithDefault(%arg0: !s8i) { +// CHECK: cir.switch.flat %arg0 : !s8i, ^bb[[#DEFAULT:]] [ +// CHECK: 1: ^bb[[#CASE1:]] +// CHECK: ] +// CHECK: ^bb[[#CASE1]]: +// CHECK: cir.br ^bb[[#EXIT:]] +// CHECK: ^bb[[#DEFAULT]]: +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: +// CHECK: cir.return +// CHECK: } + + cir.func @shouldFlatSwitchWithoutDefault(%arg0: !s32i) { + cir.switch (%arg0 : !s32i) { + cir.case (equal, [#cir.int<1> : !s32i]) { + cir.break + } + cir.yield + } + cir.return + } +// CHECK: cir.func @shouldFlatSwitchWithoutDefault(%arg0: !s32i) { +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb[[#EXIT:]] [ +// CHECK: 1: ^bb[[#CASE1:]] +// CHECK: ] +// CHECK: ^bb[[#CASE1]]: +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: +// CHECK: cir.return +// CHECK: } + + + cir.func @shouldFlatSwitchWithImplicitFallthrough(%arg0: !s64i) { + cir.switch (%arg0 : !s64i) { + cir.case (anyof, [#cir.int<1> : !s64i, #cir.int<2> : !s64i]) { + cir.break + } + cir.yield + } + cir.return + } +// CHECK: cir.func @shouldFlatSwitchWithImplicitFallthrough(%arg0: !s64i) { +// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [ +// CHECK: 1: ^bb[[#CASE1N2:]], +// CHECK: 2: ^bb[[#CASE1N2]] +// CHECK: ] +// CHECK: ^bb[[#CASE1N2]]: +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: +// CHECK: cir.return +// CHECK: } + + + + cir.func @shouldFlatSwitchWithExplicitFallthrough(%arg0: !s64i) { + cir.switch (%arg0 : !s64i) { + cir.case (equal, [#cir.int<1> : !s64i]) { // case 1 has its own region + cir.yield // fallthrough to case 2 + } + cir.case (equal, [#cir.int<2> : !s64i]) { + cir.break + } + cir.yield + } + cir.return + } +// CHECK: cir.func @shouldFlatSwitchWithExplicitFallthrough(%arg0: !s64i) { +// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [ +// CHECK: 1: ^bb[[#CASE1:]], +// CHECK: 2: ^bb[[#CASE2:]] +// CHECK: ] +// CHECK: ^bb[[#CASE1]]: +// CHECK: cir.br ^bb[[#CASE2]] +// CHECK: ^bb[[#CASE2]]: +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: +// CHECK: cir.return +// CHECK: } + + cir.func @shouldFlatSwitchWithFallthroughToExit(%arg0: !s64i) { + cir.switch (%arg0 : !s64i) { + cir.case (equal, [#cir.int<1> : !s64i]) { + cir.yield // fallthrough to exit + } + cir.yield + } + cir.return + } +// CHECK: cir.func @shouldFlatSwitchWithFallthroughToExit(%arg0: !s64i) { +// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [ +// CHECK: 1: ^bb[[#CASE1:]] +// CHECK: ] +// CHECK: ^bb[[#CASE1]]: +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: +// CHECK: cir.return +// CHECK: } + + cir.func @shouldDropEmptySwitch(%arg0: !s64i) { + cir.switch (%arg0 : !s64i) { + cir.yield + } + // CHECK-NOT: llvm.switch + cir.return + } +// CHECK: cir.func @shouldDropEmptySwitch(%arg0: !s64i) +// CHECK-NOT: cir.switch.flat + + + cir.func @shouldFlatMultiBlockCase(%arg0: !s32i) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case (equal, [#cir.int<3> : !s32i]) { + cir.return + ^bb1: // no predecessors + cir.break + } + cir.yield + } + } + cir.return + } + +// CHECK: cir.func @shouldFlatMultiBlockCase(%arg0: !s32i) { +// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} +// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> +// CHECK: cir.br ^bb1 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CHECK: cir.switch.flat %1 : !s32i, ^bb[[#DEFAULT:]] [ +// CHECK: 3: ^bb[[#BB1:]] +// CHECK: ] +// CHECK: ^bb[[#BB1]]: +// CHECK: cir.return +// CHECK: ^bb[[#DEFAULT]]: +// CHECK: cir.br ^bb[[#RET_BB:]] +// CHECK: ^bb[[#RET_BB]]: // pred: ^bb[[#DEFAULT]] +// CHECK: cir.return +// CHECK: } + + + cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} + %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.store %arg1, %1 : !s32i, !cir.ptr<!s32i> + cir.scope { + %5 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%5 : !s32i) { + cir.case (equal, [#cir.int<0> : !s32i]) { + cir.scope { + %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i + %7 = cir.const #cir.int<0> : !s32i + %8 = cir.cmp(ge, %6, %7) : !s32i, !cir.bool + cir.if %8 { + cir.break + } + } + cir.break + } + cir.yield + } + } + %3 = cir.const #cir.int<3> : !s32i + cir.store %3, %2 : !s32i, !cir.ptr<!s32i> + %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i + cir.return %4 : !s32i + } +// CHECK: cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i { +// CHECK: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#DEFAULT_BB:]] [ +// CHECK: 0: ^bb[[#BB1:]] +// CHECK: ] +// CHECK: ^bb[[#BB1]]: +// CHECK: cir.br ^bb[[#COND_BB:]] +// CHECK: ^bb[[#COND_BB]]: +// CHECK: cir.brcond {{%.*}} ^bb[[#TRUE_BB:]], ^bb[[#FALSE_BB:]] +// CHECK: ^bb[[#TRUE_BB]]: +// CHECK: cir.br ^bb[[#DEFAULT_BB]] +// CHECK: ^bb[[#FALSE_BB]]: +// CHECK: cir.br ^bb[[#PRED_BB:]] +// CHECK: ^bb[[#PRED_BB]]: +// CHECK: cir.br ^bb[[#DEFAULT_BB]] +// CHECK: ^bb[[#DEFAULT_BB]]: +// CHECK: cir.br ^bb[[#RET_BB:]] +// CHECK: ^bb[[#RET_BB]]: +// CHECK: cir.return +// CHECK: } + + + cir.func @flatCaseRange(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} + %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + %3 = cir.const #cir.int<0> : !s32i + cir.store %3, %2 : !s32i, !cir.ptr<!s32i> + cir.scope { + %6 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%6 : !s32i) { + cir.case (equal, [#cir.int<-100> : !s32i]) { + %7 = cir.const #cir.int<1> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr<!s32i> + cir.break + } + cir.case (range, [#cir.int<1> : !s32i, #cir.int<100> : !s32i]) { + %7 = cir.const #cir.int<2> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr<!s32i> + cir.break + } + cir.case (default, []) { + %7 = cir.const #cir.int<3> : !s32i + cir.store %7, %2 : !s32i, !cir.ptr<!s32i> + cir.break + } + cir.yield + } + } + %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i + cir.store %4, %1 : !s32i, !cir.ptr<!s32i> + %5 = cir.load %1 : !cir.ptr<!s32i>, !s32i + cir.return %5 : !s32i + } +// CHECK: cir.func @flatCaseRange(%arg0: !s32i) -> !s32i { +// CHECK: cir.switch.flat %[[X:[0-9]+]] : !s32i, ^[[JUDGE_RANGE:bb[0-9]+]] [ +// CHECK-NEXT: -100: ^[[CASE_EQUAL:bb[0-9]+]] +// CHECK-NEXT: ] +// CHECK-NEXT: ^[[UNRACHABLE_BB:.+]]: // no predecessors +// CHECK-NEXT: cir.br ^[[CASE_EQUAL]] +// CHECK-NEXT: ^[[CASE_EQUAL]]: +// CHECK-NEXT: cir.int<1> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG:bb[0-9]+]] +// CHECK-NEXT: ^[[CASE_RANGE:bb[0-9]+]]: +// CHECK-NEXT: cir.int<2> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG]] +// CHECK-NEXT: ^[[JUDGE_RANGE]]: +// CHECK-NEXT: %[[RANGE:[0-9]+]] = cir.const #cir.int<99> +// CHECK-NEXT: %[[LOWER_BOUND:[0-9]+]] = cir.const #cir.int<1> +// CHECK-NEXT: %[[DIFF:[0-9]+]] = cir.binop(sub, %[[X]], %[[LOWER_BOUND]]) +// CHECK-NEXT: %[[U_DIFF:[0-9]+]] = cir.cast(integral, %[[DIFF]] : !s32i), !u32i +// CHECK-NEXT: %[[U_RANGE:[0-9]+]] = cir.cast(integral, %[[RANGE]] : !s32i), !u32i +// CHECK-NEXT: %[[CMP_RESULT:[0-9]+]] = cir.cmp(le, %[[U_DIFF]], %[[U_RANGE]]) +// CHECK-NEXT: cir.brcond %[[CMP_RESULT]] ^[[CASE_RANGE]], ^[[CASE_DEFAULT:bb[0-9]+]] +// CHECK-NEXT: ^[[CASE_DEFAULT]]: +// CHECK-NEXT: cir.int<3> +// CHECK-NEXT: cir.store +// CHECK-NEXT: cir.br ^[[EPILOG]] +// CHECK-NEXT: ^[[EPILOG]]: +// CHECK-NEXT: cir.br ^[[EPILOG_END:bb[0-9]+]] +// CHECK-NEXT: ^[[EPILOG_END]]: +// CHECK: cir.return +// CHECK: } + +} >From 5f5e3a81446c3e740dc20c9379492141f1661294 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Tue, 13 May 2025 20:46:38 -0500 Subject: [PATCH 2/5] Remove auto, add log-range test, and end-to-end test for switch flat op --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 6 +- .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 17 ++-- clang/test/CIR/CodeGen/switch_flat_op.cpp | 81 +++++++++++++++++++ clang/test/CIR/IR/switch-flat.cir | 2 +- clang/test/CIR/Transforms/switch.cir | 40 +++++++++ 5 files changed, 134 insertions(+), 12 deletions(-) create mode 100644 clang/test/CIR/CodeGen/switch_flat_op.cpp diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 914af6d1dc6bd..abacee47d694e 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -981,7 +981,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, let description = [{ The `cir.switch.flat` operation is a region-less and simplified version of the `cir.switch`. - It's representation is closer to LLVM IR dialect + Its representation is closer to LLVM IR dialect than the C/C++ language feature. }]; @@ -989,7 +989,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, CIR_IntType:$condition, Variadic<AnyType>:$defaultOperands, VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, - ArrayAttr:$case_values, + ArrayAttr:$caseValues, DenseI32ArrayAttr:$case_operand_segments ); @@ -1001,7 +1001,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, let assemblyFormat = [{ $condition `:` type($condition) `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? - custom<SwitchFlatOpCases>(ref(type($condition)), $case_values, + custom<SwitchFlatOpCases>(ref(type($condition)), $caseValues, $caseDestinations, $caseOperands, type($caseOperands)) attr-dict diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 70f383b556567..46e25719abafb 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -247,7 +247,8 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { { cir::YieldOp switchYield = nullptr; // Clear switch operation. - for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) + for (mlir::Block &block : + llvm::make_early_inc_range(op.getBody().getBlocks())) if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) switchYield = yieldOp; @@ -279,7 +280,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { mlir::ValueRange defaultOperands = exitBlock->getArguments(); // Digest the case statements values and bodies. - for (auto caseOp : cases) { + for (cir::CaseOp caseOp : cases) { mlir::Region ®ion = caseOp.getCaseRegion(); // Found default case: save destination and operands. @@ -300,7 +301,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { case cir::CaseOpKind::Anyof: case cir::CaseOpKind::Equal: // AnyOf cases kind can have multiple values, hence the loop below. - for (auto &value : caseOp.getValue()) { + for (const mlir::Attribute &value : caseOp.getValue()) { caseValues.push_back(cast<cir::IntAttr>(value).getValue()); caseDestinations.push_back(®ion.front()); caseOperands.push_back(caseDestinations.back()->getArguments()); @@ -319,7 +320,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { }); // Track fallthrough in cases. - for (auto &blk : region.getBlocks()) { + for (mlir::Block &blk : region.getBlocks()) { if (blk.getNumSuccessors()) continue; @@ -349,7 +350,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { } // Remove all cases since we've inlined the regions. - for (auto caseOp : cases) { + for (cir::CaseOp caseOp : cases) { mlir::Block *caseBlock = caseOp->getBlock(); // Erase the block with no predecessors here to make the generated code // simpler a little bit. @@ -359,9 +360,9 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { rewriter.eraseOp(caseOp); } - for (size_t index = 0; index < rangeValues.size(); ++index) { - APInt lowerBound = rangeValues[index].first; - APInt upperBound = rangeValues[index].second; + for (auto [index, rangeVal] : llvm::enumerate(rangeValues)) { + APInt lowerBound = rangeVal.first; + APInt upperBound = rangeVal.second; // The case range is unreachable, skip it. if (lowerBound.sgt(upperBound)) diff --git a/clang/test/CIR/CodeGen/switch_flat_op.cpp b/clang/test/CIR/CodeGen/switch_flat_op.cpp new file mode 100644 index 0000000000000..e6ed1db2c9e19 --- /dev/null +++ b/clang/test/CIR/CodeGen/switch_flat_op.cpp @@ -0,0 +1,81 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: cir-opt --mlir-print-ir-before=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.before.cir 2> %t.before +// RUN: FileCheck --input-file=%t.before %s --check-prefix=BEFORE +// RUN: cir-opt --mlir-print-ir-after=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.after.cir 2> %t.after +// RUN: FileCheck --input-file=%t.after %s --check-prefix=AFTER + + + + + +void swf(int a) { + switch (int b = 3; a) { + case 3: + b = b * 2; + break; + case 4 ... 5: + b = b * 3; + break; + default: + break; + } + +} + +// BEFORE: cir.func @_Z3swfi +// BEFORE: %[[VAR_B:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64} +// BEFORE: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i +// BEFORE: cir.switch (%[[COND:.*]] : !s32i) { +// BEFORE: cir.case(equal, [#cir.int<3> : !s32i]) { +// BEFORE: %[[LOAD_B_EQ:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i +// BEFORE: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i +// BEFORE: %[[MUL_EQ:.*]] = cir.binop(mul, %[[LOAD_B_EQ]], %[[CONST_2]]) nsw : !s32i +// BEFORE: cir.store %[[MUL_EQ]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i> +// BEFORE: cir.break +// BEFORE: } +// BEFORE: cir.case(range, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) { +// BEFORE: %[[LOAD_B_RANGE:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i +// BEFORE: %[[CONST_3_RANGE:.*]] = cir.const #cir.int<3> : !s32i +// BEFORE: %[[MUL_RANGE:.*]] = cir.binop(mul, %[[LOAD_B_RANGE]], %[[CONST_3_RANGE]]) nsw : !s32i +// BEFORE: cir.store %[[MUL_RANGE]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i> +// BEFORE: cir.break +// BEFORE: } +// BEFORE: cir.case(default, []) { +// BEFORE: cir.break +// BEFORE: } +// BEFORE: cir.yield +// BEFORE: } +// BEFORE: } +// BEFORE: cir.return + +// AFTER: cir.func @_Z3swfi +// AFTER: %[[VAR_A:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} +// AFTER: cir.store %arg0, %[[VAR_A]] : !s32i, !cir.ptr<!s32i> +// AFTER: %[[VAR_B:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64} +// AFTER: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i +// AFTER: cir.store %[[CONST_3]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i> +// AFTER: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#BB6:]] [ +// AFTER: 3: ^bb[[#BB4:]], +// AFTER: 4: ^bb[[#BB5:]], +// AFTER: 5: ^bb[[#BB5:]] +// AFTER: ] +// AFTER: ^bb[[#BB4]]: +// AFTER: %[[LOAD_B_EQ:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i +// AFTER: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i +// AFTER: %[[MUL_EQ:.*]] = cir.binop(mul, %[[LOAD_B_EQ]], %[[CONST_2]]) nsw : !s32i +// AFTER: cir.store %[[MUL_EQ]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i> +// AFTER: cir.br ^bb[[#BB7:]] +// AFTER: ^bb[[#BB5]]: +// AFTER: %[[LOAD_B_RANGE:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i +// AFTER: %[[CONST_3_AGAIN:.*]] = cir.const #cir.int<3> : !s32i +// AFTER: %[[MUL_RANGE:.*]] = cir.binop(mul, %[[LOAD_B_RANGE]], %[[CONST_3_AGAIN]]) nsw : !s32i +// AFTER: cir.store %[[MUL_RANGE]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i> +// AFTER: cir.br ^bb[[#BB7]] +// AFTER: ^bb[[#BB6]]: +// AFTER: cir.br ^bb[[#BB7]] +// AFTER: ^bb[[#BB7]]: +// AFTER: cir.br ^bb[[#BB8:]] +// AFTER: ^bb[[#BB8]]: +// AFTER: cir.return +// AFTER: } + diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir index b072c224b4a2c..8c11a74484d39 100644 --- a/clang/test/CIR/IR/switch-flat.cir +++ b/clang/test/CIR/IR/switch-flat.cir @@ -17,7 +17,7 @@ cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) { // CHECK-NEXT: ^bb1: // CHECK-NEXT: cir.br ^bb2 // CHECK-NEXT: ^bb2: -//CHECK-NEXT: cir.return +// CHECK-NEXT: cir.return cir.func @FlatSwitchWithDefault(%arg0: !s32i) { cir.switch.flat %arg0 : !s32i, ^bb2 [ diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir index a05cf37e39728..00b462a6075c9 100644 --- a/clang/test/CIR/Transforms/switch.cir +++ b/clang/test/CIR/Transforms/switch.cir @@ -275,4 +275,44 @@ module { // CHECK: cir.return // CHECK: } + cir.func @_Z8bigRangei(%arg0: !s32i) { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + cir.scope { + %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i + cir.switch (%1 : !s32i) { + cir.case(range, [#cir.int<3> : !s32i, #cir.int<100> : !s32i]) { + cir.break + } + cir.case(default, []) { + cir.break + } + cir.yield + } + } + cir.return + } + +// CHECK: cir.func @_Z8bigRangei(%arg0: !s32i) { +// CHECK: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#RANGE_BR:]] [ +// CHECK: ] +// CHECK: ^bb[[#NO_PRED_BB:]]: // no predecessors +// CHECK: cir.br ^bb[[#DEFAULT_BB:]] +// CHECK: ^bb[[#DEFAULT_BB]]: // 2 preds: ^bb[[#NO_PRED_BB]], ^bb[[#RANGE_BR]] +// CHECK: cir.br ^bb[[#EXIT:]] +// CHECK: ^bb[[#RANGE_BR]]: // pred: ^bb[[#BB2:]] +// CHECK: %[[CONST97:.*]] = cir.const #cir.int<97> : !s32i +// CHECK: %[[CONST3:.*]] = cir.const #cir.int<3> : !s32i +// CHECK: %[[SUB:.*]] = cir.binop(sub, %[[COND]], %[[CONST3]]) : !s32i +// CHECK: %[[CAST1:.*]] = cir.cast(integral, %[[SUB]] : !s32i), !u32i +// CHECK: %[[CAST2:.*]] = cir.cast(integral, %[[CONST97]] : !s32i), !u32i +// CHECK: %[[CMP:.*]] = cir.cmp(le, %[[CAST1]], %[[CAST2]]) : !u32i, !cir.bool +// CHECK: cir.brcond %7 ^bb[[#DEFAULT_BB]], ^bb[[#RANGE_BB:]] +// CHECK: ^bb[[#RANGE_BB]]: // pred: ^bb[[#RANGE_BR]] +// CHECK: cir.br ^bb[[#EXIT]] +// CHECK: ^bb[[#EXIT]]: // 2 preds: ^bb[[#DEFAULT_BB]], ^bb[[#RANGE_BB]] +// CHECK: cir.br ^bb[[#RET_BB:]] +// CHECK: ^bb[[#RET_BB]]: // pred: ^bb[[#EXIT]] +// CHECK: cir.return +// CHECK: } } >From e7ed599c7d2d5bc5939f5d74dd5ab74a2b973c2e Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Tue, 13 May 2025 20:48:26 -0500 Subject: [PATCH 3/5] Fix formatting for switch_flat_op --- clang/test/CIR/CodeGen/switch_flat_op.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/clang/test/CIR/CodeGen/switch_flat_op.cpp b/clang/test/CIR/CodeGen/switch_flat_op.cpp index e6ed1db2c9e19..a9fc095025eb0 100644 --- a/clang/test/CIR/CodeGen/switch_flat_op.cpp +++ b/clang/test/CIR/CodeGen/switch_flat_op.cpp @@ -4,10 +4,6 @@ // RUN: cir-opt --mlir-print-ir-after=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.after.cir 2> %t.after // RUN: FileCheck --input-file=%t.after %s --check-prefix=AFTER - - - - void swf(int a) { switch (int b = 3; a) { case 3: >From ad9dfc6e8f6fe617a785d7d61500bd1e6262aa29 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Tue, 13 May 2025 20:54:04 -0500 Subject: [PATCH 4/5] remove auto keyword --- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index ca03013edb485..57c7e275137f8 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -22,6 +22,7 @@ #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc" #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc" #include "clang/CIR/MissingFeatures.h" +#include "llvm/ADT/APInt.h" #include <numeric> using namespace mlir; @@ -975,7 +976,7 @@ void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result, ArrayRef<ValueRange> caseOperands) { std::vector<mlir::Attribute> caseValuesAttrs; - for (auto &val : caseValues) { + for (const APInt &val : caseValues) { caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val)); } mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs); >From 9e40189c27ab6dbbd4e9c223e44528a210ddf047 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Wed, 14 May 2025 11:12:18 -0500 Subject: [PATCH 5/5] change enumerate to zip --- clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 46e25719abafb..71a45d3c84eea 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -360,7 +360,8 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { rewriter.eraseOp(caseOp); } - for (auto [index, rangeVal] : llvm::enumerate(rangeValues)) { + for (auto [rangeVal, operand, destination] : + llvm::zip(rangeValues, rangeOperands, rangeDestinations)) { APInt lowerBound = rangeVal.first; APInt upperBound = rangeVal.second; @@ -376,16 +377,16 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { for (APInt iValue = lowerBound; iValue.sle(upperBound); (void)iValue++) { caseValues.push_back(iValue); - caseOperands.push_back(rangeOperands[index]); - caseDestinations.push_back(rangeDestinations[index]); + caseOperands.push_back(operand); + caseDestinations.push_back(destination); } continue; } defaultDestination = - condBrToRangeDestination(op, rewriter, rangeDestinations[index], + condBrToRangeDestination(op, rewriter, destination, defaultDestination, lowerBound, upperBound); - defaultOperands = rangeOperands[index]; + defaultOperands = operand; } // Set switch op to branch to the newly created blocks. _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits