modocache created this revision. modocache added reviewers: rsmith, GorNishanov, eric_niebler. Herald added a subscriber: EricWF.
Use corutine function arguments to initialize a promise type, but only if the promise type defines a constructor that takes those arguments. Otherwise, fall back to the default constructor. Test Plan: check-clang Repository: rC Clang https://reviews.llvm.org/D41820 Files: include/clang/Sema/ScopeInfo.h include/clang/Sema/Sema.h lib/Sema/CoroutineStmtBuilder.h lib/Sema/ScopeInfo.cpp lib/Sema/SemaCoroutine.cpp lib/Sema/TreeTransform.h test/CodeGenCoroutines/coro-alloc.cpp
Index: test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- test/CodeGenCoroutines/coro-alloc.cpp +++ test/CodeGenCoroutines/coro-alloc.cpp @@ -193,3 +193,26 @@ // CHECK: ret i32 %[[LoadRet]] co_return; } + +struct promise_matching_constructor {}; + +template<> +struct std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double> { + struct promise_type { + promise_type(promise_matching_constructor, int, float, double) {} + promise_type() = delete; + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f5( +extern "C" void f5(promise_matching_constructor, int, float, double) { + // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4 + // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4 + // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8 + // CHECK: call void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double>::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]]) + co_return; +} Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -6956,6 +6956,8 @@ // The new CoroutinePromise object needs to be built and put into the current // FunctionScopeInfo before any transformations or rebuilding occurs. + if (!SemaRef.buildCoroutineParameterMoves(FD->getLocation())) + return StmtError(); auto *Promise = SemaRef.buildCoroutinePromise(FD->getLocation()); if (!Promise) return StmtError(); @@ -7046,8 +7048,6 @@ Builder.ReturnStmt = Res.get(); } } - if (!Builder.buildParameterMoves()) - return StmtError(); return getDerived().RebuildCoroutineBodyStmt(Builder); } Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -472,6 +472,69 @@ return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); } +// Create a static_cast\<T&&>(expr). +static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) { + if (T.isNull()) + T = E->getType(); + QualType TargetType = S.BuildReferenceType( + T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName()); + SourceLocation ExprLoc = E->getLocStart(); + TypeSourceInfo *TargetLoc = + S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc); + + return S + .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E, + SourceRange(ExprLoc, ExprLoc), E->getSourceRange()) + .get(); +} + +/// \brief Build a variable declaration for move parameter. +static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type, + IdentifierInfo *II) { + TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc); + VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, + TInfo, SC_None); + Decl->setImplicit(); + return Decl; +} + +// Build statements that move coroutine function parameters to the coroutine +// frame, and store them on the function scope info. +bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) { + assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(CurContext); + + auto *ScopeInfo = getCurFunction(); + assert(ScopeInfo->CoroutineParameterMoves.empty() && + "Should not build parameter moves twice"); + + for (auto *PD : FD->parameters()) { + if (PD->getType()->isDependentType()) + continue; + + // No need to copy scalars, LLVM will take care of them. + if (PD->getType()->getAsCXXRecordDecl()) { + ExprResult PDRefExpr = BuildDeclRefExpr( + PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope? + if (PDRefExpr.isInvalid()) + return false; + + Expr *CExpr = castForMoving(*this, PDRefExpr.get()); + + auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier()); + AddInitializerToDecl(D, CExpr, /*DirectInit=*/true); + + // Convert decl to a statement. + StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc); + if (Stmt.isInvalid()) + return false; + + ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get())); + } + } + return true; +} + VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); auto *FD = cast<FunctionDecl>(CurContext); @@ -494,7 +557,64 @@ CheckVariableDeclarationType(VD); if (VD->isInvalidDecl()) return nullptr; - ActOnUninitializedDecl(VD); + + auto *ScopeInfo = getCurFunction(); + // Build a list of arguments, based on the coroutine functions arguments, + // that will be passed to the promise type's constructor. + llvm::SmallVector<Expr *, 4> CtorArgExprs; + for (auto *PD : FD->parameters()) { + if (PD->getType()->isDependentType()) + continue; + + auto RefExpr = ExprEmpty(); + auto Moves = ScopeInfo->CoroutineParameterMoves; + if (Moves.find(PD) != Moves.end()) { + // If a reference to the function parameter exists in the coroutine + // frame, use that reference. + auto *VD = cast<VarDecl>(cast<DeclStmt>(Moves[PD])->getSingleDecl()); + RefExpr = BuildDeclRefExpr(VD, VD->getType(), ExprValueKind::VK_LValue, + FD->getLocation()); + } else { + // If the function parameter doesn't exist in the coroutine frame, it + // must be a scalar value. Use it directly. + assert(!PD->getType()->getAsCXXRecordDecl() && + "Non-scalar types should have been moved and inserted into the " + "parameter moves map"); + RefExpr = + BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), + ExprValueKind::VK_LValue, FD->getLocation()); + } + + if (RefExpr.isInvalid()) + return nullptr; + CtorArgExprs.push_back(RefExpr.get()); + } + + // Create an initialization sequence for the promise type using the + // constructor arguments, wrapped in a parenthesized list expression. + Expr *PLE = new (Context) ParenListExpr(Context, FD->getLocation(), + CtorArgExprs, FD->getLocation()); + InitializedEntity Entity = InitializedEntity::InitializeVariable(VD); + InitializationKind Kind = InitializationKind::CreateForInit( + VD->getLocation(), /*DirectInit=*/true, PLE); + InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs, + /*TopLevelOfInitList=*/false, + /*TreatUnavailableAsInvalid=*/false); + + // Attempt to initialize the promise type with the arguments. + // If that fails, fall back to the promise type's default constructor. + if (InitSeq) { + ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs); + if (Result.isInvalid()) { + VD->setInvalidDecl(); + } else if (Result.get()) { + VD->setInit(MaybeCreateExprWithCleanups(Result.get())); + VD->setInitStyle(VarDecl::CallInit); + CheckCompleteVariableDeclaration(VD); + } + } else + ActOnUninitializedDecl(VD); + FD->addDecl(VD); assert(!VD->isInvalidDecl()); return VD; @@ -518,6 +638,9 @@ if (ScopeInfo->CoroutinePromise) return ScopeInfo; + if (!S.buildCoroutineParameterMoves(Loc)) + return nullptr; + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); if (!ScopeInfo->CoroutinePromise) return nullptr; @@ -861,6 +984,11 @@ !Fn.CoroutinePromise || Fn.CoroutinePromise->getType()->isDependentType()) { this->Body = Body; + + for (auto KV : Fn.CoroutineParameterMoves) + this->ParamMovesVector.push_back(KV.second); + this->ParamMoves = this->ParamMovesVector; + if (!IsPromiseDependentType) { PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); assert(PromiseRecordDecl && "Type should have already been checked"); @@ -870,7 +998,7 @@ bool CoroutineStmtBuilder::buildStatements() { assert(this->IsValid && "coroutine already invalid"); - this->IsValid = makeReturnObject() && makeParamMoves(); + this->IsValid = makeReturnObject(); if (this->IsValid && !IsPromiseDependentType) buildDependentStatements(); return this->IsValid; @@ -886,12 +1014,6 @@ return this->IsValid; } -bool CoroutineStmtBuilder::buildParameterMoves() { - assert(this->IsValid && "coroutine already invalid"); - assert(this->ParamMoves.empty() && "param moves already built"); - return this->IsValid = makeParamMoves(); -} - bool CoroutineStmtBuilder::makePromiseStmt() { // Form a declaration statement for the promise declaration, so that AST // visitors can more easily find it. @@ -1288,66 +1410,6 @@ return true; } -// Create a static_cast\<T&&>(expr). -static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) { - if (T.isNull()) - T = E->getType(); - QualType TargetType = S.BuildReferenceType( - T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName()); - SourceLocation ExprLoc = E->getLocStart(); - TypeSourceInfo *TargetLoc = - S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc); - - return S - .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E, - SourceRange(ExprLoc, ExprLoc), E->getSourceRange()) - .get(); -} - - -/// \brief Build a variable declaration for move parameter. -static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type, - IdentifierInfo *II) { - TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc); - VarDecl *Decl = - VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, TInfo, SC_None); - Decl->setImplicit(); - return Decl; -} - -bool CoroutineStmtBuilder::makeParamMoves() { - for (auto *paramDecl : FD.parameters()) { - auto Ty = paramDecl->getType(); - if (Ty->isDependentType()) - continue; - - // No need to copy scalars, llvm will take care of them. - if (Ty->getAsCXXRecordDecl()) { - ExprResult ParamRef = - S.BuildDeclRefExpr(paramDecl, paramDecl->getType(), - ExprValueKind::VK_LValue, Loc); // FIXME: scope? - if (ParamRef.isInvalid()) - return false; - - Expr *RCast = castForMoving(S, ParamRef.get()); - - auto D = buildVarDecl(S, Loc, Ty, paramDecl->getIdentifier()); - S.AddInitializerToDecl(D, RCast, /*DirectInit=*/true); - - // Convert decl to a statement. - StmtResult Stmt = S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(D), Loc, Loc); - if (Stmt.isInvalid()) - return false; - - ParamMovesVector.push_back(Stmt.get()); - } - } - - // Convert to ArrayRef in CtorArgs structure that builder inherits from. - ParamMoves = ParamMovesVector; - return true; -} - StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); if (!Res) Index: lib/Sema/ScopeInfo.cpp =================================================================== --- lib/Sema/ScopeInfo.cpp +++ lib/Sema/ScopeInfo.cpp @@ -43,6 +43,7 @@ // Coroutine state FirstCoroutineStmtLoc = SourceLocation(); CoroutinePromise = nullptr; + CoroutineParameterMoves.clear(); NeedsCoroutineSuspends = true; CoroutineSuspends.first = nullptr; CoroutineSuspends.second = nullptr; Index: lib/Sema/CoroutineStmtBuilder.h =================================================================== --- lib/Sema/CoroutineStmtBuilder.h +++ lib/Sema/CoroutineStmtBuilder.h @@ -51,9 +51,6 @@ /// name lookup. bool buildDependentStatements(); - /// \brief Build just parameter moves. To use for rebuilding in TreeTransform. - bool buildParameterMoves(); - bool isInvalid() const { return !this->IsValid; } private: @@ -65,7 +62,6 @@ bool makeReturnObject(); bool makeGroDeclAndReturnStmt(); bool makeReturnOnAllocFailure(); - bool makeParamMoves(); }; } // end namespace clang Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8473,6 +8473,7 @@ StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E, bool IsImplicit = false); StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs); + bool buildCoroutineParameterMoves(SourceLocation Loc); VarDecl *buildCoroutinePromise(SourceLocation Loc); void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body); Index: include/clang/Sema/ScopeInfo.h =================================================================== --- include/clang/Sema/ScopeInfo.h +++ include/clang/Sema/ScopeInfo.h @@ -22,6 +22,7 @@ #include "clang/Sema/CleanupInfo.h" #include "clang/Sema/Ownership.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" @@ -172,6 +173,10 @@ /// \brief The promise object for this coroutine, if any. VarDecl *CoroutinePromise = nullptr; + /// \brief A mapping between the coroutine function parameters that were moved + /// to the coroutine frame, and their move statements. + llvm::SmallMapVector<ParmVarDecl *, Stmt *, 4> CoroutineParameterMoves; + /// \brief The initial and final coroutine suspend points. std::pair<Stmt *, Stmt *> CoroutineSuspends;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits