https://github.com/Andres-Salamanca created https://github.com/llvm/llvm-project/pull/134333
This patch adds support for if statements in the CIR dialect Additionally, multiple RUN lines were introduced to improve codegen test coverage >From 89f0f528f981223273b2c1548c9a71f2ceeca329 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <andrealebarbari...@gmail.com> Date: Thu, 3 Apr 2025 12:07:25 -0500 Subject: [PATCH] [CIR] Add if statement support Upstream if statement support Formatted code and added test cases. added multiple RUN lines for the codegen test --- .../include/clang/CIR/Dialect/IR/CIRDialect.h | 4 + clang/include/clang/CIR/Dialect/IR/CIROps.td | 60 ++++- clang/include/clang/CIR/MissingFeatures.h | 4 + clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 100 +++++++ clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 14 + clang/lib/CIR/CodeGen/CIRGenFunction.cpp | 49 ++++ clang/lib/CIR/CodeGen/CIRGenFunction.h | 34 +++ clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 71 ++++- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 128 +++++++++ .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 70 ++++- clang/test/CIR/CodeGen/if.cpp | 254 ++++++++++++++++++ clang/test/CIR/Lowering/if.cir | 99 +++++++ clang/test/CIR/Transforms/if.cir | 48 ++++ 13 files changed, 926 insertions(+), 9 deletions(-) create mode 100644 clang/test/CIR/CodeGen/if.cpp create mode 100644 clang/test/CIR/Lowering/if.cir create mode 100644 clang/test/CIR/Transforms/if.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h index 4d7f537418a90..4d7f0bfd1c253 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h @@ -35,6 +35,10 @@ using BuilderCallbackRef = llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>; +namespace cir { +void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc); +} // namespace cir + // TableGen'erated files for MLIR dialects require that a macro be defined when // they are included. GET_OP_CLASSES tells the file to define the classes for // the operations of that dialect. diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 3965372755685..e181a5db3e1b9 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -424,8 +424,8 @@ def StoreOp : CIR_Op<"store", [ // ReturnOp //===----------------------------------------------------------------------===// -def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "DoWhileOp", - "WhileOp", "ForOp"]>, +def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp", + "DoWhileOp", "WhileOp", "ForOp"]>, Terminator]> { let summary = "Return from function"; let description = [{ @@ -462,6 +462,58 @@ def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "DoWhileOp", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +def IfOp : CIR_Op<"if", + [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]>{ + + let summary = "the if-then-else operation"; + let description = [{ + The `cir.if` operation represents an if-then-else construct for + conditionally executing two regions of code. The operand is a `cir.bool` + type. + + Examples: + + ```mlir + cir.if %b { + ... + } else { + ... + } + + cir.if %c { + ... + } + + cir.if %c { + ... + cir.br ^a + ^a: + cir.yield + } + ``` + + `cir.if` defines no values and the 'else' can be omitted. The if/else + regions must be terminated. If the region has only one block, the terminator + can be left out, and `cir.yield` terminator will be inserted implictly. + Otherwise, the region must be explicitly terminated. + }]; + let arguments = (ins CIR_BoolType:$condition); + let regions = (region AnyRegion:$thenRegion, AnyRegion:$elseRegion); + let hasCustomAssemblyFormat=1; + let hasVerifier=1; + let skipDefaultBuilders=1; + let builders = [ + OpBuilder<(ins "mlir::Value":$cond, "bool":$withElseRegion, + CArg<"BuilderCallbackRef", "buildTerminatedBody">:$thenBuilder, + CArg<"BuilderCallbackRef", "nullptr">:$elseBuilder)> + ]; +} + //===----------------------------------------------------------------------===// // ConditionOp //===----------------------------------------------------------------------===// @@ -512,8 +564,8 @@ def ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, - ParentOneOf<["ScopeOp", "WhileOp", "ForOp", - "DoWhileOp"]>]> { + ParentOneOf<["IfOp", "ScopeOp", "WhileOp", + "ForOp", "DoWhileOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on different CIR operations, diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 3a102d90aba8f..1d53d094fa4e7 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -81,6 +81,7 @@ struct MissingFeatures { // Clang early optimizations or things defered to LLVM lowering. static bool mayHaveIntegerOverflow() { return false; } + static bool shouldReverseUnaryCondOnBoolExpr() { return false; } // Misc static bool cxxABI() { return false; } @@ -109,6 +110,9 @@ struct MissingFeatures { static bool cgFPOptionsRAII() { return false; } static bool metaDataNode() { return false; } static bool fastMathFlags() { return false; } + static bool constantFoldsToSimpleInteger() { return false; } + static bool incrementProfileCounter() { return false; } + static bool insertBuiltinUnpredictable() { return false; } // Missing types static bool dataMemberType() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index f01e03a89981d..a12ec878e3656 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -316,6 +316,106 @@ void CIRGenFunction::emitIgnoredExpr(const Expr *e) { emitLValue(e); } +/// Emit an `if` on a boolean condition, filling `then` and `else` into +/// appropriated regions. +mlir::LogicalResult CIRGenFunction::emitIfOnBoolExpr(const Expr *cond, + const Stmt *thenS, + const Stmt *elseS) { + // Attempt to be more accurate as possible with IfOp location, generate + // one fused location that has either 2 or 4 total locations, depending + // on else's availability. + auto getStmtLoc = [this](const Stmt &s) { + return mlir::FusedLoc::get(&getMLIRContext(), + {getLoc(s.getSourceRange().getBegin()), + getLoc(s.getSourceRange().getEnd())}); + }; + mlir::Location thenLoc = getStmtLoc(*thenS); + std::optional<mlir::Location> elseLoc; + if (elseS) + elseLoc = getStmtLoc(*elseS); + + mlir::LogicalResult resThen = mlir::success(), resElse = mlir::success(); + emitIfOnBoolExpr( + cond, /*thenBuilder=*/ + [&](mlir::OpBuilder &, mlir::Location) { + LexicalScope lexScope{*this, thenLoc, builder.getInsertionBlock()}; + resThen = emitStmt(thenS, /*useCurrentScope=*/true); + }, + thenLoc, + /*elseBuilder=*/ + [&](mlir::OpBuilder &, mlir::Location) { + assert(elseLoc && "Invalid location for elseS."); + LexicalScope lexScope{*this, *elseLoc, builder.getInsertionBlock()}; + resElse = emitStmt(elseS, /*useCurrentScope=*/true); + }, + elseLoc); + + return mlir::LogicalResult::success(resThen.succeeded() && + resElse.succeeded()); +} + +/// Emit an `if` on a boolean condition, filling `then` and `else` into +/// appropriated regions. +cir::IfOp CIRGenFunction::emitIfOnBoolExpr( + const clang::Expr *cond, BuilderCallbackRef thenBuilder, + mlir::Location thenLoc, BuilderCallbackRef elseBuilder, + std::optional<mlir::Location> elseLoc) { + + SmallVector<mlir::Location, 2> ifLocs{thenLoc}; + if (elseLoc) + ifLocs.push_back(*elseLoc); + mlir::Location loc = mlir::FusedLoc::get(&getMLIRContext(), ifLocs); + + // Emit the code with the fully general case. + mlir::Value condV = emitOpOnBoolExpr(loc, cond); + return builder.create<cir::IfOp>(loc, condV, elseLoc.has_value(), + /*thenBuilder=*/thenBuilder, + /*elseBuilder=*/elseBuilder); +} + +/// TODO(cir): PGO data +/// TODO(cir): see EmitBranchOnBoolExpr for extra ideas). +mlir::Value CIRGenFunction::emitOpOnBoolExpr(mlir::Location loc, + const Expr *cond) { + // TODO(CIR): scoped ApplyDebugLocation DL(*this, Cond); + // TODO(CIR): __builtin_unpredictable and profile counts? + cond = cond->IgnoreParens(); + + // if (const BinaryOperator *CondBOp = dyn_cast<BinaryOperator>(cond)) { + // llvm_unreachable("binaryoperator ifstmt NYI"); + // } + + if (const UnaryOperator *CondUOp = dyn_cast<UnaryOperator>(cond)) { + // In LLVM the condition is reversed here for efficient codegen. + // This should be done in CIR prior to LLVM lowering, if we do now + // we can make CIR based diagnostics misleading. + // cir.ternary(!x, t, f) -> cir.ternary(x, f, t) + assert(!cir::MissingFeatures::shouldReverseUnaryCondOnBoolExpr()); + } + + if (const ConditionalOperator *CondOp = dyn_cast<ConditionalOperator>(cond)) { + + cgm.errorNYI(cond->getExprLoc(), "Ternary NYI"); + assert(!cir::MissingFeatures::ternaryOp()); + return createDummyValue(loc, cond->getType()); + } + + // if (const CXXThrowExpr *Throw = dyn_cast<CXXThrowExpr>(cond)) { + // llvm_unreachable("NYI"); + // } + + // If the branch has a condition wrapped by __builtin_unpredictable, + // create metadata that specifies that the branch is unpredictable. + // Don't bother if not optimizing because that metadata would not be used. + auto *Call = dyn_cast<CallExpr>(cond->IgnoreImpCasts()); + if (Call && cgm.getCodeGenOpts().OptimizationLevel != 0) { + assert(!cir::MissingFeatures::insertBuiltinUnpredictable()); + } + + // Emit the code with the fully general case. + return evaluateExprAsBool(cond); +} + mlir::Value CIRGenFunction::emitAlloca(StringRef name, mlir::Type ty, mlir::Location loc, CharUnits alignment, bool insertIntoFnEntryBlock, diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 2cf92dfbf3a5b..5d85bc6267e8e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1358,6 +1358,20 @@ mlir::Value CIRGenFunction::emitScalarConversion(mlir::Value src, .emitScalarConversion(src, srcTy, dstTy, loc); } +/// If the specified expression does not fold +/// to a constant, or if it does but contains a label, return false. If it +/// constant folds return true and set the boolean result in Result. +bool CIRGenFunction::ConstantFoldsToSimpleInteger(const Expr *Cond, + bool &ResultBool, + bool AllowLabels) { + llvm::APSInt ResultInt; + if (!ConstantFoldsToSimpleInteger(Cond, ResultInt, AllowLabels)) + return false; + + ResultBool = ResultInt.getBoolValue(); + return true; +} + /// Return the size or alignment of the type of argument of the sizeof /// expression as an integer. mlir::Value ScalarExprEmitter::VisitUnaryExprOrTypeTraitExpr( diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 47fc90836fca6..6510ce7985ead 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -135,6 +135,55 @@ mlir::Location CIRGenFunction::getLoc(mlir::Location lhs, mlir::Location rhs) { return mlir::FusedLoc::get(locs, metadata, &getMLIRContext()); } +bool CIRGenFunction::ContainsLabel(const Stmt *s, bool ignoreCaseStmts) { + // Null statement, not a label! + if (!s) + return false; + + // If this is a label, we have to emit the code, consider something like: + // if (0) { ... foo: bar(); } goto foo; + // + // TODO: If anyone cared, we could track __label__'s, since we know that you + // can't jump to one from outside their declared region. + if (isa<LabelStmt>(s)) + return true; + + // If this is a case/default statement, and we haven't seen a switch, we + // have to emit the code. + if (isa<SwitchCase>(s) && !ignoreCaseStmts) + return true; + + // If this is a switch statement, we want to ignore cases below it. + if (isa<SwitchStmt>(s)) + ignoreCaseStmts = true; + + // Scan subexpressions for verboten labels. + return std::any_of(s->child_begin(), s->child_end(), + [=](const Stmt *subStmt) { + return ContainsLabel(subStmt, ignoreCaseStmts); + }); +} + +/// If the specified expression does not fold +/// to a constant, or if it does but contains a label, return false. If it +/// constant folds return true and set the folded value. +bool CIRGenFunction::ConstantFoldsToSimpleInteger(const Expr *cond, + llvm::APSInt &resultInt, + bool allowLabels) { + // FIXME: Rename and handle conversion of other evaluatable things + // to bool. + Expr::EvalResult result; + if (!cond->EvaluateAsInt(result, getContext())) + return false; // Not foldable, not integer or not fully evaluatable. + + llvm::APSInt intValue = result.Val.getInt(); + if (!allowLabels && ContainsLabel(cond)) + return false; // Contains a label. + + resultInt = intValue; + return true; +} + void CIRGenFunction::emitAndUpdateRetAlloca(QualType type, mlir::Location loc, CharUnits alignment) { if (!type->isVoidType()) { diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 5cae4d5da9516..15b25d8a81522 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -24,6 +24,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/CharUnits.h" #include "clang/AST/Decl.h" +#include "clang/AST/Stmt.h" #include "clang/AST/Type.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/MissingFeatures.h" @@ -164,6 +165,20 @@ class CIRGenFunction : public CIRGenTypeCache { /// that it requires no code to be generated. bool isTrivialInitializer(const Expr *init); + /// If the specified expression does not fold to a constant, or if it does but + /// contains a label, return false. If it constant folds return true and set + /// the boolean result in Result. + bool ConstantFoldsToSimpleInteger(const clang::Expr *Cond, bool &ResultBool, + bool AllowLabels = false); + bool ConstantFoldsToSimpleInteger(const clang::Expr *Cond, + llvm::APSInt &ResultInt, + bool AllowLabels = false); + + /// Return true if the statement contains a label in it. If + /// this statement is not executed normally, it not containing a label means + /// that we can just remove the code. + bool ContainsLabel(const clang::Stmt *s, bool IgnoreCaseStmts = false); + struct AutoVarEmission { const clang::VarDecl *Variable; /// The address of the alloca for languages with explicit address space @@ -442,6 +457,25 @@ class CIRGenFunction : public CIRGenTypeCache { mlir::LogicalResult emitDeclStmt(const clang::DeclStmt &s); LValue emitDeclRefLValue(const clang::DeclRefExpr *e); + /// Emit an if on a boolean condition to the specified blocks. + /// FIXME: Based on the condition, this might try to simplify the codegen of + /// the conditional based on the branch. TrueCount should be the number of + /// times we expect the condition to evaluate to true based on PGO data. We + /// might decide to leave this as a separate pass (see EmitBranchOnBoolExpr + /// for extra ideas). + mlir::LogicalResult emitIfOnBoolExpr(const clang::Expr *cond, + const clang::Stmt *thenS, + const clang::Stmt *elseS); + cir::IfOp emitIfOnBoolExpr(const clang::Expr *cond, + BuilderCallbackRef thenBuilder, + mlir::Location thenLoc, + BuilderCallbackRef elseBuilder, + std::optional<mlir::Location> elseLoc = {}); + + mlir::Value emitOpOnBoolExpr(mlir::Location loc, const clang::Expr *cond); + + mlir::LogicalResult emitIfStmt(const clang::IfStmt &s); + /// Emit code to compute the specified expression, /// ignoring the result. void emitIgnoredExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index b5c1f0ae2a7ef..00a745b7196a0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Builders.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/Stmt.h" +#include "clang/CIR/MissingFeatures.h" using namespace clang; using namespace clang::CIRGen; @@ -72,7 +73,8 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s, assert(outgoing && "expression emission cleared block!"); return mlir::success(); } - + case Stmt::IfStmtClass: + return emitIfStmt(cast<IfStmt>(*s)); case Stmt::ForStmtClass: return emitForStmt(cast<ForStmt>(*s)); case Stmt::WhileStmtClass: @@ -99,7 +101,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s, case Stmt::CaseStmtClass: case Stmt::SEHLeaveStmtClass: case Stmt::SYCLKernelCallStmtClass: - case Stmt::IfStmtClass: case Stmt::SwitchStmtClass: case Stmt::CoroutineBodyStmtClass: case Stmt::CoreturnStmtClass: @@ -263,6 +264,72 @@ static void terminateBody(CIRGenBuilderTy &builder, mlir::Region &r, b->erase(); } +mlir::LogicalResult CIRGenFunction::emitIfStmt(const IfStmt &s) { + mlir::LogicalResult res = mlir::success(); + // The else branch of a consteval if statement is always the only branch + // that can be runtime evaluated. + const Stmt *ConstevalExecuted; + if (s.isConsteval()) { + ConstevalExecuted = s.isNegatedConsteval() ? s.getThen() : s.getElse(); + if (!ConstevalExecuted) { + // No runtime code execution required + return res; + } + } + + // C99 6.8.4.1: The first substatement is executed if the expression + // compares unequal to 0. The condition must be a scalar type. + auto ifStmtBuilder = [&]() -> mlir::LogicalResult { + if (s.isConsteval()) + return emitStmt(ConstevalExecuted, /*useCurrentScope=*/true); + + if (s.getInit()) + if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed()) + return mlir::failure(); + + if (s.getConditionVariable()) + emitDecl(*s.getConditionVariable()); + + // During LLVM codegen, if the condition constant folds and can be elided, + // it tries to avoid emitting the condition and the dead arm of the if/else. + // TODO(cir): we skip this in CIRGen, but should implement this as part of + // SSCP or a specific CIR pass. + bool CondConstant; + if (ConstantFoldsToSimpleInteger(s.getCond(), CondConstant, + s.isConstexpr())) { + if (s.isConstexpr()) { + // Handle "if constexpr" explicitly here to avoid generating some + // ill-formed code since in CIR the "if" is no longer simplified + // in this lambda like in Clang but postponed to other MLIR + // passes. + if (const Stmt *Executed = CondConstant ? s.getThen() : s.getElse()) + return emitStmt(Executed, /*useCurrentScope=*/true); + // There is nothing to execute at runtime. + // TODO(cir): there is still an empty cir.scope generated by the caller. + return mlir::success(); + } + assert(!cir::MissingFeatures::constantFoldsToSimpleInteger()); + } + + assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic()); + assert(!cir::MissingFeatures::incrementProfileCounter()); + return emitIfOnBoolExpr(s.getCond(), s.getThen(), s.getElse()); + }; + + // TODO: Add a new scoped symbol table. + // LexicalScope ConditionScope(*this, S.getCond()->getSourceRange()); + // The if scope contains the full source range for IfStmt. + mlir::Location scopeLoc = getLoc(s.getSourceRange()); + builder.create<cir::ScopeOp>( + scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, scopeLoc, builder.getInsertionBlock()}; + res = ifStmtBuilder(); + }); + + return res; +} + mlir::LogicalResult CIRGenFunction::emitDeclStmt(const DeclStmt &s) { assert(builder.getInsertionBlock() && "expected valid insertion point"); diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 4ace083e3c081..7877f2601a245 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -14,6 +14,7 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LogicalResult.h" @@ -447,6 +448,133 @@ mlir::LogicalResult cir::ReturnOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { + // create the regions for 'then'. + result.regions.reserve(2); + Region *thenRegion = result.addRegion(); + Region *elseRegion = result.addRegion(); + + mlir::Builder &builder = parser.getBuilder(); + OpAsmParser::UnresolvedOperand cond; + Type boolType = cir::BoolType::get(builder.getContext()); + + if (parser.parseOperand(cond) || + parser.resolveOperand(cond, boolType, result.operands)) + return failure(); + + // Parse 'then' region. + mlir::SMLoc parseThenLoc = parser.getCurrentLocation(); + if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed()) + return failure(); + + // If we find an 'else' keyword, parse the 'else' region. + if (!parser.parseOptionalKeyword("else")) { + mlir::SMLoc parseElseLoc = parser.getCurrentLocation(); + if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed()) + return failure(); + } + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +void cir::IfOp::print(OpAsmPrinter &p) { + + p << " " << getCondition() << " "; + mlir::Region &thenRegion = this->getThenRegion(); + p.printRegion(thenRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/!omitRegionTerm(thenRegion)); + + // Print the 'else' regions if it exists and has a block. + mlir::Region &elseRegion = this->getElseRegion(); + if (!elseRegion.empty()) { + p << " else "; + p.printRegion(elseRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/!omitRegionTerm(elseRegion)); + } + + p.printOptionalAttrDict(getOperation()->getAttrs()); +} + +/// Default callback for IfOp builders. +void cir::buildTerminatedBody(OpBuilder &builder, Location loc) { + // add cir.yield to end of the block + builder.create<cir::YieldOp>(loc); +} + +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void cir::IfOp::getSuccessorRegions(mlir::RegionBranchPoint point, + SmallVectorImpl<RegionSuccessor> ®ions) { + // The `then` and the `else` region branch back to the parent operation. + if (!point.isParent()) { + regions.push_back(RegionSuccessor()); + return; + } + + // Don't consider the else region if it is empty. + Region *elseRegion = &this->getElseRegion(); + if (elseRegion->empty()) + elseRegion = nullptr; + + // Otherwise, the successor is dependent on the condition. + // bool condition; + // if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) { + // assert(0 && "not implemented"); + // condition = condAttr.getValue().isOneValue(); + // Add the successor regions using the condition. + // regions.push_back(RegionSuccessor(condition ? &thenRegion() : + // elseRegion)); + // return; + // } + + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&getThenRegion())); + // If the else region does not exist, it is not a viable successor. + if (elseRegion) + regions.push_back(RegionSuccessor(elseRegion)); + + return; +} + +void cir::IfOp::build(OpBuilder &builder, OperationState &result, Value cond, + bool withElseRegion, BuilderCallbackRef thenBuilder, + BuilderCallbackRef elseBuilder) { + assert(thenBuilder && "the builder callback for 'then' must be present"); + result.addOperands(cond); + + OpBuilder::InsertionGuard guard(builder); + Region *thenRegion = result.addRegion(); + builder.createBlock(thenRegion); + thenBuilder(builder, result.location); + + Region *elseRegion = result.addRegion(); + if (!withElseRegion) { + return; + } + + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); +} + +LogicalResult cir::IfOp::verify() { return success(); } + //===----------------------------------------------------------------------===// // ScopeOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 52f4b2241505d..ea2b46a6a67f9 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -13,6 +13,8 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -54,6 +56,67 @@ struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> { void runOnOperation() override; }; +struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> { + using OpRewritePattern<IfOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cir::IfOp ifOp, + mlir::PatternRewriter &rewriter) const override { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::Location loc = ifOp.getLoc(); + bool emptyElse = ifOp.getElseRegion().empty(); + mlir::Block *currentBlock = rewriter.getInsertionBlock(); + mlir::Block *remainingOpsBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + mlir::Block *continueBlock; + if (ifOp->getResults().empty()) + continueBlock = remainingOpsBlock; + else + llvm_unreachable("NYI"); + + // Inline the region + mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front(); + mlir::Block *thenAfterBody = &ifOp.getThenRegion().back(); + rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock); + + rewriter.setInsertionPointToEnd(thenAfterBody); + if (auto thenYieldOp = + dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) { + rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(), + continueBlock); + } + + rewriter.setInsertionPointToEnd(continueBlock); + + // Has else region: inline it. + mlir::Block *elseBeforeBody = nullptr; + mlir::Block *elseAfterBody = nullptr; + if (!emptyElse) { + elseBeforeBody = &ifOp.getElseRegion().front(); + elseAfterBody = &ifOp.getElseRegion().back(); + rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock); + } else { + elseBeforeBody = elseAfterBody = continueBlock; + } + + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody, + elseBeforeBody); + + if (!emptyElse) { + rewriter.setInsertionPointToEnd(elseAfterBody); + if (auto elseYieldOP = + dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) { + rewriter.replaceOpWithNewOp<cir::BrOp>( + elseYieldOP, elseYieldOP.getArgs(), continueBlock); + } + } + + rewriter.replaceOp(ifOp, continueBlock->getArguments()); + return mlir::success(); + } +}; + class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { public: using OpRewritePattern<cir::ScopeOp>::OpRewritePattern; @@ -191,8 +254,9 @@ class CIRLoopOpInterfaceFlattening }; void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns.add<CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>( - patterns.getContext()); + patterns + .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -206,7 +270,7 @@ void CIRFlattenCFGPass::runOnOperation() { assert(!cir::MissingFeatures::switchOp()); assert(!cir::MissingFeatures::ternaryOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<ScopeOp, LoopOpInterface>(op)) + if (isa<IfOp, ScopeOp, LoopOpInterface>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/CodeGen/if.cpp b/clang/test/CIR/CodeGen/if.cpp new file mode 100644 index 0000000000000..d1be063bb529b --- /dev/null +++ b/clang/test/CIR/CodeGen/if.cpp @@ -0,0 +1,254 @@ +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG + +int if0(bool a) { + + if (a) + return 2; + + return 3; + +} + +// CIR: cir.func @if0(%arg0: !cir.bool loc({{.*}})) -> !s32i +// CIR: cir.scope { +// CIR: %4 = cir.load %0 : !cir.ptr<!cir.bool>, !cir.bool +// CIR-NEXT: cir.if %4 { +// CIR-NEXT: %5 = cir.const #cir.int<2> : !s32i +// CIR-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i +// CIR-NEXT: cir.return %6 : !s32i +// CIR-NEXT: } +// CIR-NEXT: } + + +// LLVM: define i32 @if0(i1 %0) +// LLVM: br label %[[ENTRY:.*]] +// LLVM: [[ENTRY]]: +// LLVM: %6 = load i8, ptr %2, align 1 +// LLVM: %7 = trunc i8 %6 to i1 +// LLVM: br i1 %7, label %[[THEN:.*]], label %[[END:.*]] +// LLVM: [[THEN]]: +// LLVM: store i32 2, ptr %3, align 4 +// LLVM: %9 = load i32, ptr %3, align 4 +// LLVM: ret i32 %9 +// LLVM: [[END]]: +// LLVM: br label %[[LABEL4:.*]] +// LLVM: [[LABEL4]]: +// LLVM: store i32 3, ptr %3, align 4 +// LLVM: %12 = load i32, ptr %3, align 4 +// LLVM: ret i32 %12 + +// OGCG: define dso_local noundef i32 @_Z3if0b(i1 noundef zeroext %a) +// OGCG: entry: +// OGCG: %[[RETVAL:.*]] = alloca i32, align 4 +// OGCG: %[[A_ADDR:.*]] = alloca i8, align 1 +// OGCG: %[[STOREDV:.*]] = zext i1 %a to i8 +// OGCG: store i8 %[[STOREDV]], ptr %[[A_ADDR]], align 1 +// OGCG: %[[LOADTMP:.*]] = load i8, ptr %[[A_ADDR]], align 1 +// OGCG: %[[LOADEDV:.*]] = trunc i8 %[[LOADTMP]] to i1 +// OGCG: br i1 %[[LOADEDV]], label %[[THEN_LABEL:.*]], label %[[END_LABEL:.*]] +// OGCG: [[THEN_LABEL]]: +// OGCG: store i32 2, ptr %[[RETVAL]], align 4 +// OGCG: br label %[[RETURN_LABEL:.*]] +// OGCG: [[END_LABEL]]: +// OGCG: store i32 3, ptr %[[RETVAL]], align 4 +// OGCG: br label %[[RETURN_LABEL]] +// OGCG: [[RETURN_LABEL]]: +// OGCG: %[[FINALLOAD:.*]] = load i32, ptr %[[RETVAL]], align 4 +// OGCG: ret i32 %[[FINALLOAD]] + +void if1(int a) { + int x = 0; + if (a) { + x = 3; + } else { + x = 4; + } +} + +// CIR: cir.func @if1(%arg0: !s32i loc({{.*}})) +// CIR: cir.scope { +// CIR: %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CIR: %4 = cir.cast(int_to_bool, %3 : !s32i), !cir.bool +// CIR-NEXT: cir.if %4 { +// CIR-NEXT: %5 = cir.const #cir.int<3> : !s32i +// CIR-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: } else { +// CIR-NEXT: %5 = cir.const #cir.int<4> : !s32i +// CIR-NEXT: cir.store %5, %1 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: } +// CIR: } + +// LLVM: define void @if1(i32 %0) +// LLVM: %[[A:.*]] = alloca i32, i64 1, align 4 +// LLVM: %[[X:.*]] = alloca i32, i64 1, align 4 +// LLVM: store i32 %0, ptr %[[A]], align 4 +// LLVM: store i32 0, ptr %[[X]], align 4 +// LLVM: br label %[[ENTRY:.*]] +// LLVM: [[ENTRY]]: +// LLVM: %[[LOADED:.*]] = load i32, ptr %[[A]], align 4 +// LLVM: %[[COND:.*]] = icmp ne i32 %[[LOADED]], 0 +// LLVM: br i1 %[[COND]], label %[[THEN:.*]], label %[[ELSE:.*]] +// LLVM: [[THEN]]: +// LLVM: store i32 3, ptr %[[X]], align 4 +// LLVM: br label %[[END:.*]] +// LLVM: [[ELSE]]: +// LLVM: store i32 4, ptr %[[X]], align 4 +// LLVM: br label %[[END]] +// LLVM: [[END]]: +// LLVM: br label %[[EXIT:.*]] +// LLVM: [[EXIT]]: +// LLVM: ret void + +// OGCG: define dso_local void @_Z3if1i(i32 noundef %[[A:.*]]) +// OGCG: entry: +// OGCG: %[[A_ADDR:.*]] = alloca i32, align 4 +// OGCG: %[[X:.*]] = alloca i32, align 4 +// OGCG: store i32 %[[A]], ptr %[[A_ADDR]], align 4 +// OGCG: store i32 0, ptr %[[X]], align 4 +// OGCG: %[[LOADED_A:.*]] = load i32, ptr %[[A_ADDR]], align 4 +// OGCG: %[[TOBOOL:.*]] = icmp ne i32 %[[LOADED_A]], 0 +// OGCG: br i1 %[[TOBOOL]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]] +// OGCG: [[THEN_LABEL]]: +// OGCG: store i32 3, ptr %[[X]], align 4 +// OGCG: br label %[[END_LABEL:.*]] +// OGCG: [[ELSE_LABEL]]: +// OGCG: store i32 4, ptr %[[X]], align 4 +// OGCG: br label %[[END_LABEL]] +// OGCG: [[END_LABEL]]: +// OGCG: ret void + +void if2(int a, bool b, bool c) { + int x = 0; + if (a) { + x = 3; + if (b) { + x = 8; + } + } else { + if (c) { + x = 14; + } + x = 4; + } +} + +// CIR: cir.func @if2(%arg0: !s32i loc({{.*}}), %arg1: !cir.bool loc({{.*}}), %arg2: !cir.bool loc({{.*}})) +// CIR: cir.scope { +// CIR: %5 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CIR: %6 = cir.cast(int_to_bool, %5 : !s32i), !cir.bool +// CIR: cir.if %6 { +// CIR: %7 = cir.const #cir.int<3> : !s32i +// CIR: cir.store %7, %3 : !s32i, !cir.ptr<!s32i> +// CIR: cir.scope { +// CIR: %8 = cir.load %1 : !cir.ptr<!cir.bool>, !cir.bool +// CIR-NEXT: cir.if %8 { +// CIR-NEXT: %9 = cir.const #cir.int<8> : !s32i +// CIR-NEXT: cir.store %9, %3 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: } +// CIR: } +// CIR: } else { +// CIR: cir.scope { +// CIR: %8 = cir.load %2 : !cir.ptr<!cir.bool>, !cir.bool +// CIR-NEXT: cir.if %8 { +// CIR-NEXT: %9 = cir.const #cir.int<14> : !s32i +// CIR-NEXT: cir.store %9, %3 : !s32i, !cir.ptr<!s32i> +// CIR-NEXT: } +// CIR: } +// CIR: %7 = cir.const #cir.int<4> : !s32i +// CIR: cir.store %7, %3 : !s32i, !cir.ptr<!s32i> +// CIR: } +// CIR: } + +// LLVM: define void @if2(i32 %[[A:.*]], i1 %[[B:.*]], i1 %[[C:.*]]) +// LLVM: %[[VARA:.*]] = alloca i32, i64 1, align 4 +// LLVM: %[[VARB:.*]] = alloca i8, i64 1, align 1 +// LLVM: %[[VARC:.*]] = alloca i8, i64 1, align 1 +// LLVM: %[[VARX:.*]] = alloca i32, i64 1, align 4 +// LLVM: store i32 %[[A]], ptr %[[VARA]], align 4 +// LLVM: %[[B_EXT:.*]] = zext i1 %[[B]] to i8 +// LLVM: store i8 %[[B_EXT]], ptr %[[VARB]], align 1 +// LLVM: %[[C_EXT:.*]] = zext i1 %[[C]] to i8 +// LLVM: store i8 %[[C_EXT]], ptr %[[VARC]], align 1 +// LLVM: store i32 0, ptr %[[VARX]], align 4 +// LLVM: br label %[[ENTRY:.*]] +// LLVM: [[ENTRY]]: +// LLVM: %[[LOAD_A:.*]] = load i32, ptr %[[VARA]], align 4 +// LLVM: %[[CMP_A:.*]] = icmp ne i32 %[[LOAD_A]], 0 +// LLVM: br i1 %[[CMP_A]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +// LLVM: [[IF_THEN]]: +// LLVM: store i32 3, ptr %[[VARX]], align 4 +// LLVM: br label %[[LABEL14:.*]] +// LLVM: [[LABEL14]]: +// LLVM: %[[LOAD_B:.*]] = load i8, ptr %[[VARB]], align 1 +// LLVM: %[[TRUNC_B:.*]] = trunc i8 %[[LOAD_B]] to i1 +// LLVM: br i1 %[[TRUNC_B]], label %[[IF_THEN2:.*]], label %[[IF_END2:.*]] +// LLVM: [[IF_THEN2]]: +// LLVM: store i32 8, ptr %[[VARX]], align 4 +// LLVM: br label %[[IF_END2]] +// LLVM: [[IF_END2]]: +// LLVM: br label %[[LABEL19:.*]] +// LLVM: [[LABEL19]]: +// LLVM: br label %[[LABEL27:.*]] +// LLVM: [[IF_ELSE]]: +// LLVM: br label %[[LABEL21:.*]] +// LLVM: [[LABEL21]]: +// LLVM: %[[LOAD_C:.*]] = load i8, ptr %[[VARC]], align 1 +// LLVM: %[[TRUNC_C:.*]] = trunc i8 %[[LOAD_C]] to i1 +// LLVM: br i1 %[[TRUNC_C]], label %[[IF_THEN3:.*]], label %[[IF_END3:.*]] +// LLVM: [[IF_THEN3]]: +// LLVM: store i32 14, ptr %[[VARX]], align 4 +// LLVM: br label %[[IF_END3]] +// LLVM: [[IF_END3]]: +// LLVM: br label %[[LABEL26:.*]] +// LLVM: [[LABEL26]]: +// LLVM: store i32 4, ptr %[[VARX]], align 4 +// LLVM: br label %[[LABEL27]] +// LLVM: [[LABEL27]]: +// LLVM: br label %[[LABEL28:.*]] +// LLVM: [[LABEL28]]: +// LLVM: ret void + +// OGCG: define dso_local void @_Z3if2ibb(i32 noundef %[[A:.*]], i1 noundef zeroext %[[B:.*]], i1 noundef zeroext %[[C:.*]]) +// OGCG: entry: +// OGCG: %[[A_ADDR:.*]] = alloca i32, align 4 +// OGCG: %[[B_ADDR:.*]] = alloca i8, align 1 +// OGCG: %[[C_ADDR:.*]] = alloca i8, align 1 +// OGCG: %[[X:.*]] = alloca i32, align 4 +// OGCG: store i32 %[[A]], ptr %[[A_ADDR]], align 4 +// OGCG: %[[B_EXT:.*]] = zext i1 %[[B]] to i8 +// OGCG: store i8 %[[B_EXT]], ptr %[[B_ADDR]], align 1 +// OGCG: %[[C_EXT:.*]] = zext i1 %[[C]] to i8 +// OGCG: store i8 %[[C_EXT]], ptr %[[C_ADDR]], align 1 +// OGCG: store i32 0, ptr %[[X]], align 4 +// OGCG: %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4 +// OGCG: %[[A_BOOL:.*]] = icmp ne i32 %[[A_VAL]], 0 +// OGCG: br i1 %[[A_BOOL]], label %[[IF_THEN:.*]], label %[[IF_ELSE:.*]] +// OGCG: [[IF_THEN]]: +// OGCG: store i32 3, ptr %[[X]], align 4 +// OGCG: %[[B_LOAD:.*]] = load i8, ptr %[[B_ADDR]], align 1 +// OGCG: %[[B_TRUNC:.*]] = trunc i8 %[[B_LOAD]] to i1 +// OGCG: br i1 %[[B_TRUNC]], label %[[IF_THEN2:.*]], label %[[IF_END:.*]] +// OGCG: [[IF_THEN2]]: +// OGCG: store i32 8, ptr %[[X]], align 4 +// OGCG: br label %[[IF_END]] +// OGCG: [[IF_END]]: +// OGCG: br label %[[IF_END6:.*]] +// OGCG: [[IF_ELSE]]: +// OGCG: %[[C_LOAD:.*]] = load i8, ptr %[[C_ADDR]], align 1 +// OGCG: %[[C_TRUNC:.*]] = trunc i8 %[[C_LOAD]] to i1 +// OGCG: br i1 %[[C_TRUNC]], label %[[IF_THEN4:.*]], label %[[IF_END5:.*]] +// OGCG: [[IF_THEN4]]: +// OGCG: store i32 14, ptr %[[X]], align 4 +// OGCG: br label %[[IF_END5]] +// OGCG: [[IF_END5]]: +// OGCG: store i32 4, ptr %[[X]], align 4 +// OGCG: br label %[[IF_END6]] +// OGCG: [[IF_END6]]: +// OGCG: ret void + diff --git a/clang/test/CIR/Lowering/if.cir b/clang/test/CIR/Lowering/if.cir new file mode 100644 index 0000000000000..3a077aa9ef057 --- /dev/null +++ b/clang/test/CIR/Lowering/if.cir @@ -0,0 +1,99 @@ +// RUN: cir-opt %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR +// RUN: cir-translate %s -cir-to-llvmir --target x86_64-unknown-linux-gnu --disable-cc-lowering | FileCheck %s -check-prefix=LLVM +!s32i = !cir.int<s, 32> + +module { + cir.func @foo(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const #cir.int<1> : !s32i + cir.return %5 : !s32i + } else { + %5 = cir.const #cir.int<0> : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } + +// MLIR: llvm.func @foo(%arg0: i32) -> i32 +// MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 +// MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 +// MLIR-NEXT: llvm.cond_br %1, ^bb1, ^bb2 +// MLIR-NEXT: ^bb1: // pred: ^bb0 +// MLIR-NEXT: %2 = llvm.mlir.constant(1 : i32) : i32 +// MLIR-NEXT: llvm.return %2 : i32 +// MLIR-NEXT: ^bb2: // pred: ^bb0 +// MLIR-NEXT: %3 = llvm.mlir.constant(0 : i32) : i32 +// MLIR-NEXT: llvm.return %3 : i32 +// MLIR-NEXT: ^bb3: // no predecessors +// MLIR-NEXT: llvm.return %arg0 : i32 +// MLIR-NEXT: } + +// LLVM: define i32 @foo(i32 %0) +// LLVM-NEXT: %2 = icmp ne i32 %0, 0 +// LLVM-NEXT: br i1 %2, label %3, label %4 +// LLVM-EMPTY: +// LLVM-NEXT: 3: +// LLVM-NEXT: ret i32 1 +// LLVM-EMPTY: +// LLVM-NEXT: 4: +// LLVM-NEXT: ret i32 0 +// LLVM-EMPTY: +// LLVM-NEXT: 5: +// LLVM-NEXT: ret i32 %0 +// LLVM-NEXT: } + + cir.func @onlyIf(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const #cir.int<1> : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } + + // MLIR: llvm.func @onlyIf(%arg0: i32) -> i32 + // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 + // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 + // MLIR-NEXT: llvm.cond_br %1, ^bb1, ^bb2 + // MLIR-NEXT: ^bb1: // pred: ^bb0 + // MLIR-NEXT: %2 = llvm.mlir.constant(1 : i32) : i32 + // MLIR-NEXT: llvm.return %2 : i32 + // MLIR-NEXT: ^bb2: // pred: ^bb0 + // MLIR-NEXT: llvm.return %arg0 : i32 + // MLIR-NEXT: } + + // Verify empty if clause is properly lowered to empty block + cir.func @emptyIfClause(%arg0: !s32i) -> !s32i { + // MLIR-LABEL: llvm.func @emptyIfClause + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + // MLIR: llvm.cond_br {{%.*}}, ^[[T:.*]], ^[[PHI:.*]] + cir.if %4 { + // MLIR-NEXT: ^[[T]]: + // MLIR-NEXT: llvm.br ^[[PHI]] + } + // MLIR-NEXT: ^[[PHI]]: + // MLIR-NEXT: llvm.return + cir.return %arg0 : !s32i + } + + // Verify empty if-else clauses are properly lowered to empty blocks + // TODO: Fix reversed order of blocks in the test once Issue clangir/#1094 is + // addressed + cir.func @emptyIfElseClause(%arg0: !s32i) -> !s32i { + // MLIR-LABEL: llvm.func @emptyIfElseClause + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + // MLIR: llvm.cond_br {{%.*}}, ^[[T:.*]], ^[[F:.*]] + cir.if %4 { + // MLIR-NEXT: ^[[T]]: + // MLIR-NEXT: llvm.br ^[[PHI:.*]] + } else { + // MLIR-NEXT: ^[[F]]: + // MLIR-NEXT: llvm.br ^[[PHI]] + } + // MLIR-NEXT: ^[[PHI]]: + // MLIR-NEXT: llvm.return + cir.return %arg0 : !s32i + } + +} diff --git a/clang/test/CIR/Transforms/if.cir b/clang/test/CIR/Transforms/if.cir new file mode 100644 index 0000000000000..03848bf8d0633 --- /dev/null +++ b/clang/test/CIR/Transforms/if.cir @@ -0,0 +1,48 @@ +// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @foo(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const #cir.int<1> : !s32i + cir.return %5 : !s32i + } else { + %5 = cir.const #cir.int<0> : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } +// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i { +// CHECK-NEXT: %0 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool +// CHECK-NEXT: cir.brcond %0 ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %1 = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: cir.return %1 : !s32i +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %2 = cir.const #cir.int<0> : !s32i +// CHECK-NEXT: cir.return %2 : !s32i +// CHECK-NEXT: ^bb3: // no predecessors +// CHECK-NEXT: cir.return %arg0 : !s32i +// CHECK-NEXT: } + + cir.func @onlyIf(%arg0: !s32i) -> !s32i { + %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool + cir.if %4 { + %5 = cir.const #cir.int<1> : !s32i + cir.return %5 : !s32i + } + cir.return %arg0 : !s32i + } +// CHECK: cir.func @onlyIf(%arg0: !s32i) -> !s32i { +// CHECK-NEXT: %0 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool +// CHECK-NEXT: cir.brcond %0 ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %1 = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: cir.return %1 : !s32i +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: cir.return %arg0 : !s32i +// CHECK-NEXT: } + +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits