llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: Morris Hafner (mmha)

<details>
<summary>Changes</summary>

This patch adds TernaryOp to CIR plus a pass that flattens the operator in 
FlattenCFG.

This is the first PR out of (probably) 3 wrt. TernaryOp. I split the patches up 
to make reviewing them easier. As such, this PR is only about adding the CIR 
operation. The next PR will be about the CodeGen bits from the C++ conditional 
operator and the final one will add the cir-simplify transform for TernaryOp 
and SelectOp.

---
Full diff: https://github.com/llvm/llvm-project/pull/137184.diff


6 Files Affected:

- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+55-2) 
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+42) 
- (modified) clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp (+55-5) 
- (added) clang/test/CIR/IR/ternary.cir (+30) 
- (added) clang/test/CIR/Lowering/ternary.cir (+30) 
- (added) clang/test/CIR/Transforms/ternary.cir (+68) 


``````````diff
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 81b447f31feca..76ad5c3666c1b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
 
//===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "ScopeOp", "WhileOp",
-                                            "ForOp", "DoWhileOp"]>]> {
+                               ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
+                                            "WhileOp", "ForOp", 
"DoWhileOp"]>]> {
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on different CIR operations,
@@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+def TernaryOp : CIR_Op<"ternary",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+  let summary = "The `cond ? a : b` C/C++ ternary operation";
+  let description = [{
+    The `cir.ternary` operation represents C/C++ ternary, much like a `select`
+    operation. The first argument is a `cir.bool` condition to evaluate, 
followed
+    by two regions to execute (true or false). This is different from `cir.if`
+    since each region is one block sized and the `cir.yield` closing the block
+    scope should have one argument.
+
+    Example:
+
+    ```mlir
+    // x = cond ? a : b;
+
+    %x = cir.ternary (%cond, true_region {
+      ...
+      cir.yield %a : i32
+    }, false_region {
+      ...
+      cir.yield %b : i32
+    }) -> i32
+    ```
+  }];
+  let arguments = (ins CIR_BoolType:$cond);
+  let regions = (region AnyRegion:$trueRegion,
+                        AnyRegion:$falseRegion);
+  let results = (outs Optional<CIR_AnyType>:$result);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$cond,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$trueBuilder,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$falseBuilder)
+      >
+  ];
+
+  // All constraints already verified elsewhere.
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    `(` $cond `,`
+      `true` $trueRegion `,`
+      `false` $falseRegion
+    `)` `:` functional-type(operands, results) attr-dict
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // GlobalOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 89daf20c5f478..e80d243cb396f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void cir::TernaryOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `true` and the `false` region branch back to the parent operation.
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor(this->getODSResults(0)));
+    return;
+  }
+
+  // If the condition isn't constant, both regions may be executed.
+  regions.push_back(RegionSuccessor(&getTrueRegion()));
+  regions.push_back(RegionSuccessor(&getFalseRegion()));
+}
+
+void cir::TernaryOp::build(
+    OpBuilder &builder, OperationState &result, Value cond,
+    function_ref<void(OpBuilder &, Location)> trueBuilder,
+    function_ref<void(OpBuilder &, Location)> falseBuilder) {
+  result.addOperands(cond);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *trueRegion = result.addRegion();
+  Block *block = builder.createBlock(trueRegion);
+  trueBuilder(builder, result.location);
+  Region *falseRegion = result.addRegion();
+  builder.createBlock(falseRegion);
+  falseBuilder(builder, result.location);
+
+  auto yield = dyn_cast<YieldOp>(block->getTerminator());
+  assert((yield && yield.getNumOperands() <= 1) &&
+         "expected zero or one result type");
+  if (yield.getNumOperands() == 1)
+    result.addTypes(TypeRange{yield.getOperandTypes().front()});
+}
+
 
//===----------------------------------------------------------------------===//
 // ShiftOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp 
b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 72ccfa8d4e14e..295fa748b1624 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
   }
 };
 
+class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
+public:
+  using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::TernaryOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Block *condBlock = rewriter.getInsertionBlock();
+    Block::iterator opPosition = rewriter.getInsertionPoint();
+    Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+    llvm::SmallVector<mlir::Location, 2> locs;
+    // Ternary result is optional, make sure to populate the location only
+    // when relevant.
+    if (op->getResultTypes().size())
+      locs.push_back(loc);
+    auto *continueBlock =
+        rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
+    rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
+
+    Region &trueRegion = op.getTrueRegion();
+    Block *trueBlock = &trueRegion.front();
+    mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&trueRegion.back());
+    auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
+
+    rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(trueRegion, continueBlock);
+
+    Block *falseBlock = continueBlock;
+    Region &falseRegion = op.getFalseRegion();
+
+    falseBlock = &falseRegion.front();
+    mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&falseRegion.back());
+    cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, 
falseYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(falseRegion, continueBlock);
+
+    rewriter.setInsertionPointToEnd(condBlock);
+    rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
+
+    rewriter.replaceOp(op, continueBlock->getArguments());
+
+    // Ok, we're done!
+    return mlir::success();
+  }
+};
+
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, 
CIRScopeOpFlattening>(
-          patterns.getContext());
+  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
+               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
+      patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
   getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
-    assert(!cir::MissingFeatures::ternaryOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
+    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir
new file mode 100644
index 0000000000000..3827dc77726df
--- /dev/null
+++ b/clang/test/CIR/IR/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-opt %s | cir-opt | FileCheck %s
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// CHECK: module  {
+
+// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
+// CHECK:   %0 = cir.ternary(%arg0, true {
+// CHECK:     %1 = cir.const #cir.int<0> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }, false {
+// CHECK:     %1 = cir.const #cir.int<1> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }) : (!cir.bool) -> !u32i
+// CHECK:   cir.return %0 : !u32i
+// CHECK: }
+
+// CHECK: }
diff --git a/clang/test/CIR/Lowering/ternary.cir 
b/clang/test/CIR/Lowering/ternary.cir
new file mode 100644
index 0000000000000..247c6ae3a1e17
--- /dev/null
+++ b/clang/test/CIR/Lowering/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
+
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// LLVM-LABEL: define i32 {{.*}}@blue(
+// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
+// LLVM:   br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label 
%[[B2:[[:alnum:]]+]]
+// LLVM: [[B1]]:
+// LLVM:   br label %[[M:[[:alnum:]]+]]
+// LLVM: [[B2]]:
+// LLVM:   br label %[[M]]
+// LLVM: [[M]]:
+// LLVM:   [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
+// LLVM:   br label %[[B3:[[:alnum:]]+]]
+// LLVM: [[B3]]:
+// LLVM:   ret i32 [[R]]
diff --git a/clang/test/CIR/Transforms/ternary.cir 
b/clang/test/CIR/Transforms/ternary.cir
new file mode 100644
index 0000000000000..67ef7f95a6b52
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @foo(%arg0: !s32i) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+    %3 = cir.const #cir.int<0> : !s32i
+    %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+    %5 = cir.ternary(%4, true {
+      %7 = cir.const #cir.int<3> : !s32i
+      cir.yield %7 : !s32i
+    }, false {
+      %7 = cir.const #cir.int<5> : !s32i
+      cir.yield %7 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+    %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+    cir.return %6 : !s32i
+  }
+
+// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
+// CHECK:   %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 
: i64}
+// CHECK:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 
4 : i64}
+// CHECK:   cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK:   %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK:   %3 = cir.const #cir.int<0> : !s32i
+// CHECK:   %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+// CHECK:    cir.brcond %4 ^bb1, ^bb2
+// CHECK:  ^bb1:  // pred: ^bb0
+// CHECK:    %5 = cir.const #cir.int<3> : !s32i
+// CHECK:    cir.br ^bb3(%5 : !s32i)
+// CHECK:  ^bb2:  // pred: ^bb0
+// CHECK:    %6 = cir.const #cir.int<5> : !s32i
+// CHECK:    cir.br ^bb3(%6 : !s32i)
+// CHECK:  ^bb3(%7: !s32i):  // 2 preds: ^bb1, ^bb2
+// CHECK:    cir.br ^bb4
+// CHECK:  ^bb4:  // pred: ^bb3
+// CHECK:    cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK:    %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CHECK:    cir.return %8 : !s32i
+// CHECK:  }
+
+  cir.func @foo2(%arg0: !cir.bool) {
+    cir.ternary(%arg0, true {
+      cir.yield
+    }, false {
+      cir.yield
+    }) : (!cir.bool) -> ()
+    cir.return
+  }
+
+// CHECK: cir.func @foo2(%arg0: !cir.bool) {
+// CHECK:   cir.brcond %arg0 ^bb1, ^bb2
+// CHECK: ^bb1:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb2:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+// CHECK:   cir.br ^bb4
+// CHECK: ^bb4:  // pred: ^bb3
+// CHECK:   cir.return
+// CHECK: }
+
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/137184
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to