EricWF created this revision.
EricWF added reviewers: rsmith, GorNishanov, majnemer.
EricWF added a subscriber: cfe-commits.
Herald added a subscriber: mehdi_amini.
When using TreeTransform to rebuild a coroutine the coroutine `promise_type`
variable is not transformed because it's stored in the `FunctionScopeInfo` and
not as a regular sub-statement of the function.
This patch attempts to fix this by rebuilding the promise variable at the start
of transforming a `CoroutineBodyStmt`. Additionally this patch changes
`TransformCoroutineBodyStmt` so that it always re-builds the coroutine body. Is
there a better alternative to always rebuilding the body? Or is always
rebuilding sufficient for now?
https://reviews.llvm.org/D25303
Files:
include/clang/Sema/Sema.h
lib/Sema/SemaCoroutine.cpp
lib/Sema/SemaDecl.cpp
lib/Sema/TreeTransform.h
test/SemaCXX/coroutines.cpp
Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -52,21 +52,21 @@
using promise_type = Promise;
};
-void no_specialization() {
- co_await a; // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits<void>'}}
+void no_specialization() { // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits<void>'}}
+ co_await a;
}
template <typename... T>
struct std::experimental::coroutine_traits<int, T...> {};
-int no_promise_type() {
- co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+ co_await a;
}
template <>
struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; };
-double bad_promise_type(double) {
- co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+ co_await a;
}
template <>
@@ -77,7 +77,7 @@
co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}}
}
-struct promise; // expected-note 2{{forward declaration}}
+struct promise; // expected-note 3{{forward declaration}}
template <typename... T>
struct std::experimental::coroutine_traits<void, T...> { using promise_type = promise; };
@@ -94,6 +94,12 @@
// expected-error@-2 {{incomplete definition of type 'promise'}}
co_await a;
}
+template <class T>
+void undefined_promise_template(T) { // expected-error {{variable has incomplete type 'promise_type'}}
+ // FIXME: This diagnostic doesn't make any sense.
+ co_await a;
+}
+template void undefined_promise_template(int); // expected-note {{requested here}}
struct yielded_thing { const char *p; short a, b; };
@@ -299,6 +305,16 @@
co_await a;
}
+struct not_class_tag {};
+template <>
+struct std::experimental::coroutine_traits<void, not_class_tag> { using promise_type = int; };
+
+template <class T>
+void promise_type_not_class(T) {
+ // expected-error@-1 {{this function cannot be a coroutine: 'experimental::coroutine_traits<void, not_class_tag>::promise_type' (aka 'int') is not a class}}
+ co_await a;
+}
+template void promise_type_not_class(not_class_tag); // expected-note {{requested here}}
template<> struct std::experimental::coroutine_traits<int, int, const char**>
{ using promise_type = promise; };
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1326,6 +1326,16 @@
return getSema().BuildCoyieldExpr(CoyieldLoc, Result);
}
+ /// \brief Build a new coroutine body.
+ ///
+ /// By default, performs semantic analysis to build the new body.
+ /// Subclasses may override this routine to provide different behavior.
+ StmtResult RebuildCoroutineBodyStmt(Stmt *Body) {
+ auto *FD = dyn_cast<FunctionDecl>(getSema().CurContext);
+ assert(FD); // FIXME this assertion should never fire
+ return getSema().ActOnFinishCoroutineBody(FD, Body);
+ }
+
/// \brief Build a new Objective-C \@try statement.
///
/// By default, performs semantic analysis to build the new statement.
@@ -6651,8 +6661,25 @@
template<typename Derived>
StmtResult
TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) {
- // The coroutine body should be re-formed by the caller if necessary.
- return getDerived().TransformStmt(S->getBody());
+ // FIXME: Don't rebuild the entire coroutine body.
+ // The coroutine body should only be re-formed by the caller if necessary.
+ FunctionScopeInfo *FS = getSema().getCurFunction();
+ assert(FS);
+ VarDecl *VD =
+ getSema().buildCoroutinePromise(S->getPromiseDecl()->getLocation());
+ if (!VD || VD->isInvalidDecl())
+ return StmtError();
+ getDerived().transformedLocalDecl(S->getPromiseDecl(), VD);
+ // FIXME: Re-setting FS->CoroutinePromise feels like a hack. Is there a better
+ // way to do this? Currently this is needed so the rebuilt body uses the
+ // transformed promise type.
+ FS->CoroutinePromise = VD;
+
+ StmtResult Body = getDerived().TransformStmt(S->getBody());
+ if (Body.isInvalid())
+ return StmtError();
+
+ return getDerived().RebuildCoroutineBodyStmt(Body.get());
}
template<typename Derived>
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -11626,13 +11626,22 @@
Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body,
bool IsInstantiation) {
- FunctionDecl *FD = dcl ? dcl->getAsFunction() : nullptr;
+ if (!dcl)
+ return nullptr;
+ FunctionDecl *FD = dcl->getAsFunction();
sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy();
sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr;
- if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty())
- CheckCompletedCoroutineBody(FD, Body);
+ if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) {
+ // FIXME: support ObjC methods here
+ assert(FD && "Objective C methods are not supported");
+ StmtResult NewBody = ActOnFinishCoroutineBody(FD, Body);
+ if (NewBody.isInvalid())
+ FD->setInvalidDecl();
+ else
+ Body = NewBody.get();
+ }
if (FD) {
FD->setBody(Body);
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -24,18 +24,19 @@
/// Look up the std::coroutine_traits<...>::promise_type for the given
/// function type.
static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
- SourceLocation Loc) {
+ SourceLocation KWLoc,
+ SourceLocation FuncLoc) {
// FIXME: Cache std::coroutine_traits once we've found it.
NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
if (!StdExp) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+ S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found);
return QualType();
}
LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
- Loc, Sema::LookupOrdinaryName);
+ FuncLoc, Sema::LookupOrdinaryName);
if (!S.LookupQualifiedName(Result, StdExp)) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+ S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found);
return QualType();
}
@@ -49,36 +50,37 @@
}
// Form template argument list for coroutine_traits<R, P1, P2, ...>.
- TemplateArgumentListInfo Args(Loc, Loc);
+ TemplateArgumentListInfo Args(FuncLoc, FuncLoc);
Args.addArgument(TemplateArgumentLoc(
TemplateArgument(FnType->getReturnType()),
- S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
+ S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), FuncLoc)));
// FIXME: If the function is a non-static member function, add the type
// of the implicit object parameter before the formal parameters.
for (QualType T : FnType->getParamTypes())
Args.addArgument(TemplateArgumentLoc(
- TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
+ TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, FuncLoc)));
// Build the template-id.
QualType CoroTrait =
- S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
+ S.CheckTemplateIdType(TemplateName(CoroTraits), FuncLoc, Args);
if (CoroTrait.isNull())
return QualType();
- if (S.RequireCompleteType(Loc, CoroTrait,
+ if (S.RequireCompleteType(FuncLoc, CoroTrait,
diag::err_coroutine_traits_missing_specialization))
return QualType();
CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
assert(RD && "specialization of class template is not a class?");
// Look up the ::promise_type member.
- LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
+ LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), FuncLoc,
Sema::LookupOrdinaryName);
S.LookupQualifiedName(R, RD);
auto *Promise = R.getAsSingle<TypeDecl>();
if (!Promise) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
- << RD;
+ S.Diag(FuncLoc,
+ diag::err_implied_std_coroutine_traits_promise_type_not_found)
+ << RD;
return QualType();
}
@@ -91,14 +93,38 @@
CoroTrait.getTypePtr());
PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
- << PromiseType;
+ S.Diag(FuncLoc,
+ diag::err_implied_std_coroutine_traits_promise_type_not_class)
+ << PromiseType;
return QualType();
}
return PromiseType;
}
+VarDecl *Sema::buildCoroutinePromise(SourceLocation KWLoc) {
+ auto *FD = dyn_cast<FunctionDecl>(CurContext);
+ assert(FD && "Not inside a function context");
+ SourceLocation FuncLoc = FD->getLocation();
+ QualType T =
+ FD->getType()->isDependentType()
+ ? Context.DependentTy
+ : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+ KWLoc, FuncLoc);
+ if (T.isNull())
+ return nullptr;
+
+ // Create and default-initialize the promise.
+ VarDecl *VD =
+ VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+ &PP.getIdentifierTable().get("__promise"), T,
+ Context.getTrivialTypeSourceInfo(T, FuncLoc), SC_None);
+ CheckVariableDeclarationType(VD);
+ if (!VD->isInvalidDecl())
+ ActOnUninitializedDecl(VD, false);
+ return VD;
+}
+
/// Check that this is a context in which a coroutine suspension can appear.
static FunctionScopeInfo *
checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
@@ -138,22 +164,8 @@
// If we don't have a promise variable, build one now.
if (!ScopeInfo->CoroutinePromise) {
- QualType T =
- FD->getType()->isDependentType()
- ? S.Context.DependentTy
- : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
- Loc);
- if (T.isNull())
+ if (!(ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc)))
return nullptr;
-
- // Create and default-initialize the promise.
- ScopeInfo->CoroutinePromise =
- VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
- &S.PP.getIdentifierTable().get("__promise"), T,
- S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
- S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
- if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
- S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
}
return ScopeInfo;
@@ -378,7 +390,7 @@
return Res;
}
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+StmtResult Sema::ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body) {
FunctionScopeInfo *Fn = getCurFunction();
assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
@@ -410,7 +422,7 @@
StmtResult PromiseStmt =
ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
if (PromiseStmt.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
// Form and check implicit 'co_await p.initial_suspend();' statement.
ExprResult InitialSuspend =
@@ -420,7 +432,7 @@
InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
if (InitialSuspend.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
// Form and check implicit 'co_await p.final_suspend();' statement.
ExprResult FinalSuspend =
@@ -430,7 +442,7 @@
FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
if (FinalSuspend.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
// FIXME: Perform analysis of set_exception call.
@@ -442,26 +454,26 @@
ExprResult ReturnObject =
buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
if (ReturnObject.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
QualType RetType = FD->getReturnType();
if (!RetType->isDependentType()) {
InitializedEntity Entity =
InitializedEntity::InitializeResult(Loc, RetType, false);
ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
ReturnObject.get());
if (ReturnObject.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
}
ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
if (ReturnObject.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
// FIXME: Perform move-initialization of parameters into frame-local copies.
SmallVector<Expr*, 16> ParamMoves;
// Build body for the coroutine wrapper statement.
- Body = new (Context) CoroutineBodyStmt(
+ return new (Context) CoroutineBodyStmt(
Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
- /*SetException*/nullptr, /*Fallthrough*/nullptr,
- ReturnObject.get(), ParamMoves);
+ /*SetException*/ nullptr, /*Fallthrough*/ nullptr, ReturnObject.get(),
+ ParamMoves);
}
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8024,7 +8024,9 @@
ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E);
- void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
+ VarDecl *buildCoroutinePromise(SourceLocation KWLoc);
+
+ StmtResult ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body);
//===--------------------------------------------------------------------===//
// OpenMP directives and clauses.
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits