https://github.com/mmha created https://github.com/llvm/llvm-project/pull/142165
We used to insert a continue Block at the end of a flattened ternary op that only contained a branch to the remaing operation of the remaining Block. This patch removes that continue block and changes the true/false blocks to directly jump to the remaining ops. With this patch the CIR now generates exactly the same LLVM IR as the original codegen. This upstreams llvm/clangir#1651. >From f62994df24f912a3815cabb7fc4a47fa8c8c948e Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Fri, 30 May 2025 16:29:05 +0200 Subject: [PATCH] [CIR] Skip generation of a continue block when flattening TernaryOp We used to insert a continue Block at the end of a flattened ternary op that only contained a branch to the remaing operation of the remaining Block. This patch removes that continue block and changes the true/false blocks to directly jump to the remaining ops. With this patch the CIR now generates exactly the same LLVM IR as the original codegen. This upstreams llvm/clangir#1651. --- .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 31 +++++++++---------- clang/test/CIR/Lowering/ternary.cir | 2 -- clang/test/CIR/Transforms/ternary.cir | 4 --- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 26e5c0572f12e..6081a436f5c29 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -16,9 +16,11 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "clang/AST/DeclBase.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/MissingFeatures.h" @@ -491,15 +493,7 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { Location loc = op->getLoc(); Block *condBlock = rewriter.getInsertionBlock(); Block::iterator opPosition = rewriter.getInsertionPoint(); - Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); - llvm::SmallVector<mlir::Location, 2> locs; - // Ternary result is optional, make sure to populate the location only - // when relevant. - if (op->getResultTypes().size()) - locs.push_back(loc); - Block *continueBlock = - rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); - rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); Region &trueRegion = op.getTrueRegion(); Block *trueBlock = &trueRegion.front(); @@ -508,24 +502,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), - continueBlock); - rewriter.inlineRegionBefore(trueRegion, continueBlock); + remainingOpsBlock); + rewriter.inlineRegionBefore(trueRegion, remainingOpsBlock); - Block *falseBlock = continueBlock; Region &falseRegion = op.getFalseRegion(); + Block *falseBlock = &falseRegion.front(); - falseBlock = &falseRegion.front(); mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); rewriter.setInsertionPointToEnd(&falseRegion.back()); auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), - continueBlock); - rewriter.inlineRegionBefore(falseRegion, continueBlock); + remainingOpsBlock); + rewriter.inlineRegionBefore(falseRegion, remainingOpsBlock); rewriter.setInsertionPointToEnd(condBlock); rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); - rewriter.replaceOp(op, continueBlock->getArguments()); + if (auto rt = op.getResultTypes(); rt.size()) { + iterator_range args = remainingOpsBlock->addArguments(rt, op.getLoc()); + SmallVector<mlir::Value, 2> values; + llvm::copy(args, std::back_inserter(values)); + rewriter.replaceOpUsesWithinBlock(op, values, remainingOpsBlock); + } + rewriter.eraseOp(op); // Ok, we're done! return mlir::success(); diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir index 247c6ae3a1e17..c2226cd92ece7 100644 --- a/clang/test/CIR/Lowering/ternary.cir +++ b/clang/test/CIR/Lowering/ternary.cir @@ -25,6 +25,4 @@ module { // LLVM: br label %[[M]] // LLVM: [[M]]: // LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ] -// LLVM: br label %[[B3:[[:alnum:]]+]] -// LLVM: [[B3]]: // LLVM: ret i32 [[R]] diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir index 67ef7f95a6b52..0c22268495697 100644 --- a/clang/test/CIR/Transforms/ternary.cir +++ b/clang/test/CIR/Transforms/ternary.cir @@ -37,8 +37,6 @@ module { // CHECK: %6 = cir.const #cir.int<5> : !s32i // CHECK: cir.br ^bb3(%6 : !s32i) // CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2 -// CHECK: cir.br ^bb4 -// CHECK: ^bb4: // pred: ^bb3 // CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i> // CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i // CHECK: cir.return %8 : !s32i @@ -60,8 +58,6 @@ module { // CHECK: ^bb2: // pred: ^bb0 // CHECK: cir.br ^bb3 // CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 -// CHECK: cir.br ^bb4 -// CHECK: ^bb4: // pred: ^bb3 // CHECK: cir.return // CHECK: } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits