https://github.com/mmha updated https://github.com/llvm/llvm-project/pull/137184

>From 1eed90e3859c2ad8d703708f89976cad8f0faeec Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhaf...@nvidia.com>
Date: Thu, 24 Apr 2025 16:12:37 +0200
Subject: [PATCH 1/3] [CIR] Upstream TernaryOp

This patch adds TernaryOp to CIR plus a pass that flattens the operator in 
FlattenCFG.
---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  | 57 +++++++++++++++-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 42 ++++++++++++
 .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 60 ++++++++++++++--
 clang/test/CIR/IR/ternary.cir                 | 30 ++++++++
 clang/test/CIR/Lowering/ternary.cir           | 30 ++++++++
 clang/test/CIR/Transforms/ternary.cir         | 68 +++++++++++++++++++
 6 files changed, 280 insertions(+), 7 deletions(-)
 create mode 100644 clang/test/CIR/IR/ternary.cir
 create mode 100644 clang/test/CIR/Lowering/ternary.cir
 create mode 100644 clang/test/CIR/Transforms/ternary.cir

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 81b447f31feca..76ad5c3666c1b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
 
//===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "ScopeOp", "WhileOp",
-                                            "ForOp", "DoWhileOp"]>]> {
+                               ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
+                                            "WhileOp", "ForOp", 
"DoWhileOp"]>]> {
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on different CIR operations,
@@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+def TernaryOp : CIR_Op<"ternary",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+  let summary = "The `cond ? a : b` C/C++ ternary operation";
+  let description = [{
+    The `cir.ternary` operation represents C/C++ ternary, much like a `select`
+    operation. The first argument is a `cir.bool` condition to evaluate, 
followed
+    by two regions to execute (true or false). This is different from `cir.if`
+    since each region is one block sized and the `cir.yield` closing the block
+    scope should have one argument.
+
+    Example:
+
+    ```mlir
+    // x = cond ? a : b;
+
+    %x = cir.ternary (%cond, true_region {
+      ...
+      cir.yield %a : i32
+    }, false_region {
+      ...
+      cir.yield %b : i32
+    }) -> i32
+    ```
+  }];
+  let arguments = (ins CIR_BoolType:$cond);
+  let regions = (region AnyRegion:$trueRegion,
+                        AnyRegion:$falseRegion);
+  let results = (outs Optional<CIR_AnyType>:$result);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$cond,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$trueBuilder,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$falseBuilder)
+      >
+  ];
+
+  // All constraints already verified elsewhere.
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    `(` $cond `,`
+      `true` $trueRegion `,`
+      `false` $falseRegion
+    `)` `:` functional-type(operands, results) attr-dict
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // GlobalOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 89daf20c5f478..e80d243cb396f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void cir::TernaryOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `true` and the `false` region branch back to the parent operation.
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor(this->getODSResults(0)));
+    return;
+  }
+
+  // If the condition isn't constant, both regions may be executed.
+  regions.push_back(RegionSuccessor(&getTrueRegion()));
+  regions.push_back(RegionSuccessor(&getFalseRegion()));
+}
+
+void cir::TernaryOp::build(
+    OpBuilder &builder, OperationState &result, Value cond,
+    function_ref<void(OpBuilder &, Location)> trueBuilder,
+    function_ref<void(OpBuilder &, Location)> falseBuilder) {
+  result.addOperands(cond);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *trueRegion = result.addRegion();
+  Block *block = builder.createBlock(trueRegion);
+  trueBuilder(builder, result.location);
+  Region *falseRegion = result.addRegion();
+  builder.createBlock(falseRegion);
+  falseBuilder(builder, result.location);
+
+  auto yield = dyn_cast<YieldOp>(block->getTerminator());
+  assert((yield && yield.getNumOperands() <= 1) &&
+         "expected zero or one result type");
+  if (yield.getNumOperands() == 1)
+    result.addTypes(TypeRange{yield.getOperandTypes().front()});
+}
+
 
//===----------------------------------------------------------------------===//
 // ShiftOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp 
b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 72ccfa8d4e14e..295fa748b1624 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
   }
 };
 
+class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
+public:
+  using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::TernaryOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Block *condBlock = rewriter.getInsertionBlock();
+    Block::iterator opPosition = rewriter.getInsertionPoint();
+    Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+    llvm::SmallVector<mlir::Location, 2> locs;
+    // Ternary result is optional, make sure to populate the location only
+    // when relevant.
+    if (op->getResultTypes().size())
+      locs.push_back(loc);
+    auto *continueBlock =
+        rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
+    rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
+
+    Region &trueRegion = op.getTrueRegion();
+    Block *trueBlock = &trueRegion.front();
+    mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&trueRegion.back());
+    auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
+
+    rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(trueRegion, continueBlock);
+
+    Block *falseBlock = continueBlock;
+    Region &falseRegion = op.getFalseRegion();
+
+    falseBlock = &falseRegion.front();
+    mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&falseRegion.back());
+    cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, 
falseYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(falseRegion, continueBlock);
+
+    rewriter.setInsertionPointToEnd(condBlock);
+    rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
+
+    rewriter.replaceOp(op, continueBlock->getArguments());
+
+    // Ok, we're done!
+    return mlir::success();
+  }
+};
+
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, 
CIRScopeOpFlattening>(
-          patterns.getContext());
+  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
+               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
+      patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
   getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
-    assert(!cir::MissingFeatures::ternaryOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
+    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir
new file mode 100644
index 0000000000000..3827dc77726df
--- /dev/null
+++ b/clang/test/CIR/IR/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-opt %s | cir-opt | FileCheck %s
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// CHECK: module  {
+
+// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
+// CHECK:   %0 = cir.ternary(%arg0, true {
+// CHECK:     %1 = cir.const #cir.int<0> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }, false {
+// CHECK:     %1 = cir.const #cir.int<1> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }) : (!cir.bool) -> !u32i
+// CHECK:   cir.return %0 : !u32i
+// CHECK: }
+
+// CHECK: }
diff --git a/clang/test/CIR/Lowering/ternary.cir 
b/clang/test/CIR/Lowering/ternary.cir
new file mode 100644
index 0000000000000..247c6ae3a1e17
--- /dev/null
+++ b/clang/test/CIR/Lowering/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
+
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// LLVM-LABEL: define i32 {{.*}}@blue(
+// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
+// LLVM:   br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label 
%[[B2:[[:alnum:]]+]]
+// LLVM: [[B1]]:
+// LLVM:   br label %[[M:[[:alnum:]]+]]
+// LLVM: [[B2]]:
+// LLVM:   br label %[[M]]
+// LLVM: [[M]]:
+// LLVM:   [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
+// LLVM:   br label %[[B3:[[:alnum:]]+]]
+// LLVM: [[B3]]:
+// LLVM:   ret i32 [[R]]
diff --git a/clang/test/CIR/Transforms/ternary.cir 
b/clang/test/CIR/Transforms/ternary.cir
new file mode 100644
index 0000000000000..67ef7f95a6b52
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @foo(%arg0: !s32i) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+    %3 = cir.const #cir.int<0> : !s32i
+    %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+    %5 = cir.ternary(%4, true {
+      %7 = cir.const #cir.int<3> : !s32i
+      cir.yield %7 : !s32i
+    }, false {
+      %7 = cir.const #cir.int<5> : !s32i
+      cir.yield %7 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+    %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+    cir.return %6 : !s32i
+  }
+
+// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
+// CHECK:   %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 
: i64}
+// CHECK:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 
4 : i64}
+// CHECK:   cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK:   %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK:   %3 = cir.const #cir.int<0> : !s32i
+// CHECK:   %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+// CHECK:    cir.brcond %4 ^bb1, ^bb2
+// CHECK:  ^bb1:  // pred: ^bb0
+// CHECK:    %5 = cir.const #cir.int<3> : !s32i
+// CHECK:    cir.br ^bb3(%5 : !s32i)
+// CHECK:  ^bb2:  // pred: ^bb0
+// CHECK:    %6 = cir.const #cir.int<5> : !s32i
+// CHECK:    cir.br ^bb3(%6 : !s32i)
+// CHECK:  ^bb3(%7: !s32i):  // 2 preds: ^bb1, ^bb2
+// CHECK:    cir.br ^bb4
+// CHECK:  ^bb4:  // pred: ^bb3
+// CHECK:    cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK:    %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CHECK:    cir.return %8 : !s32i
+// CHECK:  }
+
+  cir.func @foo2(%arg0: !cir.bool) {
+    cir.ternary(%arg0, true {
+      cir.yield
+    }, false {
+      cir.yield
+    }) : (!cir.bool) -> ()
+    cir.return
+  }
+
+// CHECK: cir.func @foo2(%arg0: !cir.bool) {
+// CHECK:   cir.brcond %arg0 ^bb1, ^bb2
+// CHECK: ^bb1:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb2:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+// CHECK:   cir.br ^bb4
+// CHECK: ^bb4:  // pred: ^bb3
+// CHECK:   cir.return
+// CHECK: }
+
+}

>From 3e9d9d35b52c0b69ac9950f53cad044a958a81d4 Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhaf...@nvidia.com>
Date: Thu, 24 Apr 2025 17:08:17 +0200
Subject: [PATCH 2/3] Reorder YieldOp parents lexicographically

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 76ad5c3666c1b..760149636b23b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
 
//===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
-                                            "WhileOp", "ForOp", 
"DoWhileOp"]>]> {
+                               ParentOneOf<["DoWhileOp", "ForOp", "WhileOp",
+                                            "IfOp", "ScopeOp", "TernaryOp"]>]> 
{
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on different CIR operations,

>From c98ee8457506779451c0f6a531c14c151e8704fd Mon Sep 17 00:00:00 2001
From: Morris Hafner <m...@users.noreply.github.com>
Date: Fri, 25 Apr 2025 15:57:09 +0200
Subject: [PATCH 3/3] Apply suggestions from code review

Co-authored-by: Andy Kaylor <akay...@nvidia.com>
---
 clang/include/clang/CIR/Dialect/IR/CIROps.td    | 2 +-
 clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 760149636b23b..a01fb0aa60844 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1264,7 +1264,7 @@ def TernaryOp : CIR_Op<"ternary",
     Example:
 
     ```mlir
-    // x = cond ? a : b;
+    // cond = a && b;
 
     %x = cir.ternary (%cond, true_region {
       ...
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp 
b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 295fa748b1624..4a936d33b022a 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -270,7 +270,7 @@ class CIRTernaryOpFlattening : public 
mlir::OpRewritePattern<cir::TernaryOp> {
     // when relevant.
     if (op->getResultTypes().size())
       locs.push_back(loc);
-    auto *continueBlock =
+    Block *continueBlock =
         rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
     rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
 
@@ -290,7 +290,7 @@ class CIRTernaryOpFlattening : public 
mlir::OpRewritePattern<cir::TernaryOp> {
     falseBlock = &falseRegion.front();
     mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
     rewriter.setInsertionPointToEnd(&falseRegion.back());
-    cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+    auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
     rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, 
falseYieldOp.getArgs(),
                                            continueBlock);
     rewriter.inlineRegionBefore(falseRegion, continueBlock);

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to