https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/94693
>From 065f965d6649a56a27c39321a553188b4230f5f6 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <yuxuanchen1...@outlook.com> Date: Tue, 4 Jun 2024 23:22:00 -0700 Subject: [PATCH] [Clang] Introduce [[clang::coro_inplace_task]] --- clang/include/clang/AST/ExprCXX.h | 26 ++++-- clang/include/clang/Basic/Attr.td | 8 ++ clang/include/clang/Basic/AttrDocs.td | 19 +++++ clang/lib/CodeGen/CGBlocks.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.cpp | 5 +- clang/lib/CodeGen/CGCUDARuntime.h | 8 +- clang/lib/CodeGen/CGCXXABI.h | 10 +-- clang/lib/CodeGen/CGClass.cpp | 16 ++-- clang/lib/CodeGen/CGCoroutine.cpp | 29 +++++-- clang/lib/CodeGen/CGExpr.cpp | 41 +++++---- clang/lib/CodeGen/CGExprCXX.cpp | 60 +++++++------ clang/lib/CodeGen/CodeGenFunction.h | 64 ++++++++------ clang/lib/CodeGen/ItaniumCXXABI.cpp | 16 ++-- clang/lib/CodeGen/MicrosoftCXXABI.cpp | 18 ++-- clang/lib/Sema/SemaCoroutine.cpp | 54 +++++++++++- clang/lib/Serialization/ASTReaderStmt.cpp | 10 ++- clang/lib/Serialization/ASTWriterStmt.cpp | 3 +- clang/test/CodeGenCoroutines/Inputs/utility.h | 13 +++ .../coro-structured-concurrency.cpp | 84 +++++++++++++++++++ ...a-attribute-supported-attributes-list.test | 1 + llvm/include/llvm/IR/Intrinsics.td | 3 + .../lib/Transforms/Coroutines/CoroCleanup.cpp | 11 ++- llvm/lib/Transforms/Coroutines/CoroElide.cpp | 58 ++++++++++++- llvm/lib/Transforms/Coroutines/Coroutines.cpp | 1 + .../coro-elide-structured-concurrency.ll | 64 ++++++++++++++ 25 files changed, 493 insertions(+), 134 deletions(-) create mode 100644 clang/test/CodeGenCoroutines/Inputs/utility.h create mode 100644 clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp create mode 100644 llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h index c2feac525c1ea..0cf62aee41b66 100644 --- a/clang/include/clang/AST/ExprCXX.h +++ b/clang/include/clang/AST/ExprCXX.h @@ -5082,7 +5082,8 @@ class CoroutineSuspendExpr : public Expr { enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count }; Stmt *SubExprs[SubExpr::Count]; - OpaqueValueExpr *OpaqueValue = nullptr; + OpaqueValueExpr *CommonExprOpaqueValue = nullptr; + OpaqueValueExpr *InplaceCallOpaqueValue = nullptr; public: // These types correspond to the three C++ 'await_suspend' return variants @@ -5090,10 +5091,10 @@ class CoroutineSuspendExpr : public Expr { CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue) + OpaqueValueExpr *CommonExprOpaqueValue) : Expr(SC, Resume->getType(), Resume->getValueKind(), Resume->getObjectKind()), - KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) { + KeywordLoc(KeywordLoc), CommonExprOpaqueValue(CommonExprOpaqueValue) { SubExprs[SubExpr::Operand] = Operand; SubExprs[SubExpr::Common] = Common; SubExprs[SubExpr::Ready] = Ready; @@ -5128,7 +5129,16 @@ class CoroutineSuspendExpr : public Expr { } /// getOpaqueValue - Return the opaque value placeholder. - OpaqueValueExpr *getOpaqueValue() const { return OpaqueValue; } + OpaqueValueExpr *getCommonExprOpaqueValue() const { + return CommonExprOpaqueValue; + } + + OpaqueValueExpr *getInplaceCallOpaqueValue() const { + return InplaceCallOpaqueValue; + } + void setInplaceCallOpaqueValue(OpaqueValueExpr *E) { + InplaceCallOpaqueValue = E; + } Expr *getReadyExpr() const { return static_cast<Expr*>(SubExprs[SubExpr::Ready]); @@ -5194,9 +5204,9 @@ class CoawaitExpr : public CoroutineSuspendExpr { public: CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue, bool IsImplicit = false) + OpaqueValueExpr *CommonExprOpaqueValue, bool IsImplicit = false) : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common, - Ready, Suspend, Resume, OpaqueValue) { + Ready, Suspend, Resume, CommonExprOpaqueValue) { CoawaitBits.IsImplicit = IsImplicit; } @@ -5275,9 +5285,9 @@ class CoyieldExpr : public CoroutineSuspendExpr { public: CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume, - OpaqueValueExpr *OpaqueValue) + OpaqueValueExpr *CommonExprOpaqueValue) : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common, - Ready, Suspend, Resume, OpaqueValue) {} + Ready, Suspend, Resume, CommonExprOpaqueValue) {} CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand, Expr *Common) : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand, diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index b70b0c8b836a5..7c291978a27ed 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1212,6 +1212,14 @@ def CoroDisableLifetimeBound : InheritableAttr { let SimpleHandler = 1; } +def CoroInplaceTask : InheritableAttr { + let Spellings = [Clang<"coro_inplace_task">]; + let Subjects = SubjectList<[CXXRecord]>; + let LangOpts = [CPlusPlus]; + let Documentation = [CoroInplaceTaskDoc]; + let SimpleHandler = 1; +} + // OSObject-based attributes. def OSConsumed : InheritableParamAttr { let Spellings = [Clang<"os_consumed">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 70d5dfa8aaf86..964ab1bc9c70d 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8015,6 +8015,25 @@ but do not pass them to the underlying coroutine or pass them by value. }]; } +def CoroInplaceTaskDoc : Documentation { + let Category = DocCatDecl; + let Content = [{ +The ``[[clang::coro_inplace_task]]`` is a class attribute which can be applied +to a coroutine return type. + +When a coroutine function that returns such a type calls another coroutine function, +the compiler performs heap allocation elision when the following conditions are all met: +- callee coroutine function returns a type that is annotated with ``[[clang::coro_inplace_task]]``. +- The callee coroutine function is inlined. +- In caller coroutine, the return value of the callee is a prvalue or an xvalue, and +- The temporary expression containing the callee coroutine object is immediately co_awaited. + +The behavior is undefined if any of the following condition was met: +- the caller coroutine is destroyed earlier than the callee coroutine. + + }]; +} + def CountedByDocs : Documentation { let Category = DocCatField; let Content = [{ diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp index 5dac1cd425bf6..a602dbfcfeac2 100644 --- a/clang/lib/CodeGen/CGBlocks.cpp +++ b/clang/lib/CodeGen/CGBlocks.cpp @@ -1154,7 +1154,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() { } RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>(); llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee()); llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType(); @@ -1211,7 +1212,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E, CGCallee Callee(CGCalleeInfo(), Func); // And call the block. - return EmitCall(FnInfo, Callee, ReturnValue, Args); + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke); } Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) { diff --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp index c14a9d3f2bbbc..1e1da1e2411a7 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.cpp +++ b/clang/lib/CodeGen/CGCUDARuntime.cpp @@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {} RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok"); llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end"); @@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF, eval.begin(CGF); CGF.EmitBlock(ConfigOKBlock); - CGF.EmitSimpleCallExpr(E, ReturnValue); + CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke); CGF.EmitBranch(ContBlock); CGF.EmitBlock(ContBlock); diff --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h index 8030d632cc3d2..86f776004ee7c 100644 --- a/clang/lib/CodeGen/CGCUDARuntime.h +++ b/clang/lib/CodeGen/CGCUDARuntime.h @@ -21,6 +21,7 @@ #include "llvm/IR/GlobalValue.h" namespace llvm { +class CallBase; class Function; class GlobalVariable; } @@ -82,9 +83,10 @@ class CGCUDARuntime { CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {} virtual ~CGCUDARuntime(); - virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF, - const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + virtual RValue + EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E, + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); /// Emits a kernel launch stub. virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0; diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h index 104a20db8efaf..b38a3b7602e34 100644 --- a/clang/lib/CodeGen/CGCXXABI.h +++ b/clang/lib/CodeGen/CGCXXABI.h @@ -485,11 +485,11 @@ class CGCXXABI { llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>; /// Emit the ABI-specific virtual destructor call. - virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, - Address This, - DeleteOrMemberCallExpr E) = 0; + virtual llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) = 0; virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp index 5a032bdbf9379..d891d97c59bd6 100644 --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -2191,15 +2191,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF, return true; } -void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, - CXXCtorType Type, - bool ForVirtualBase, - bool Delegating, - Address This, - CallArgList &Args, - AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, - bool NewPointerIsChecked) { +void CodeGenFunction::EmitCXXConstructorCall( + const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase, + bool Delegating, Address This, CallArgList &Args, + AggValueSlot::Overlap_t Overlap, SourceLocation Loc, + bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) { const CXXRecordDecl *ClassDecl = D->getParent(); if (!NewPointerIsChecked) @@ -2247,7 +2243,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D, const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall( Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs); CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type)); - EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc); + EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc); // Generate vtable assumptions if we're constructing a complete object // with a vtable. We don't do this for base subobjects for two reasons: diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index b4c724422c14a..79d38db360393 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -12,9 +12,11 @@ #include "CGCleanup.h" #include "CodeGenFunction.h" -#include "llvm/ADT/ScopeExit.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtVisitor.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/IR/Intrinsics.h" using namespace clang; using namespace CodeGen; @@ -223,12 +225,21 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co CoroutineSuspendExpr const &S, AwaitKind Kind, AggValueSlot aggSlot, bool ignoreResult, bool forLValue) { - auto *E = S.getCommonExpr(); + auto &Builder = CGF.Builder; - auto CommonBinder = - CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E); - auto UnbindCommonOnExit = - llvm::make_scope_exit([&] { CommonBinder.unbind(CGF); }); + // If S.getInplaceCallOpaqueValue() is null, we don't have a nested opaque + // value for common expression. + std::optional<CodeGenFunction::OpaqueValueMapping> OperandMapping; + if (auto *CallOV = S.getInplaceCallOpaqueValue()) { + auto *CE = cast<CallExpr>(CallOV->getSourceExpr()); + // TODO: don't use the intrisic coro_safe_elide in the next version. + LValue CallResult = CGF.EmitCallExprLValue(CE, nullptr); + OperandMapping.emplace(CGF, CallOV, CallResult); + llvm::Value *Value = CallResult.getPointer(CGF); + auto SafeElide = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_safe_elide); + Builder.CreateCall(SafeElide, Value); + } + CodeGenFunction::OpaqueValueMapping BindCommon(CGF, S.getCommonExprOpaqueValue()); auto Prefix = buildSuspendPrefixStr(Coro, Kind); BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready")); @@ -241,7 +252,6 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co // Otherwise, emit suspend logic. CGF.EmitBlock(SuspendBlock); - auto &Builder = CGF.Builder; llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save); auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy); auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr}); @@ -256,7 +266,8 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co SmallVector<llvm::Value *, 3> SuspendIntrinsicCallArgs; SuspendIntrinsicCallArgs.push_back( - CGF.getOrCreateOpaqueLValueMapping(S.getOpaqueValue()).getPointer(CGF)); + CGF.getOrCreateOpaqueLValueMapping(S.getCommonExprOpaqueValue()) + .getPointer(CGF)); SuspendIntrinsicCallArgs.push_back(CGF.CurCoro.Data->CoroBegin); SuspendIntrinsicCallArgs.push_back(SuspendWrapper); @@ -455,7 +466,7 @@ CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName, Builder.CreateLoad(GetAddrOfLocalVar(&FrameDecl)); auto AwaiterBinder = CodeGenFunction::OpaqueValueMappingData::bind( - *this, S.getOpaqueValue(), AwaiterLValue); + *this, S.getCommonExprOpaqueValue(), AwaiterLValue); auto *SuspendRet = EmitScalarExpr(S.getSuspendExpr()); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index b6718a46e8c50..d20f2ebfcffe5 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -5445,16 +5445,17 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV, //===--------------------------------------------------------------------===// RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { // Builtins never have block type. if (E->getCallee()->getType()->isBlockPointerType()) - return EmitBlockCallExpr(E, ReturnValue); + return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E)) - return EmitCXXMemberCallExpr(CE, ReturnValue); + return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke); if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E)) - return EmitCUDAKernelCallExpr(CE, ReturnValue); + return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke); // A CXXOperatorCallExpr is created even for explicit object methods, but // these should be treated like static function call. @@ -5462,7 +5463,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, if (const auto *MD = dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl()); MD && MD->isImplicitObjectMemberFunction()) - return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue); + return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke); CGCallee callee = EmitCallee(E->getCallee()); @@ -5475,14 +5476,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E, return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr()); } - return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } /// Emit a CallExpr without considering whether it might be a subclass. RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { CGCallee Callee = EmitCallee(E->getCallee()); - return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue); + return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue, + /*Chain=*/nullptr, CallOrInvoke); } // Detect the unusual situation where an inline version is shadowed by a @@ -5685,8 +5689,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) { llvm_unreachable("bad evaluation kind"); } -LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) { - RValue RV = EmitCallExpr(E); +LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke) { + RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke); if (!RV.isScalar()) return MakeAddrLValue(RV.getAggregateAddress(), E->getType(), @@ -5809,9 +5814,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) { AlignmentSource::Decl); } -RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee, - const CallExpr *E, ReturnValueSlot ReturnValue, - llvm::Value *Chain) { +RValue CodeGenFunction::EmitCall(QualType CalleeType, + const CGCallee &OrigCallee, const CallExpr *E, + ReturnValueSlot ReturnValue, + llvm::Value *Chain, + llvm::CallBase **CallOrInvoke) { // Get the actual function type. The callee type will always be a pointer to // function type or a block pointer type. assert(CalleeType->isFunctionPointerType() && @@ -6022,8 +6029,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee Address(Handle, Handle->getType(), CGM.getPointerAlign())); Callee.setFunctionPointer(Stub); } - llvm::CallBase *CallOrInvoke = nullptr; - RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke, + llvm::CallBase *LocalCallOrInvoke = nullptr; + RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, E == MustTailCall, E->getExprLoc()); // Generate function declaration DISuprogram in order to be used @@ -6032,11 +6039,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) { FunctionArgList Args; QualType ResTy = BuildFunctionArgList(CalleeDecl, Args); - DI->EmitFuncDeclForCallSite(CallOrInvoke, + DI->EmitFuncDeclForCallSite(LocalCallOrInvoke, DI->getFunctionType(CalleeDecl, ResTy, Args), CalleeDecl); } } + if (CallOrInvoke) + *CallOrInvoke = LocalCallOrInvoke; return Call; } diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp index 8eb6ab7381acb..1214bb054fb8d 100644 --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD, RValue CodeGenFunction::EmitCXXMemberOrOperatorCall( const CXXMethodDecl *MD, const CGCallee &Callee, - ReturnValueSlot ReturnValue, - llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy, - const CallExpr *CE, CallArgList *RtlArgs) { + ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam, + QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs, + llvm::CallBase **CallOrInvoke) { const FunctionProtoType *FPT = MD->getType()->castAs<FunctionProtoType>(); CallArgList Args; MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall( *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs); auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall( Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize); - return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr, + return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke, CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation()); } RValue CodeGenFunction::EmitCXXDestructorCall( GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, - llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) { + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE, + llvm::CallBase **CallOrInvoke) { const CXXMethodDecl *DtorDecl = cast<CXXMethodDecl>(Dtor.getDecl()); assert(!ThisTy.isNull()); @@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall( commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam, ImplicitParamTy, CE, Args, nullptr); return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee, - ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall, + ReturnValueSlot(), Args, CallOrInvoke, + CE && CE == MustTailCall, CE ? CE->getExprLoc() : SourceLocation{}); } @@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) { // Note: This function also emit constructor calls to support a MSVC // extensions allowing explicit constructor function call. RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const Expr *callee = CE->getCallee()->IgnoreParens(); if (isa<BinaryOperator>(callee)) - return EmitCXXMemberPointerCallExpr(CE, ReturnValue); + return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke); const MemberExpr *ME = cast<MemberExpr>(callee); const CXXMethodDecl *MD = cast<CXXMethodDecl>(ME->getMemberDecl()); @@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, CGCallee callee = CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD)); return EmitCall(getContext().getPointerType(MD->getType()), callee, CE, - ReturnValue); + ReturnValue, /*Chain=*/nullptr, CallOrInvoke); } bool HasQualifier = ME->hasQualifier(); @@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE, bool IsArrow = ME->isArrow(); const Expr *Base = ME->getBase(); - return EmitCXXMemberOrOperatorMemberCallExpr( - CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base); + return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue, + HasQualifier, Qualifier, IsArrow, + Base, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, - const Expr *Base) { + const Expr *Base, llvm::CallBase **CallOrInvoke) { assert(isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)); // Compute the object pointer. @@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false, /*Delegating=*/false, This.getAddress(), Args, AggValueSlot::DoesNotOverlap, CE->getExprLoc(), - /*NewPointerIsChecked=*/false); + /*NewPointerIsChecked=*/false, CallOrInvoke); return RValue::get(nullptr); } @@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( "Destructor shouldn't have explicit parameters"); assert(ReturnValue.isNull() && "Destructor shouldn't have return value"); if (UseVirtualCall) { - CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete, - This.getAddress(), - cast<CXXMemberCallExpr>(CE)); + CGM.getCXXABI().EmitVirtualDestructorCall( + *this, Dtor, Dtor_Complete, This.getAddress(), + cast<CXXMemberCallExpr>(CE), CallOrInvoke); } else { GlobalDecl GD(Dtor, Dtor_Complete); CGCallee Callee; @@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( IsArrow ? Base->getType()->getPointeeType() : Base->getType(); EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy, /*ImplicitParam=*/nullptr, - /*ImplicitParamTy=*/QualType(), CE); + /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke); } return RValue::get(nullptr); } @@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr( return EmitCXXMemberOrOperatorCall( CalleeDecl, Callee, ReturnValue, This.getPointer(*this), - /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs); + /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke); } RValue CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue) { + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { const BinaryOperator *BO = cast<BinaryOperator>(E->getCallee()->IgnoreParens()); const Expr *BaseExpr = BO->getLHS(); @@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, EmitCallArgs(Args, FPT, E->arguments()); return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required, /*PrefixSize=*/0), - Callee, ReturnValue, Args, nullptr, E == MustTailCall, + Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall, E->getExprLoc()); } -RValue -CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue) { +RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr( + const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, + ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) { assert(MD->isImplicitObjectMemberFunction() && "Trying to emit a member call expr on a static method!"); return EmitCXXMemberOrOperatorMemberCallExpr( E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr, - /*IsArrow=*/false, E->getArg(0)); + /*IsArrow=*/false, E->getArg(0), CallOrInvoke); } RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue) { - return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke) { + return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue, + CallOrInvoke); } static void EmitNullBaseClassInitialization(CodeGenFunction &CGF, diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 8525f66082a4e..2e879839d2906 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3133,7 +3133,8 @@ class CodeGenFunction : public CodeGenTypeCache { bool ForVirtualBase, bool Delegating, Address This, CallArgList &Args, AggValueSlot::Overlap_t Overlap, - SourceLocation Loc, bool NewPointerIsChecked); + SourceLocation Loc, bool NewPointerIsChecked, + llvm::CallBase **CallOrInvoke = nullptr); /// Emit assumption load for all bases. Requires to be called only on /// most-derived class and not under construction of the object. @@ -4247,7 +4248,8 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitBinaryOperatorLValue(const BinaryOperator *E); LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E); // Note: only available for agg return types - LValue EmitCallExprLValue(const CallExpr *E); + LValue EmitCallExprLValue(const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); // Note: only available for agg return types LValue EmitVAArgExprLValue(const VAArgExpr *E); LValue EmitDeclRefLValue(const DeclRefExpr *E); @@ -4357,20 +4359,26 @@ class CodeGenFunction : public CodeGenTypeCache { /// LLVM arguments and the types they were derived from. RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke, bool IsMustTail, + llvm::CallBase **CallOrInvoke, bool IsMustTail, SourceLocation Loc); RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee, ReturnValueSlot ReturnValue, const CallArgList &Args, - llvm::CallBase **callOrInvoke = nullptr, + llvm::CallBase **CallOrInvoke = nullptr, bool IsMustTail = false) { - return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke, + return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke, IsMustTail, SourceLocation()); } RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E, - ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr); + ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr, + llvm::CallBase **CallOrInvoke = nullptr); + + // If a Call or Invoke instruction was emitted for this CallExpr, this method + // writes the pointer to `CallOrInvoke` if it's not null. RValue EmitCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue = ReturnValueSlot()); - RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue = ReturnValueSlot(), + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); CGCallee EmitCallee(const Expr *E); void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl); @@ -4441,25 +4449,23 @@ class CodeGenFunction : public CodeGenTypeCache { void callCStructCopyAssignmentOperator(LValue Dst, LValue Src); void callCStructMoveAssignmentOperator(LValue Dst, LValue Src); - RValue - EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method, - const CGCallee &Callee, - ReturnValueSlot ReturnValue, llvm::Value *This, - llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E, - CallArgList *RtlArgs); + RValue EmitCXXMemberOrOperatorCall( + const CXXMethodDecl *Method, const CGCallee &Callee, + ReturnValueSlot ReturnValue, llvm::Value *This, + llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E, + CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke); RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy, llvm::Value *ImplicitParam, - QualType ImplicitParamTy, const CallExpr *E); + QualType ImplicitParamTy, const CallExpr *E, + llvm::CallBase **CallOrInvoke = nullptr); RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); - RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE, - const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue, - bool HasQualifier, - NestedNameSpecifier *Qualifier, - bool IsArrow, const Expr *Base); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke = nullptr); + RValue EmitCXXMemberOrOperatorMemberCallExpr( + const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue, + bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow, + const Expr *Base, llvm::CallBase **CallOrInvoke); // Compute the object pointer. Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base, llvm::Value *memberPtr, @@ -4467,15 +4473,18 @@ class CodeGenFunction : public CodeGenTypeCache { LValueBaseInfo *BaseInfo = nullptr, TBAAAccessInfo *TBAAInfo = nullptr); RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E); RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E, - ReturnValueSlot ReturnValue); + ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E); RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E); @@ -4498,7 +4507,8 @@ class CodeGenFunction : public CodeGenTypeCache { const analyze_os_log::OSLogBufferLayout &Layout, CharUnits BufferAlignment); - RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue); + RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue, + llvm::CallBase **CallOrInvoke); /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call /// is unhandled by the current target. diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 5a3e83de625c9..0ffb0e04c8ce1 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -314,10 +314,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override; @@ -1254,7 +1255,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // FIXME: Provide a source location here even though there's no // CXXMemberCallExpr for dtor call. CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.PopCleanupBlock(); @@ -2054,7 +2056,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2075,7 +2077,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall( } CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - nullptr, QualType(), nullptr); + nullptr, QualType(), nullptr, CallOrInvoke); return nullptr; } diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp index 9ab634fa6ce2e..c7e3d5e37bbf8 100644 --- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp +++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp @@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI { Address This, llvm::Type *Ty, SourceLocation Loc) override; - llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF, - const CXXDestructorDecl *Dtor, - CXXDtorType DtorType, Address This, - DeleteOrMemberCallExpr E) override; + llvm::Value * + EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, + CXXDtorType DtorType, Address This, + DeleteOrMemberCallExpr E, + llvm::CallBase **CallOrInvoke) override; void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD, CallArgList &CallArgs) override { @@ -899,7 +900,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF, // CXXMemberCallExpr for dtor call. bool UseGlobalDelete = DE->isGlobalDelete(); CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting; - llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE); + llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE, + /*CallOrInvoke=*/nullptr); if (UseGlobalDelete) CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType); } @@ -1683,7 +1685,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF, CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy), ThisTy, /*ImplicitParam=*/Implicit, - /*ImplicitParamTy=*/QualType(), nullptr); + /*ImplicitParamTy=*/QualType(), /*E=*/nullptr); if (BaseDtorEndBB) { // Complete object handler should continue to be the remaining CGF.Builder.CreateBr(BaseDtorEndBB); @@ -1999,7 +2001,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType, - Address This, DeleteOrMemberCallExpr E) { + Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) { auto *CE = E.dyn_cast<const CXXMemberCallExpr *>(); auto *D = E.dyn_cast<const CXXDeleteExpr *>(); assert((CE != nullptr) ^ (D != nullptr)); @@ -2029,7 +2031,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall( This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true); RValue RV = CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy, - ImplicitParam, Context.IntTy, CE); + ImplicitParam, Context.IntTy, CE, CallOrInvoke); return RV.getScalarVal(); } diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 81334c817b2af..66f6fb86a4aa8 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -15,6 +15,7 @@ #include "CoroutineStmtBuilder.h" #include "clang/AST/ASTLambda.h" +#include "clang/AST/ComputeDependence.h" #include "clang/AST/Decl.h" #include "clang/AST/Expr.h" #include "clang/AST/ExprCXX.h" @@ -825,6 +826,32 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) { return CoawaitOp; } +static bool isAttributedCoroInplaceTask(const QualType &QT) { + auto *Record = QT->getAsCXXRecordDecl(); + return Record && Record->hasAttr<CoroInplaceTaskAttr>(); +} + +static bool isCoroInplaceCall(Expr *Operand) { + if (!Operand->isPRValue()) { + return false; + } + + return isAttributedCoroInplaceTask(Operand->getType()); +} + +template <typename DesiredExpr> +DesiredExpr *getExprWrappedByTemporary(Expr *E) { + if (auto *BTE = dyn_cast<CXXBindTemporaryExpr>(E)) { + E = BTE->getSubExpr(); + } + + if (auto *S = dyn_cast<DesiredExpr>(E)) { + return S; + } + + return nullptr; +} + // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to // DependentCoawaitExpr if needed. ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, @@ -848,6 +875,25 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, } auto *RD = Promise->getType()->getAsCXXRecordDecl(); + bool InplaceCall = + isCoroInplaceCall(Operand) && + isAttributedCoroInplaceTask( + getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); + + OpaqueValueExpr *OpaqueCallExpr = nullptr; + + if (InplaceCall) { + if (auto *Temporary = dyn_cast<CXXBindTemporaryExpr>(Operand)) { + auto *SubExpr = Temporary->getSubExpr(); + if (CallExpr *Call = dyn_cast<CallExpr>(SubExpr)) { + OpaqueCallExpr = new (Context) + OpaqueValueExpr(Call->getRParenLoc(), Call->getType(), + Call->getValueKind(), Call->getObjectKind(), Call); + Temporary->setSubExpr(OpaqueCallExpr); + } + } + } + auto *Transformed = Operand; if (lookupMember(*this, "await_transform", RD, Loc)) { ExprResult R = @@ -864,7 +910,13 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, if (Awaiter.isInvalid()) return ExprError(); - return BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get()); + auto Res = BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get()); + if (!Res.isInvalid() && InplaceCall) { + // BuildResolvedCoawaitExpr must return a CoawaitExpr, if valid. + CoawaitExpr *CE = Res.getAs<CoawaitExpr>(); + CE->setInplaceCallOpaqueValue(OpaqueCallExpr); + } + return Res; } ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 67ef170251914..5bef1127d237d 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -483,7 +483,10 @@ void ASTStmtReader::VisitCoawaitExpr(CoawaitExpr *E) { E->KeywordLoc = readSourceLocation(); for (auto &SubExpr: E->SubExprs) SubExpr = Record.readSubStmt(); - E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->CommonExprOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->InplaceCallOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); E->setIsImplicit(Record.readInt() != 0); } @@ -492,7 +495,10 @@ void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *E) { E->KeywordLoc = readSourceLocation(); for (auto &SubExpr: E->SubExprs) SubExpr = Record.readSubStmt(); - E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->CommonExprOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); + E->InplaceCallOpaqueValue = + cast_or_null<OpaqueValueExpr>(Record.readSubStmt()); } void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *E) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index 1ba6d5501fd10..236219cd8a62c 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -443,7 +443,8 @@ void ASTStmtWriter::VisitCoroutineSuspendExpr(CoroutineSuspendExpr *E) { Record.AddSourceLocation(E->getKeywordLoc()); for (Stmt *S : E->children()) Record.AddStmt(S); - Record.AddStmt(E->getOpaqueValue()); + Record.AddStmt(E->getCommonExprOpaqueValue()); + Record.AddStmt(E->getInplaceCallOpaqueValue()); } void ASTStmtWriter::VisitCoawaitExpr(CoawaitExpr *E) { diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h new file mode 100644 index 0000000000000..43c6d27823bd4 --- /dev/null +++ b/clang/test/CodeGenCoroutines/Inputs/utility.h @@ -0,0 +1,13 @@ +// This is a mock file for <utility> + +namespace std { + +template <typename T> struct remove_reference { using type = T; }; +template <typename T> struct remove_reference<T &> { using type = T; }; +template <typename T> struct remove_reference<T &&> { using type = T; }; + +template <typename T> +constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept { + return static_cast<typename std::remove_reference<T>::type &&>(t); +} +} diff --git a/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp new file mode 100644 index 0000000000000..2569643221da0 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp @@ -0,0 +1,84 @@ +// This file tests the coro_structured_concurrency attribute semantics. +// RUN: %clang_cc1 -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" +#include "Inputs/utility.h" + +template <typename T> +struct [[clang::coro_inplace_task]] Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + + template <typename P> + std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept { + if (!coro) + return std::noop_coroutine(); + return coro.promise().continuation; + } + void await_resume() noexcept {} + }; + + Task get_return_object() noexcept { + return std::coroutine_handle<promise_type>::from_promise(*this); + } + + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_value(T x) noexcept { + value = x; + } + + std::coroutine_handle<> continuation; + T value; + }; + + Task(std::coroutine_handle<promise_type> handle) : handle(handle) {} + ~Task() { + if (handle) + handle.destroy(); + } + + struct Awaiter { + Awaiter(Task *t) : task(t) {} + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<void> continuation) noexcept {} + T await_resume() noexcept { + return task->handle.promise().value; + } + + Task *task; + }; + + auto operator co_await() { + return Awaiter{this}; + } + +private: + std::coroutine_handle<promise_type> handle; +}; + +// CHECK-LABEL: define{{.*}} @_Z6calleev +Task<int> callee() { + co_return 1; +} + +// CHECK-LABEL: define{{.*}} @_Z8elidablev +Task<int> elidable() { + // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task + // CHECK: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]]) + co_return co_await callee(); +} + +// CHECK-LABEL: define{{.*}} @_Z11nonelidablev +Task<int> nonelidable() { + // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task + auto t = callee(); + // Because we aren't co_awaiting a prvalue, we cannot elide here. + // CHECK-NOT: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]]) + co_await t; + co_await std::move(t); + + co_return 1; +} diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test index 99732694f72a5..068192c173fcd 100644 --- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test +++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test @@ -59,6 +59,7 @@ // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record) // CHECK-NEXT: Convergent (SubjectMatchRule_function) // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function) +// CHECK-NEXT: CoroInplaceTask (SubjectMatchRule_record) // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record) // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record) // CHECK-NEXT: CoroReturnType (SubjectMatchRule_record) diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index ef500329d1fb9..7b17f3061269c 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1728,6 +1728,9 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic< [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>; +def int_coro_safe_elide : DefaultAttrsIntrinsic< + [], [llvm_ptr_ty], []>; + ///===-------------------------- Other Intrinsics --------------------------===// // // TODO: We should introduce a new memory kind fo traps (and other side effects diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 3e3825fcd50e2..71229eae5cb47 100644 --- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -8,10 +8,11 @@ #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "CoroInternal.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/Function.h" #include "llvm/Transforms/Scalar/SimplifyCFG.h" using namespace llvm; @@ -80,7 +81,7 @@ bool Lowerer::lower(Function &F) { } else continue; break; - case Intrinsic::coro_async_size_replace: + case Intrinsic::coro_async_size_replace: { auto *Target = cast<ConstantStruct>( cast<GlobalVariable>(II->getArgOperand(0)->stripPointerCasts()) ->getInitializer()); @@ -98,6 +99,9 @@ bool Lowerer::lower(Function &F) { Target->replaceAllUsesWith(NewFuncPtrStruct); break; } + case Intrinsic::coro_safe_elide: + break; + } II->eraseFromParent(); Changed = true; } @@ -111,7 +115,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) { M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr", "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.async", "llvm.coro.id.retcon.once", - "llvm.coro.async.size.replace", "llvm.coro.async.resume"}); + "llvm.coro.async.size.replace", "llvm.coro.async.resume", + "llvm.coro.safe.elide"}); } PreservedAnalyses CoroCleanupPass::run(Module &M, diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 74b5ccb7b9b71..403e3abdebab3 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -7,12 +7,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Coroutines/CoroElide.h" +#include "CoroInstr.h" #include "CoroInternal.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/Support/ErrorHandling.h" @@ -56,7 +58,8 @@ class FunctionElideInfo { class CoroIdElider { public: CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, - DominatorTree &DT, OptimizationRemarkEmitter &ORE); + DominatorTree &DT, PostDominatorTree &PDT, + OptimizationRemarkEmitter &ORE); void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign); bool lifetimeEligibleForElide() const; bool attemptElide(); @@ -68,6 +71,7 @@ class CoroIdElider { FunctionElideInfo &FEI; AAResults &AA; DominatorTree &DT; + PostDominatorTree &PDT; OptimizationRemarkEmitter &ORE; SmallVector<CoroBeginInst *, 1> CoroBegins; @@ -183,8 +187,9 @@ void FunctionElideInfo::collectPostSplitCoroIds() { CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, DominatorTree &DT, + PostDominatorTree &PDT, OptimizationRemarkEmitter &ORE) - : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) { + : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), PDT(PDT), ORE(ORE) { // Collect all coro.begin and coro.allocs associated with this coro.id. for (User *U : CoroId->users()) { if (auto *CB = dyn_cast<CoroBeginInst>(U)) @@ -336,6 +341,41 @@ bool CoroIdElider::canCoroBeginEscape( return false; } +// FIXME: This is not accounting for the stores to tasks whose handle is not +// zero offset. +static const StoreInst *getPostDominatingStoreToTask(const CoroBeginInst *CB, + PostDominatorTree &PDT) { + const StoreInst *OnlyStore = nullptr; + + for (auto *U : CB->users()) { + auto *Store = dyn_cast<StoreInst>(U); + if (Store && Store->getValueOperand() == CB) { + if (OnlyStore) { + // Store must be unique. one coro begin getting stored to multiple + // stores is not accepted. + return nullptr; + } + OnlyStore = Store; + } + } + + if (!OnlyStore || !PDT.dominates(OnlyStore, CB)) { + return nullptr; + } + + return OnlyStore; +} + +static bool isMarkedSafeElide(const llvm::Value *V) { + for (auto *U : V->users()) { + auto *II = dyn_cast<IntrinsicInst>(U); + if (II && (II->getIntrinsicID() == Intrinsic::coro_safe_elide)) { + return true; + } + } + return false; +} + bool CoroIdElider::lifetimeEligibleForElide() const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. @@ -364,6 +404,17 @@ bool CoroIdElider::lifetimeEligibleForElide() const { // Filter out the coro.destroy that lie along exceptional paths. for (const auto *CB : CoroBegins) { + // This might be too strong of a condition but should be very safe. + // If the CB is unconditionally stored into a "Task Like Object", + // and such object is "safe elide". + if (FEI.ContainingFunction->isPresplitCoroutine()) { + if (auto *MaybeStoreToTask = getPostDominatingStoreToTask(CB, PDT)) { + auto Dest = MaybeStoreToTask->getPointerOperand(); + if (isMarkedSafeElide(Dest)) + continue; + } + } + auto It = DestroyAddr.find(CB); // FIXME: If we have not found any destroys for this coro.begin, we @@ -476,11 +527,12 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { AAResults &AA = AM.getResult<AAManager>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = false; for (auto *CII : FEI.getCoroIds()) { - CoroIdElider CIE(CII, FEI, AA, DT, ORE); + CoroIdElider CIE(CII, FEI, AA, DT, PDT, ORE); Changed |= CIE.attemptElide(); } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 1a92bc1636257..48c02e5406b75 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -86,6 +86,7 @@ static const char *const CoroIntrinsics[] = { "llvm.coro.prepare.retcon", "llvm.coro.promise", "llvm.coro.resume", + "llvm.coro.safe.elide", "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr", diff --git a/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll new file mode 100644 index 0000000000000..b19886d549d90 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll @@ -0,0 +1,64 @@ +; Testing elide performed its job for calls to coroutines marked safe. +; RUN: opt < %s -S -passes='inline,coro-elide' | FileCheck %s + +%struct.Task = type { ptr } + +declare void @print(i32) nounwind + +; resume part of the coroutine +define fastcc void @callee.resume(ptr dereferenceable(1)) { + tail call void @print(i32 0) + ret void +} + +; destroy part of the coroutine +define fastcc void @callee.destroy(ptr) { + tail call void @print(i32 1) + ret void +} + +; cleanup part of the coroutine +define fastcc void @callee.cleanup(ptr) { + tail call void @print(i32 2) + ret void +} + +@callee.resumers = internal constant [3 x ptr] [ + ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup] + +declare void @alloc(i1) nounwind + +; CHECK: define ptr @callee() +define ptr @callee() { +entry: + %task = alloca %struct.Task, align 8 + %id = call token @llvm.coro.id(i32 0, ptr null, + ptr @callee, + ptr @callee.resumers) + %alloc = call i1 @llvm.coro.alloc(token %id) + %hdl = call ptr @llvm.coro.begin(token %id, ptr null) + store ptr %hdl, ptr %task + ret ptr %task +} + +; CHECK: define ptr @caller() +; Function Attrs: presplitcoroutine +define ptr @caller() #0 { +entry: + %task = call ptr @callee() + + ; CHECK: %[[id:.+]] = call token @llvm.coro.id(i32 0, ptr null, ptr @callee, ptr @callee.resumers) + ; CHECK-NOT: call i1 @llvm.coro.alloc(token %[[id]]) + call void @llvm.coro.safe.elide(ptr %task) + + ret ptr %task +} + +attributes #0 = { presplitcoroutine } + +declare token @llvm.coro.id(i32, ptr, ptr, ptr) +declare ptr @llvm.coro.begin(token, ptr) +declare ptr @llvm.coro.frame() +declare ptr @llvm.coro.subfn.addr(ptr, i8) +declare i1 @llvm.coro.alloc(token) +declare void @llvm.coro.safe.elide(ptr) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits