https://github.com/mmha updated https://github.com/llvm/llvm-project/pull/137184
>From 1eed90e3859c2ad8d703708f89976cad8f0faeec Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Thu, 24 Apr 2025 16:12:37 +0200 Subject: [PATCH 1/3] [CIR] Upstream TernaryOp This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG. --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 57 +++++++++++++++- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 42 ++++++++++++ .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 60 ++++++++++++++-- clang/test/CIR/IR/ternary.cir | 30 ++++++++ clang/test/CIR/Lowering/ternary.cir | 30 ++++++++ clang/test/CIR/Transforms/ternary.cir | 68 +++++++++++++++++++ 6 files changed, 280 insertions(+), 7 deletions(-) create mode 100644 clang/test/CIR/IR/ternary.cir create mode 100644 clang/test/CIR/Lowering/ternary.cir create mode 100644 clang/test/CIR/Transforms/ternary.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 81b447f31feca..76ad5c3666c1b 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, - ParentOneOf<["IfOp", "ScopeOp", "WhileOp", - "ForOp", "DoWhileOp"]>]> { + ParentOneOf<["IfOp", "TernaryOp", "ScopeOp", + "WhileOp", "ForOp", "DoWhileOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on different CIR operations, @@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure, }]; } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +def TernaryOp : CIR_Op<"ternary", + [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> { + let summary = "The `cond ? a : b` C/C++ ternary operation"; + let description = [{ + The `cir.ternary` operation represents C/C++ ternary, much like a `select` + operation. The first argument is a `cir.bool` condition to evaluate, followed + by two regions to execute (true or false). This is different from `cir.if` + since each region is one block sized and the `cir.yield` closing the block + scope should have one argument. + + Example: + + ```mlir + // x = cond ? a : b; + + %x = cir.ternary (%cond, true_region { + ... + cir.yield %a : i32 + }, false_region { + ... + cir.yield %b : i32 + }) -> i32 + ``` + }]; + let arguments = (ins CIR_BoolType:$cond); + let regions = (region AnyRegion:$trueRegion, + AnyRegion:$falseRegion); + let results = (outs Optional<CIR_AnyType>:$result); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "mlir::Value":$cond, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder) + > + ]; + + // All constraints already verified elsewhere. + let hasVerifier = 0; + + let assemblyFormat = [{ + `(` $cond `,` + `true` $trueRegion `,` + `false` $falseRegion + `)` `:` functional-type(operands, results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 89daf20c5f478..e80d243cb396f 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void cir::TernaryOp::getSuccessorRegions( + mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + // The `true` and the `false` region branch back to the parent operation. + if (!point.isParent()) { + regions.push_back(RegionSuccessor(this->getODSResults(0))); + return; + } + + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&getTrueRegion())); + regions.push_back(RegionSuccessor(&getFalseRegion())); +} + +void cir::TernaryOp::build( + OpBuilder &builder, OperationState &result, Value cond, + function_ref<void(OpBuilder &, Location)> trueBuilder, + function_ref<void(OpBuilder &, Location)> falseBuilder) { + result.addOperands(cond); + OpBuilder::InsertionGuard guard(builder); + Region *trueRegion = result.addRegion(); + Block *block = builder.createBlock(trueRegion); + trueBuilder(builder, result.location); + Region *falseRegion = result.addRegion(); + builder.createBlock(falseRegion); + falseBuilder(builder, result.location); + + auto yield = dyn_cast<YieldOp>(block->getTerminator()); + assert((yield && yield.getNumOperands() <= 1) && + "expected zero or one result type"); + if (yield.getNumOperands() == 1) + result.addTypes(TypeRange{yield.getOperandTypes().front()}); +} + //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 72ccfa8d4e14e..295fa748b1624 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening } }; +class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { +public: + using OpRewritePattern<cir::TernaryOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cir::TernaryOp op, + mlir::PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Block *condBlock = rewriter.getInsertionBlock(); + Block::iterator opPosition = rewriter.getInsertionPoint(); + Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + llvm::SmallVector<mlir::Location, 2> locs; + // Ternary result is optional, make sure to populate the location only + // when relevant. + if (op->getResultTypes().size()) + locs.push_back(loc); + auto *continueBlock = + rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); + rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + + Region &trueRegion = op.getTrueRegion(); + Block *trueBlock = &trueRegion.front(); + mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&trueRegion.back()); + auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); + + rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(trueRegion, continueBlock); + + Block *falseBlock = continueBlock; + Region &falseRegion = op.getFalseRegion(); + + falseBlock = &falseRegion.front(); + mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&falseRegion.back()); + cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); + rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(falseRegion, continueBlock); + + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); + + rewriter.replaceOp(op, continueBlock->getArguments()); + + // Ok, we're done! + return mlir::success(); + } +}; + void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns - .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>( - patterns.getContext()); + patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, + CIRScopeOpFlattening, CIRTernaryOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() { getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); - assert(!cir::MissingFeatures::ternaryOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, LoopOpInterface>(op)) + if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir new file mode 100644 index 0000000000000..3827dc77726df --- /dev/null +++ b/clang/test/CIR/IR/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-opt %s | cir-opt | FileCheck %s +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// CHECK: module { + +// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i { +// CHECK: %0 = cir.ternary(%arg0, true { +// CHECK: %1 = cir.const #cir.int<0> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }, false { +// CHECK: %1 = cir.const #cir.int<1> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }) : (!cir.bool) -> !u32i +// CHECK: cir.return %0 : !u32i +// CHECK: } + +// CHECK: } diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir new file mode 100644 index 0000000000000..247c6ae3a1e17 --- /dev/null +++ b/clang/test/CIR/Lowering/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s + +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// LLVM-LABEL: define i32 {{.*}}@blue( +// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]]) +// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]] +// LLVM: [[B1]]: +// LLVM: br label %[[M:[[:alnum:]]+]] +// LLVM: [[B2]]: +// LLVM: br label %[[M]] +// LLVM: [[M]]: +// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ] +// LLVM: br label %[[B3:[[:alnum:]]+]] +// LLVM: [[B3]]: +// LLVM: ret i32 [[R]] diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir new file mode 100644 index 0000000000000..67ef7f95a6b52 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary.cir @@ -0,0 +1,68 @@ +// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @foo(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i + %3 = cir.const #cir.int<0> : !s32i + %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool + %5 = cir.ternary(%4, true { + %7 = cir.const #cir.int<3> : !s32i + cir.yield %7 : !s32i + }, false { + %7 = cir.const #cir.int<5> : !s32i + cir.yield %7 : !s32i + }) : (!cir.bool) -> !s32i + cir.store %5, %1 : !s32i, !cir.ptr<!s32i> + %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i + cir.return %6 : !s32i + } + +// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i { +// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} +// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} +// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> +// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CHECK: %3 = cir.const #cir.int<0> : !s32i +// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool +// CHECK: cir.brcond %4 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %5 = cir.const #cir.int<3> : !s32i +// CHECK: cir.br ^bb3(%5 : !s32i) +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: %6 = cir.const #cir.int<5> : !s32i +// CHECK: cir.br ^bb3(%6 : !s32i) +// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i> +// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i +// CHECK: cir.return %8 : !s32i +// CHECK: } + + cir.func @foo2(%arg0: !cir.bool) { + cir.ternary(%arg0, true { + cir.yield + }, false { + cir.yield + }) : (!cir.bool) -> () + cir.return + } + +// CHECK: cir.func @foo2(%arg0: !cir.bool) { +// CHECK: cir.brcond %arg0 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.return +// CHECK: } + +} >From 3e9d9d35b52c0b69ac9950f53cad044a958a81d4 Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Thu, 24 Apr 2025 17:08:17 +0200 Subject: [PATCH 2/3] Reorder YieldOp parents lexicographically --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 76ad5c3666c1b..760149636b23b 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, - ParentOneOf<["IfOp", "TernaryOp", "ScopeOp", - "WhileOp", "ForOp", "DoWhileOp"]>]> { + ParentOneOf<["DoWhileOp", "ForOp", "WhileOp", + "IfOp", "ScopeOp", "TernaryOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on different CIR operations, >From c98ee8457506779451c0f6a531c14c151e8704fd Mon Sep 17 00:00:00 2001 From: Morris Hafner <m...@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:57:09 +0200 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Andy Kaylor <akay...@nvidia.com> --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 +- clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 760149636b23b..a01fb0aa60844 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1264,7 +1264,7 @@ def TernaryOp : CIR_Op<"ternary", Example: ```mlir - // x = cond ? a : b; + // cond = a && b; %x = cir.ternary (%cond, true_region { ... diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 295fa748b1624..4a936d33b022a 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -270,7 +270,7 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { // when relevant. if (op->getResultTypes().size()) locs.push_back(loc); - auto *continueBlock = + Block *continueBlock = rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); rewriter.create<cir::BrOp>(loc, remainingOpsBlock); @@ -290,7 +290,7 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { falseBlock = &falseRegion.front(); mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); rewriter.setInsertionPointToEnd(&falseRegion.back()); - cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); + auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), continueBlock); rewriter.inlineRegionBefore(falseRegion, continueBlock); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits