EricWF retitled this revision from "[coroutines] Add CoawaitDependentExpr AST
node and use it to properly build await_transform." to "[coroutines] Add
DependentCoawaitExpr and fix re-building CoroutineBodyStmt.".
EricWF updated the summary for this revision.
EricWF updated this revision to Diff 76952.
EricWF marked an inline comment as done.
EricWF added a comment.
- Address review comments about `DependentCoawaitExpr` and using
`UnresolvedLookupExpr`.
- Fix building of the initial/final coroutine suspends points.
- Fix transformation of `CoroutineBodyStmt` so that it transforms the
final/initial suspend points instead of rebuilding them fully.
@rsmith: This change is a little big, but it's not trivial for me to split it
up. Please let me know if you would prefer this submitted as multiple patches.
https://reviews.llvm.org/D26057
Files:
include/clang/AST/ExprCXX.h
include/clang/AST/RecursiveASTVisitor.h
include/clang/AST/StmtCXX.h
include/clang/Basic/DiagnosticSemaKinds.td
include/clang/Basic/StmtNodes.td
include/clang/Sema/ScopeInfo.h
include/clang/Sema/Sema.h
lib/AST/Expr.cpp
lib/AST/ExprClassification.cpp
lib/AST/ExprConstant.cpp
lib/AST/ItaniumMangle.cpp
lib/AST/StmtPrinter.cpp
lib/AST/StmtProfile.cpp
lib/Parse/ParseStmt.cpp
lib/Sema/ScopeInfo.cpp
lib/Sema/SemaCoroutine.cpp
lib/Sema/SemaDecl.cpp
lib/Sema/SemaExceptionSpec.cpp
lib/Sema/SemaTemplateInstantiateDecl.cpp
lib/Sema/TreeTransform.h
lib/Serialization/ASTReaderStmt.cpp
lib/Serialization/ASTWriterStmt.cpp
lib/StaticAnalyzer/Core/ExprEngine.cpp
test/SemaCXX/coroutines.cpp
tools/libclang/CXCursor.cpp
Index: tools/libclang/CXCursor.cpp
===================================================================
--- tools/libclang/CXCursor.cpp
+++ tools/libclang/CXCursor.cpp
@@ -231,6 +231,7 @@
case Stmt::TypeTraitExprClass:
case Stmt::CoroutineBodyStmtClass:
case Stmt::CoawaitExprClass:
+ case Stmt::DependentCoawaitExprClass:
case Stmt::CoreturnStmtClass:
case Stmt::CoyieldExprClass:
case Stmt::CXXBindTemporaryExprClass:
Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -59,25 +59,25 @@
template <typename... T>
struct std::experimental::coroutine_traits<int, T...> {};
-int no_promise_type() {
- co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+ co_await a;
}
template <>
struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; };
-double bad_promise_type(double) {
- co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+ co_await a;
}
template <>
struct std::experimental::coroutine_traits<double, int> {
struct promise_type {};
};
-double bad_promise_type_2(int) {
- co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}}
+double bad_promise_type_2(int) { // expected-error {{no member named 'initial_suspend'}}
+ co_yield 0; // expected-error {{no member named 'yield_value'}}
}
-struct promise; // expected-note 2{{forward declaration}}
+struct promise; // expected-note {{forward declaration}}
struct promise_void;
struct void_tag {};
template <typename... T>
@@ -94,9 +94,7 @@
}
// FIXME: This diagnostic is terrible.
-void undefined_promise() { // expected-error {{variable has incomplete type 'promise_type'}}
- // FIXME: This diagnostic doesn't make any sense.
- // expected-error@-2 {{incomplete definition of type 'promise'}}
+void undefined_promise() { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<void>::promise_type' (aka 'promise') is an incomplete type}}
co_await a;
}
@@ -217,6 +215,13 @@
}
struct outer {};
+struct await_arg_1 {};
+struct await_arg_2 {};
+
+namespace adl_ns {
+struct coawait_arg_type {};
+awaitable operator co_await(coawait_arg_type);
+}
namespace dependent_operator_co_await_lookup {
template<typename T> void await_template(T t) {
@@ -239,6 +244,94 @@
};
template void await_template(outer); // expected-note {{instantiation}}
template void await_template_2(outer);
+
+ struct transform_awaitable {};
+ struct transformed {};
+
+ struct transform_promise {
+ typedef transform_awaitable await_arg;
+ coro<transform_promise> get_return_object();
+ transformed initial_suspend();
+ ::adl_ns::coawait_arg_type final_suspend();
+ transformed await_transform(transform_awaitable);
+ };
+ template <class AwaitArg>
+ struct basic_promise {
+ typedef AwaitArg await_arg;
+ coro<basic_promise> get_return_object();
+ awaitable initial_suspend();
+ awaitable final_suspend();
+ };
+
+ awaitable operator co_await(await_arg_1);
+
+ template <typename T, typename U>
+ coro<T> await_template_3(U t) {
+ co_await t;
+ }
+
+ template coro<basic_promise<await_arg_1>> await_template_3<basic_promise<await_arg_1>>(await_arg_1);
+
+ template <class T, int I = 0>
+ struct dependent_member {
+ coro<T> mem_fn() const {
+ co_await typename T::await_arg{}; // expected-error {{call to function 'operator co_await'}}}
+ }
+ template <class U>
+ coro<T> dep_mem_fn(U t) {
+ co_await t;
+ }
+ };
+
+ template <>
+ struct dependent_member<long> {
+ // FIXME this diagnostic is terrible
+ coro<transform_promise> mem_fn() const { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}}
+ // expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}}
+ // expected-note@+1 {{function is a coroutine due to use of 'co_await' here}}
+ co_await transform_awaitable{};
+ // expected-error@-1 {{no member named 'await_ready'}}
+ }
+ template <class R, class U>
+ coro<R> dep_mem_fn(U u) { co_await u; }
+ };
+
+ awaitable operator co_await(await_arg_2); // expected-note {{'operator co_await' should be declared prior to the call site}}
+
+ template struct dependent_member<basic_promise<await_arg_1>, 0>;
+ template struct dependent_member<basic_promise<await_arg_2>, 0>; // expected-note {{in instantiation}}
+
+ template <>
+ coro<transform_promise>
+ // FIXME this diagnostic is terrible
+ dependent_member<long>::dep_mem_fn<transform_promise>(int) { // expected-error {{no member named 'await_ready' in 'dependent_operator_co_await_lookup::transformed'}}
+ //expected-note@-1 {{call to 'initial_suspend' implicitly required by the initial suspend point}}
+ //expected-note@+1 {{function is a coroutine due to use of 'co_await' here}}
+ co_await transform_awaitable{};
+ // expected-error@-1 {{no member named 'await_ready'}}
+ }
+
+ void operator co_await(transform_awaitable) = delete;
+ awaitable operator co_await(transformed);
+
+ template coro<transform_promise>
+ dependent_member<long>::dep_mem_fn<transform_promise>(transform_awaitable);
+
+ template <>
+ coro<transform_promise> dependent_member<long>::dep_mem_fn<transform_promise>(long) {
+ co_await transform_awaitable{};
+ }
+
+ template <>
+ struct dependent_member<int> {
+ coro<transform_promise> mem_fn() const {
+ co_await transform_awaitable{};
+ }
+ };
+
+ template coro<transform_promise> await_template_3<transform_promise>(transform_awaitable);
+ template struct dependent_member<transform_promise>;
+ template coro<transform_promise> dependent_member<transform_promise>::dep_mem_fn(transform_awaitable);
}
struct yield_fn_tag {};
@@ -314,7 +407,8 @@
};
// FIXME: This diagnostic is terrible.
coro<bad_promise_4> bad_initial_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}}
- co_await a;
+ // expected-note@-1 {{'initial_suspend' implicitly required}}
+ co_await a; // expected-note {{use of 'co_await' here}}
}
struct bad_promise_5 {
@@ -324,7 +418,8 @@
};
// FIXME: This diagnostic is terrible.
coro<bad_promise_5> bad_final_suspend() { // expected-error {{no member named 'await_ready' in 'not_awaitable'}}
- co_await a;
+ // expected-note@-1 {{'final_suspend' implicitly required}}
+ co_await a; // expected-note {{use of 'co_await' here}}
}
struct bad_promise_6 {
@@ -355,20 +450,69 @@
int *current_exception();
}
-struct bad_promise_8 {
+struct bad_promise_base {
+private:
+ void return_void(); // expected-note {{declared private here}}
+};
+struct bad_promise_8 : bad_promise_base {
coro<bad_promise_8> get_return_object();
suspend_always initial_suspend();
suspend_always final_suspend();
- void return_void();
void set_exception(); // expected-note {{function not viable}}
void set_exception(int *) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
void set_exception(void *); // expected-note {{candidate function}}
};
coro<bad_promise_8> calls_set_exception() {
// expected-error@-1 {{call to unavailable member function 'set_exception'}}
+ // expected-error@-2 {{'return_void' is a private member of 'bad_promise_base'}}
co_await a;
}
+struct bad_promise_9 {
+ coro<bad_promise_9> get_return_object();
+ suspend_always initial_suspend();
+ suspend_always final_suspend();
+ void await_transform(void *); // expected-note {{candidate}}
+ awaitable await_transform(int) __attribute__((unavailable)); // expected-note {{explicitly made unavailable}}
+ void return_void();
+};
+coro<bad_promise_9> calls_await_transform() {
+ co_await 42; // expected-error {{call to unavailable member function 'await_transform'}}
+ // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct bad_promise_10 {
+ coro<bad_promise_10> get_return_object();
+ suspend_always initial_suspend();
+ suspend_always final_suspend();
+ int await_transform;
+ void return_void();
+};
+coro<bad_promise_10> bad_coawait() {
+ // FIXME this diagnostic is terrible
+ co_await 42; // expected-error {{called object type 'int' is not a function or function pointer}}
+ // expected-note@-1 {{call to 'await_transform' implicitly required by 'co_await' here}}
+}
+
+struct call_operator {
+ template <class... Args>
+ awaitable operator()(Args...) const { return a; }
+};
+void ret_void();
+struct good_promise_1 {
+ coro<good_promise_1> get_return_object();
+ suspend_always initial_suspend();
+ suspend_always final_suspend();
+ static const call_operator await_transform;
+ using Fn = void (*)();
+ Fn return_void = ret_void;
+};
+const call_operator good_promise_1::await_transform;
+coro<good_promise_1> ok_static_coawait() {
+ // FIXME this diagnostic is terrible
+ co_await 42;
+}
+
template<> struct std::experimental::coroutine_traits<int, int, const char**>
{ using promise_type = promise; };
Index: lib/StaticAnalyzer/Core/ExprEngine.cpp
===================================================================
--- lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -774,6 +774,7 @@
case Stmt::FunctionParmPackExprClass:
case Stmt::CoroutineBodyStmtClass:
case Stmt::CoawaitExprClass:
+ case Stmt::DependentCoawaitExprClass:
case Stmt::CoreturnStmtClass:
case Stmt::CoyieldExprClass:
case Stmt::SEHTryStmtClass:
Index: lib/Serialization/ASTWriterStmt.cpp
===================================================================
--- lib/Serialization/ASTWriterStmt.cpp
+++ lib/Serialization/ASTWriterStmt.cpp
@@ -315,6 +315,11 @@
llvm_unreachable("unimplemented");
}
+void ASTStmtWriter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+ // FIXME: Implement coroutine serialization.
+ llvm_unreachable("unimplemented");
+}
+
void ASTStmtWriter::VisitCoyieldExpr(CoyieldExpr *S) {
// FIXME: Implement coroutine serialization.
llvm_unreachable("unimplemented");
Index: lib/Serialization/ASTReaderStmt.cpp
===================================================================
--- lib/Serialization/ASTReaderStmt.cpp
+++ lib/Serialization/ASTReaderStmt.cpp
@@ -400,6 +400,11 @@
llvm_unreachable("unimplemented");
}
+void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+ // FIXME: Implement coroutine serialization.
+ llvm_unreachable("unimplemented");
+}
+
void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *S) {
// FIXME: Implement coroutine serialization.
llvm_unreachable("unimplemented");
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1306,16 +1306,29 @@
///
/// By default, performs semantic analysis to build the new statement.
/// Subclasses may override this routine to provide different behavior.
- StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) {
- return getSema().BuildCoreturnStmt(CoreturnLoc, Result);
+ StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result,
+ bool IsImplicitlyCreated) {
+ return getSema().BuildCoreturnStmt(CoreturnLoc, Result,
+ IsImplicitlyCreated);
}
/// \brief Build a new co_await expression.
///
/// By default, performs semantic analysis to build the new expression.
/// Subclasses may override this routine to provide different behavior.
- ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) {
- return getSema().BuildCoawaitExpr(CoawaitLoc, Result);
+ ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result,
+ bool IsImplicitlyCreated) {
+ return getSema().BuildCoawaitExpr(CoawaitLoc, Result, IsImplicitlyCreated);
+ }
+
+ /// \brief Build a new co_await expression.
+ ///
+ /// By default, performs semantic analysis to build the new expression.
+ /// Subclasses may override this routine to provide different behavior.
+ ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc,
+ Expr *Result,
+ UnresolvedLookupExpr *Lookup) {
+ return getSema().BuildDependentCoawaitExpr(CoawaitLoc, Result, Lookup);
}
/// \brief Build a new co_yield expression.
@@ -1326,6 +1339,15 @@
return getSema().BuildCoyieldExpr(CoyieldLoc, Result);
}
+ StmtResult RebuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+ Stmt *FinalSuspend, Stmt *OnException,
+ Stmt *OnFallthrough,
+ Expr *Allocation,
+ Stmt *Deallocation, Expr *ReturnObject) {
+ return getSema().BuildCoroutineBodyStmt(
+ Body, Promise, InitSuspend, FinalSuspend, OnException, OnFallthrough,
+ Allocation, Deallocation, ReturnObject);
+ }
/// \brief Build a new Objective-C \@try statement.
///
/// By default, performs semantic analysis to build the new statement.
@@ -6655,7 +6677,87 @@
TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) {
// The coroutine body should be re-formed by the caller if necessary.
// FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody
- return getDerived().TransformStmt(S->getBody());
+
+ auto *ScopeInfo = SemaRef.getCurFunction();
+ auto *FD = cast<FunctionDecl>(SemaRef.CurContext);
+ assert(ScopeInfo && !ScopeInfo->CoroutinePromise &&
+ !ScopeInfo->HasCoroutineSuspends &&
+ ScopeInfo->CoroutineStmts.empty() && "expected clean scope info");
+
+ // Set that we have (possibly-invalid) suspend points before we do anything
+ // that may fail.
+ ScopeInfo->setCoroutineSuspendsInvalid();
+
+ // The new CoroutinePromise object needs to be built and put into the current
+ // FunctionScopeInfo before any transformations or rebuilding occurs.
+ auto *Promise = S->getPromiseDecl();
+ auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation());
+ if (!NewPromise)
+ return StmtError();
+ getDerived().transformedLocalDecl(Promise, NewPromise);
+ ScopeInfo->CoroutinePromise = NewPromise;
+
+ // Transform the implicit coroutine statements we built during the initial
+ // parse.
+ StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt());
+ if (InitSuspend.isInvalid())
+ return StmtError();
+ StmtResult FinalSuspend =
+ getDerived().TransformStmt(S->getFinalSuspendStmt());
+ if (FinalSuspend.isInvalid())
+ return StmtError();
+ ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+ StmtResult BodyRes = getDerived().TransformStmt(S->getBody());
+ if (BodyRes.isInvalid())
+ return StmtError();
+
+ Stmt *SetException = S->getExceptionHandler();
+ Stmt *Fallthrough = S->getFallthroughHandler();
+ if (Fallthrough) {
+ StmtResult Res = getDerived().TransformStmt(Fallthrough);
+ if (Res.isInvalid())
+ return StmtError();
+ Fallthrough = Res.get();
+ }
+
+ if (SetException) {
+ StmtResult Res = getDerived().TransformStmt(SetException);
+ if (Res.isInvalid())
+ return StmtError();
+ SetException = Res.get();
+ }
+
+ // Transform any additional statements we may have already built.
+ Expr *Allocation = nullptr;
+ Stmt *Deallocation = nullptr;
+ if (S->getAllocate() && S->getDeallocate()) {
+ ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate());
+ if (AllocRes.isInvalid())
+ return StmtError();
+ Allocation = AllocRes.get();
+
+ StmtResult DeallocRes = getDerived().TransformStmt(S->getDeallocate());
+ if (DeallocRes.isInvalid())
+ return StmtError();
+ Deallocation = DeallocRes.get();
+ }
+
+ Expr *ReturnObject = S->getReturnValueInit();
+ if (ReturnObject) {
+ ExprResult Res = getDerived().TransformInitializer(ReturnObject,
+ /*NoCopyInit*/false);
+ if (Res.isInvalid())
+ return StmtError();
+ ReturnObject = Res.get();
+ }
+
+ // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo
+ return getDerived().RebuildCoroutineBodyStmt(
+ BodyRes.get(), NewPromise, InitSuspend.get(), FinalSuspend.get(),
+ SetException, Fallthrough, Allocation,
+ Deallocation, ReturnObject);
+
}
template<typename Derived>
@@ -6668,7 +6770,8 @@
// Always rebuild; we don't know if this needs to be injected into a new
// context or if the promise type has changed.
- return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get());
+ return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(),
+ S->isImplicitlyCreated());
}
template<typename Derived>
@@ -6681,12 +6784,26 @@
// Always rebuild; we don't know if this needs to be injected into a new
// context or if the promise type has changed.
- return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get());
+ return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(),
+ E->isImplicitlyCreated());
}
-template<typename Derived>
+template <typename Derived>
ExprResult
-TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
+TreeTransform<Derived>::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) {
+ ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
+ /*NotCopyInit*/ false);
+ if (Result.isInvalid())
+ return ExprError();
+
+ // Always rebuild; we don't know if this needs to be injected into a new
+ // context or if the promise type has changed.
+ return getDerived().RebuildDependentCoawaitExpr(
+ E->getKeywordLoc(), Result.get(), E->getOperatorCoawaitLookup());
+}
+
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformCoyieldExpr(CoyieldExpr *E) {
ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
/*NotCopyInit*/false);
if (Result.isInvalid())
Index: lib/Sema/SemaTemplateInstantiateDecl.cpp
===================================================================
--- lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -3714,6 +3714,8 @@
if (Body.isInvalid())
Function->setInvalidDecl();
+ else
+ assert(Body.get());
ActOnFinishFunctionBody(Function, Body.get(),
/*IsInstantiation=*/true);
Index: lib/Sema/SemaExceptionSpec.cpp
===================================================================
--- lib/Sema/SemaExceptionSpec.cpp
+++ lib/Sema/SemaExceptionSpec.cpp
@@ -1146,6 +1146,7 @@
case Expr::ArraySubscriptExprClass:
case Expr::OMPArraySectionExprClass:
case Expr::BinaryOperatorClass:
+ case Expr::DependentCoawaitExprClass:
case Expr::CompoundAssignOperatorClass:
case Expr::CStyleCastExprClass:
case Expr::CXXStaticCastExprClass:
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -11383,7 +11383,7 @@
if (canRedefineFunction(Definition, getLangOpts()))
return;
- // If we don't have a visible definition of the function, and it's inline or
+ // If we don't have a viNsible definition of the function, and it's inline or
// a template, skip the new definition.
if (SkipBody && !hasVisibleDefinition(Definition) &&
(Definition->getFormalLinkage() == InternalLinkage ||
@@ -11675,7 +11675,7 @@
sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy();
sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr;
- if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty())
+ if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise)
CheckCompletedCoroutineBody(FD, Body);
if (FD) {
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -21,21 +21,32 @@
using namespace clang;
using namespace sema;
+static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
+ SourceLocation Loc) {
+ DeclarationName DN = S.PP.getIdentifierInfo(Name);
+ LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+ // Suppress diagnostics when a private member is selected. The same warnings
+ // will be produced again when building the call.
+ LR.suppressDiagnostics();
+ return S.LookupQualifiedName(LR, RD);
+}
+
/// Look up the std::coroutine_traits<...>::promise_type for the given
/// function type.
static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
- SourceLocation Loc) {
+ SourceLocation KwLoc,
+ SourceLocation FuncLoc) {
// FIXME: Cache std::coroutine_traits once we've found it.
NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
if (!StdExp) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+ S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
return QualType();
}
LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
- Loc, Sema::LookupOrdinaryName);
+ FuncLoc, Sema::LookupOrdinaryName);
if (!S.LookupQualifiedName(Result, StdExp)) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+ S.Diag(KwLoc, diag::err_implied_std_coroutine_traits_not_found);
return QualType();
}
@@ -49,52 +60,58 @@
}
// Form template argument list for coroutine_traits<R, P1, P2, ...>.
- TemplateArgumentListInfo Args(Loc, Loc);
+ TemplateArgumentListInfo Args(KwLoc, KwLoc);
Args.addArgument(TemplateArgumentLoc(
TemplateArgument(FnType->getReturnType()),
- S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
+ S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc)));
// FIXME: If the function is a non-static member function, add the type
// of the implicit object parameter before the formal parameters.
for (QualType T : FnType->getParamTypes())
Args.addArgument(TemplateArgumentLoc(
- TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
+ TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
// Build the template-id.
QualType CoroTrait =
- S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
+ S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
if (CoroTrait.isNull())
return QualType();
- if (S.RequireCompleteType(Loc, CoroTrait,
+ if (S.RequireCompleteType(KwLoc, CoroTrait,
diag::err_coroutine_traits_missing_specialization))
return QualType();
- CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
+ auto *RD = CoroTrait->getAsCXXRecordDecl();
assert(RD && "specialization of class template is not a class?");
// Look up the ::promise_type member.
- LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
+ LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
Sema::LookupOrdinaryName);
S.LookupQualifiedName(R, RD);
auto *Promise = R.getAsSingle<TypeDecl>();
if (!Promise) {
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
+ S.Diag(FuncLoc,
+ diag::err_implied_std_coroutine_traits_promise_type_not_found)
<< RD;
return QualType();
}
-
// The promise type is required to be a class type.
QualType PromiseType = S.Context.getTypeDeclType(Promise);
- if (!PromiseType->getAsCXXRecordDecl()) {
- // Use the fully-qualified name of the type.
+
+ auto buildNNS = [&]() {
auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
CoroTrait.getTypePtr());
- PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+ return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
+ };
- S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
- << PromiseType;
+ if (!PromiseType->getAsCXXRecordDecl()) {
+ S.Diag(FuncLoc,
+ diag::err_implied_std_coroutine_traits_promise_type_not_class)
+ << buildNNS();
return QualType();
}
+ if (S.RequireCompleteType(FuncLoc, buildNNS(),
+ diag::err_coroutine_promise_type_incomplete))
+ return QualType();
return PromiseType;
}
@@ -160,41 +177,49 @@
return !Diagnosed;
}
-/// Check that this is a context in which a coroutine suspension can appear.
-static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
- StringRef Keyword) {
- if (!isValidCoroutineContext(S, Loc, Keyword))
- return nullptr;
+static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
+ SourceLocation Loc) {
+ DeclarationName OpName =
+ SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+ LookupResult Operators(SemaRef, OpName, SourceLocation(),
+ Sema::LookupOperatorName);
+ SemaRef.LookupName(Operators, S);
+
+ assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+ const auto &Functions = Operators.asUnresolvedSet();
+ bool IsOverloaded =
+ Functions.size() > 1 ||
+ (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+ Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+ SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+ DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+ Functions.begin(), Functions.end());
+ assert(CoawaitOp);
+ return CoawaitOp;
+}
- assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
- auto *FD = cast<FunctionDecl>(S.CurContext);
- auto *ScopeInfo = S.getCurFunction();
- assert(ScopeInfo && "missing function scope for function");
+/// Build a call to 'operator co_await' if there is a suitable operator for
+/// the given expression.
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
+ Expr *E,
+ UnresolvedLookupExpr *Lookup) {
- // If we don't have a promise variable, build one now.
- if (!ScopeInfo->CoroutinePromise) {
- QualType T = FD->getType()->isDependentType()
- ? S.Context.DependentTy
- : lookupPromiseType(
- S, FD->getType()->castAs<FunctionProtoType>(), Loc);
- if (T.isNull())
- return nullptr;
-
- // Create and default-initialize the promise.
- ScopeInfo->CoroutinePromise =
- VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
- &S.PP.getIdentifierTable().get("__promise"), T,
- S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
- S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
- if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
- S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
- }
+ UnresolvedSet<16> Functions;
+ Functions.append(Lookup->decls_begin(), Lookup->decls_end());
+ return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+}
- return ScopeInfo;
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
+ SourceLocation Loc, Expr *E) {
+ ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+ if (R.isInvalid())
+ return ExprError();
+ return buildOperatorCoawaitCall(SemaRef, Loc, E,
+ cast<UnresolvedLookupExpr>(R.get()));
}
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
- MutableArrayRef<Expr *> CallArgs) {
+ MultiExprArg CallArgs) {
StringRef Name = S.Context.BuiltinInfo.getName(Id);
LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
@@ -213,24 +238,14 @@
return Call.get();
}
-/// Build a call to 'operator co_await' if there is a suitable operator for
-/// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
- SourceLocation Loc, Expr *E) {
- UnresolvedSet<16> Functions;
- SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
- Functions);
- return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
-}
struct ReadySuspendResumeResult {
bool IsInvalid;
Expr *Results[3];
};
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
- StringRef Name,
- MutableArrayRef<Expr *> Args) {
+ StringRef Name, MultiExprArg Args) {
DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
// FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
@@ -268,25 +283,174 @@
return Calls;
}
+static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
+ SourceLocation Loc, StringRef Name,
+ MultiExprArg Args) {
+
+ // Form a reference to the promise.
+ ExprResult PromiseRef = S.BuildDeclRefExpr(
+ Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
+ if (PromiseRef.isInvalid())
+ return ExprError();
+
+ // Call 'yield_value', passing in E.
+ return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
+}
+
+VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
+ assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+ auto *FD = cast<FunctionDecl>(CurContext);
+
+ QualType T =
+ FD->getType()->isDependentType()
+ ? Context.DependentTy
+ : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+ Loc, FD->getLocation());
+ if (T.isNull())
+ return nullptr;
+
+ auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+ &PP.getIdentifierTable().get("__promise"), T,
+ Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
+ CheckVariableDeclarationType(VD);
+ if (VD->isInvalidDecl())
+ return nullptr;
+ ActOnUninitializedDecl(VD, false);
+ assert(!VD->isInvalidDecl());
+ return VD;
+}
+
+/// Check that this is a context in which a coroutine suspension can appear.
+static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
+ StringRef Keyword) {
+ if (!isValidCoroutineContext(S, Loc, Keyword))
+ return nullptr;
+
+ assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
+ auto *FD = cast<FunctionDecl>(S.CurContext);
+
+ auto *ScopeInfo = S.getCurFunction();
+ assert(ScopeInfo && "missing function scope for function");
+
+ if (ScopeInfo->CoroutinePromise)
+ return ScopeInfo;
+
+ ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
+ if (!ScopeInfo->CoroutinePromise)
+ return nullptr;
+
+ return ScopeInfo;
+}
+
+static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc,
+ StringRef Keyword) {
+ if (!checkCoroutineContext(S, KWLoc, Keyword))
+ return false;
+ auto *ScopeInfo = S.getCurFunction();
+ assert(ScopeInfo->CoroutinePromise);
+
+ // If we have existing coroutine statements then we have already built
+ // the initial and final suspend points.
+ if (ScopeInfo->HasCoroutineSuspends)
+ return true;
+
+ ScopeInfo->setCoroutineSuspendsInvalid();
+
+ auto *Fn = cast<FunctionDecl>(S.CurContext);
+ SourceLocation Loc = Fn->getLocation();
+ // Build the initial suspend point
+ auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
+ ExprResult Suspend =
+ buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None);
+ if (Suspend.isInvalid())
+ return StmtError();
+ Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get());
+ if (Suspend.isInvalid())
+ return StmtError();
+ Suspend = S.BuildCoawaitExpr(Loc, Suspend.get(),
+ /*IsImplicitlyCreated*/ true);
+ Suspend = S.ActOnFinishFullExpr(Suspend.get());
+ if (Suspend.isInvalid()) {
+ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+ << ((Name == "initial_suspend") ? 0 : 1);
+ S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
+ return StmtError();
+ }
+ return cast<Stmt>(Suspend.get());
+ };
+
+ StmtResult InitSuspend = buildSuspends("initial_suspend");
+ if (InitSuspend.isInvalid())
+ return true;
+
+ StmtResult FinalSuspend = buildSuspends("final_suspend");
+ if (FinalSuspend.isInvalid())
+ return true;
+
+ ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+ return true;
+}
+
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
- if (!Coroutine) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) {
CorrectDelayedTyposInExpr(E);
return ExprError();
}
+
if (E->getType()->isPlaceholderType()) {
ExprResult R = CheckPlaceholderExpr(E);
if (R.isInvalid()) return ExprError();
E = R.get();
}
+ ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+ if (Lookup.isInvalid())
+ return ExprError();
+ return BuildDependentCoawaitExpr(Loc, E,
+ cast<UnresolvedLookupExpr>(Lookup.get()));
+}
+
+ExprResult Sema::BuildDependentCoawaitExpr(SourceLocation Loc, Expr *E,
+ UnresolvedLookupExpr *Lookup) {
+ auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
+ if (!FSI)
+ return ExprError();
+
+ if (E->getType()->isPlaceholderType()) {
+ ExprResult R = CheckPlaceholderExpr(E);
+ if (R.isInvalid())
+ return ExprError();
+ E = R.get();
+ }
- ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
+ auto *Promise = FSI->CoroutinePromise;
+ if (Promise->getType()->isDependentType()) {
+ Expr *Res =
+ new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+ FSI->CoroutineStmts.push_back(Res);
+ return Res;
+ }
+
+ auto *RD = Promise->getType()->getAsCXXRecordDecl();
+ if (lookupMember(*this, "await_transform", RD, Loc)) {
+ ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+ if (R.isInvalid()) {
+ Diag(Loc,
+ diag::note_coroutine_promise_implicit_await_transform_required_here)
+ << E->getSourceRange();
+ return ExprError();
+ }
+ E = R.get();
+ }
+ ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
if (Awaitable.isInvalid())
return ExprError();
return BuildCoawaitExpr(Loc, Awaitable.get());
}
-ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
+
+ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E,
+ bool IsImplicitlyCreated) {
auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
if (!Coroutine)
return ExprError();
@@ -298,8 +462,10 @@
}
if (E->getType()->isDependentType()) {
- Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
- Coroutine->CoroutineStmts.push_back(Res);
+ Expr *Res = new (Context)
+ CoawaitExpr(Loc, Context.DependentTy, E, IsImplicitlyCreated);
+ if (!IsImplicitlyCreated)
+ Coroutine->CoroutineStmts.push_back(Res);
return Res;
}
@@ -314,37 +480,21 @@
return ExprError();
Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
- RSS.Results[2]);
- Coroutine->CoroutineStmts.push_back(Res);
+ RSS.Results[2], IsImplicitlyCreated);
+ if (!IsImplicitlyCreated)
+ Coroutine->CoroutineStmts.push_back(Res);
return Res;
}
-static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
- SourceLocation Loc, StringRef Name,
- MutableArrayRef<Expr *> Args) {
- assert(Coroutine->CoroutinePromise && "no promise for coroutine");
-
- // Form a reference to the promise.
- auto *Promise = Coroutine->CoroutinePromise;
- ExprResult PromiseRef = S.BuildDeclRefExpr(
- Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
- if (PromiseRef.isInvalid())
- return ExprError();
-
- // Call 'yield_value', passing in E.
- return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
-}
-
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
- if (!Coroutine) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) {
CorrectDelayedTyposInExpr(E);
return ExprError();
}
// Build yield_value call.
- ExprResult Awaitable =
- buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
+ ExprResult Awaitable = buildPromiseCall(
+ *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
if (Awaitable.isInvalid())
return ExprError();
@@ -388,18 +538,18 @@
return Res;
}
-StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
- if (!Coroutine) {
+StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) {
CorrectDelayedTyposInExpr(E);
return StmtError();
}
return BuildCoreturnStmt(Loc, E);
}
-StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
- if (!Coroutine)
+StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
+ bool IsImplicitlyCreated) {
+ auto *FSI = checkCoroutineContext(*this, Loc, "co_return");
+ if (!FSI)
return StmtError();
if (E && E->getType()->isPlaceholderType() &&
@@ -412,20 +562,22 @@
// FIXME: If the operand is a reference to a variable that's about to go out
// of scope, we should treat the operand as an xvalue for this overload
// resolution.
+ VarDecl *Promise = FSI->CoroutinePromise;
ExprResult PC;
if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
- PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
+ PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
} else {
E = MakeFullDiscardedValueExpr(E).get();
- PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
+ PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
}
if (PC.isInvalid())
return StmtError();
Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
- Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
- Coroutine->CoroutineStmts.push_back(Res);
+ Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicitlyCreated);
+ if (!IsImplicitlyCreated)
+ FSI->CoroutineStmts.push_back(Res);
return Res;
}
@@ -482,14 +634,82 @@
return OperatorDelete;
}
+static bool buildFallthrough(Sema &S, SourceLocation Loc,
+ FunctionDecl *FD,
+ FunctionScopeInfo *FTI,
+ Stmt *&OnFallthrough)
+{
+ assert(!OnFallthrough && "rebuilding existing OnFallthrough");
+ auto *Promise = FTI->CoroutinePromise;
+ if (Promise->getType()->isDependentType())
+ return true;
+
+ CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl();
+
+ // [dcl.fct.def.coroutine]/4
+ // The unqualified-ids 'return_void' and 'return_value' are looked up in
+ // the scope of class P. If both are found, the program is ill-formed.
+ const bool HasRVoid = lookupMember(S, "return_void", RD, Loc);
+ const bool HasRValue = lookupMember(S, "return_value", RD, Loc);
+ if (HasRVoid && HasRValue) {
+ // FIXME Improve this diagnostic
+ S.Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
+ << RD;
+ return false;
+ } else if (HasRVoid) {
+ // If the unqualified-id return_void is found, flowing off the end of a
+ // coroutine is equivalent to a co_return with no operand. Otherwise,
+ // flowing off the end of a coroutine results in undefined behavior.
+ StmtResult Fallthrough = S.BuildCoreturnStmt(FD->getLocation(), nullptr,
+ /*IsImplicitlyCreated*/ true);
+ if (!Fallthrough.isInvalid())
+ Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
+ if (Fallthrough.isInvalid())
+ return false;
+ OnFallthrough = Fallthrough.get();
+ }
+ return true;
+}
+
+static bool buildSetException(Sema &S, SourceLocation Loc,
+ FunctionDecl *FD,
+ FunctionScopeInfo *FTI,
+ Stmt *&OnException)
+{
+ assert(!OnException && "rebuilding existing set_exception");
+ auto *Promise = FTI->CoroutinePromise;
+ if (Promise->getType()->isDependentType())
+ return true;
+
+ CXXRecordDecl *RD = Promise->getType()->getAsCXXRecordDecl();
+
+ // [dcl.fct.def.coroutine]/3
+ // The unqualified-id set_exception is found in the scope of P by class
+ // member access lookup (3.4.5).
+ if (lookupMember(S, "set_exception", RD, Loc)) {
+ // Form the call 'p.set_exception(std::current_exception())'
+ ExprResult SetException = buildStdCurrentExceptionCall(S, Loc);
+ if (SetException.isInvalid())
+ return false;
+ Expr *E = SetException.get();
+ SetException = buildPromiseCall(S, Promise, Loc, "set_exception", E);
+ SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
+ if (SetException.isInvalid())
+ return false;
+ OnException = SetException.get();
+ }
+ return true;
+}
+
+
// Builds allocation and deallocation for the coroutine. Returns false on
// failure.
static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
FunctionScopeInfo *Fn,
Expr *&Allocation,
Stmt *&Deallocation) {
- TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
- QualType PromiseType = TInfo->getType();
+ assert(!Allocation && !Deallocation && "alloc/dealloc statements have already been built");
+ QualType PromiseType = Fn->CoroutinePromise->getType();
if (PromiseType->isDependentType())
return true;
@@ -532,8 +752,6 @@
if (NewExpr.isInvalid())
return false;
- Allocation = NewExpr.get();
-
// Make delete call.
QualType OpDeleteQualType = OperatorDelete->getType();
@@ -559,149 +777,137 @@
if (DeleteExpr.isInvalid())
return false;
+ Allocation = NewExpr.get();
Deallocation = DeleteExpr.get();
return true;
}
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
- FunctionScopeInfo *Fn = getCurFunction();
- assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
-
- // Coroutines [stmt.return]p1:
- // A return statement shall not appear in a coroutine.
- if (Fn->FirstReturnLoc.isValid()) {
- Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
- auto *First = Fn->CoroutineStmts[0];
- Diag(First->getLocStart(), diag::note_declared_coroutine_here)
- << (isa<CoawaitExpr>(First) ? 0 :
- isa<CoyieldExpr>(First) ? 1 : 2);
- }
-
- bool AnyCoawaits = false;
- bool AnyCoyields = false;
- for (auto *CoroutineStmt : Fn->CoroutineStmts) {
- AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
- AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
- }
-
- if (!AnyCoawaits && !AnyCoyields)
- Diag(Fn->CoroutineStmts.front()->getLocStart(),
- diag::ext_coroutine_without_co_await_co_yield);
-
- SourceLocation Loc = FD->getLocation();
-
+StmtResult Sema::BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+ Stmt *FinalSuspend, Stmt *SetException,
+ Stmt *OnFallthrough, Expr *Allocation,
+ Stmt *Deallocation, Expr *ReturnObjectInit) {
+ assert(Promise && InitSuspend && FinalSuspend && "these nodes must already be built");
// Form a declaration statement for the promise declaration, so that AST
// visitors can more easily find it.
+ // FIXME Get real location
+ auto *FSI = getCurFunction();
+ assert(FSI->CoroutinePromise);
+ auto *FD = cast<FunctionDecl>(CurContext);
+ auto Loc = FD->getLocation();
+
+ auto checkPlaceholders = [&](Stmt *&S) mutable {
+ Expr *E = cast_or_null<Expr>(S);
+ if (E && E->getType()->isPlaceholderType() &&
+ !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
+ ExprResult R = CheckPlaceholderExpr(E);
+ if (R.isInvalid())
+ return false;
+ S = cast<Stmt>(R.get());
+ }
+ return true;
+ };
+ if (!checkPlaceholders(InitSuspend) || !checkPlaceholders(FinalSuspend))
+ return StmtError();
+
StmtResult PromiseStmt =
- ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
+ ActOnDeclStmt(ConvertDeclToDeclGroup(Promise), Promise->getLocStart(),
+ Promise->getLocEnd());
if (PromiseStmt.isInvalid())
- return FD->setInvalidDecl();
-
- // Form and check implicit 'co_await p.initial_suspend();' statement.
- ExprResult InitialSuspend =
- buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
- // FIXME: Support operator co_await here.
- if (!InitialSuspend.isInvalid())
- InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
- InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
- if (InitialSuspend.isInvalid())
- return FD->setInvalidDecl();
-
- // Form and check implicit 'co_await p.final_suspend();' statement.
- ExprResult FinalSuspend =
- buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
- // FIXME: Support operator co_await here.
- if (!FinalSuspend.isInvalid())
- FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
- FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
- if (FinalSuspend.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
- // Form and check allocation and deallocation calls.
- Expr *Allocation = nullptr;
- Stmt *Deallocation = nullptr;
- if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation))
- return FD->setInvalidDecl();
+ if (!OnFallthrough && !buildFallthrough(*this, Loc, FD, FSI, OnFallthrough))
+ return StmtError();
- // control flowing off the end of the coroutine.
- // Also try to form 'p.set_exception(std::current_exception());' to handle
- // uncaught exceptions.
- ExprResult SetException;
- StmtResult Fallthrough;
- if (Fn->CoroutinePromise &&
- !Fn->CoroutinePromise->getType()->isDependentType()) {
- CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl();
- assert(RD && "Type should have already been checked");
- // [dcl.fct.def.coroutine]/4
- // The unqualified-ids 'return_void' and 'return_value' are looked up in
- // the scope of class P. If both are found, the program is ill-formed.
- DeclarationName RVoidDN = PP.getIdentifierInfo("return_void");
- LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName);
- const bool HasRVoid = LookupQualifiedName(RVoidResult, RD);
-
- DeclarationName RValueDN = PP.getIdentifierInfo("return_value");
- LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName);
- const bool HasRValue = LookupQualifiedName(RValueResult, RD);
-
- if (HasRVoid && HasRValue) {
- // FIXME Improve this diagnostic
- Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed)
- << RD;
- return FD->setInvalidDecl();
- } else if (HasRVoid) {
- // If the unqualified-id return_void is found, flowing off the end of a
- // coroutine is equivalent to a co_return with no operand. Otherwise,
- // flowing off the end of a coroutine results in undefined behavior.
- Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr);
- Fallthrough = ActOnFinishFullStmt(Fallthrough.get());
- if (Fallthrough.isInvalid())
- return FD->setInvalidDecl();
- }
+ if (!SetException && !buildSetException(*this, Loc, FD, FSI, SetException))
+ return StmtError();
- // [dcl.fct.def.coroutine]/3
- // The unqualified-id set_exception is found in the scope of P by class
- // member access lookup (3.4.5).
- DeclarationName SetExDN = PP.getIdentifierInfo("set_exception");
- LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName);
- if (LookupQualifiedName(SetExResult, RD)) {
- // Form the call 'p.set_exception(std::current_exception())'
- SetException = buildStdCurrentExceptionCall(*this, Loc);
- if (SetException.isInvalid())
- return FD->setInvalidDecl();
- Expr *E = SetException.get();
- SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E);
- SetException = ActOnFinishFullExpr(SetException.get(), Loc);
- if (SetException.isInvalid())
- return FD->setInvalidDecl();
- }
+ if (!Allocation || !Deallocation) {
+ assert(!Allocation && !Deallocation && "These should be a package deal");
+ if (!buildAllocationAndDeallocation(*this, Loc, FSI, Allocation,
+ Deallocation))
+ return StmtError();
}
// Build implicit 'p.get_return_object()' expression and form initialization
// of return type from it.
- ExprResult ReturnObject =
- buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
- if (ReturnObject.isInvalid())
- return FD->setInvalidDecl();
- QualType RetType = FD->getReturnType();
- if (!RetType->isDependentType()) {
- InitializedEntity Entity =
- InitializedEntity::InitializeResult(Loc, RetType, false);
- ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
- ReturnObject.get());
+ if (!ReturnObjectInit) {
+ ExprResult ReturnObject =
+ buildPromiseCall(*this, Promise, Loc, "get_return_object", None);
if (ReturnObject.isInvalid())
- return FD->setInvalidDecl();
+ return StmtError();
+ QualType RetType = FD->getReturnType();
+ if (!RetType->isDependentType()) {
+ InitializedEntity Entity =
+ InitializedEntity::InitializeResult(Loc, RetType, false);
+ ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
+ ReturnObject.get());
+ if (ReturnObject.isInvalid())
+ return StmtError();
+ }
+ ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
+ if (ReturnObject.isInvalid())
+ return StmtError();
+ ReturnObjectInit = ReturnObject.get();
}
- ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
- if (ReturnObject.isInvalid())
+
+ return new (Context) CoroutineBodyStmt(
+ Body, PromiseStmt.get(), InitSuspend, FinalSuspend,
+ SetException, OnFallthrough, Allocation, Deallocation, ReturnObjectInit, None);
+}
+
+void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+ FunctionScopeInfo *FSI = getCurFunction();
+ assert(FSI && FSI->CoroutinePromise && FSI->HasCoroutineSuspends &&
+ "not a coroutine");
+ VarDecl *Promise = FSI->CoroutinePromise;
+
+ // Check if we failed to build the initial/final suspend points during the
+ // initial parse.
+ if (FSI->hasInvalidCoroutineSuspends())
return FD->setInvalidDecl();
// FIXME: Perform move-initialization of parameters into frame-local copies.
SmallVector<Expr*, 16> ParamMoves;
+ if (Body && !isa<CoroutineBodyStmt>(Body)) {
+ StmtResult BodyRes = BuildCoroutineBodyStmt(
+ Body, FSI->CoroutinePromise, FSI->CoroutineSuspends.first,
+ FSI->CoroutineSuspends.second, nullptr, nullptr, nullptr, nullptr,
+ nullptr);
+ if (BodyRes.isInvalid())
+ return FD->setInvalidDecl();
+ Body = BodyRes.get();
+ }
+
+
+ // Coroutines [stmt.return]p1:
+ // A return statement shall not appear in a coroutine.
+ if (FSI->FirstReturnLoc.isValid()) {
+ Diag(FSI->FirstReturnLoc, diag::err_return_in_coroutine);
+ auto *First = FSI->CoroutineStmts[0];
+ Diag(First->getLocStart(), diag::note_declared_coroutine_here)
+ << ((isa<CoawaitExpr>(First) || isa<DependentCoawaitExpr>(First))
+ ? "co_await"
+ : isa<CoyieldExpr>(First) ? "co_yield" : "co_return");
+ }
+
+ bool AnyCoawaits = false;
+ bool AnyDependentCoawaits = false;
+ bool AnyCoyields = false;
+ for (auto *CoroutineStmt : FSI->CoroutineStmts) {
+ // Don't count the implicitly generated initial/final suspend points
+ if (auto *CA = dyn_cast<CoawaitExpr>(CoroutineStmt))
+ AnyCoawaits |= !CA->isImplicitlyCreated();
+ AnyDependentCoawaits |= isa<DependentCoawaitExpr>(CoroutineStmt);
+ AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
+ }
+
+ if (!FD->isInvalidDecl() && !FSI->CoroutineStmts.empty() && !AnyCoawaits &&
+ !AnyCoyields && !AnyDependentCoawaits)
+ Diag(FSI->CoroutineStmts.front()->getLocStart(),
+ diag::ext_coroutine_without_co_await_co_yield);
+
+ assert((!AnyDependentCoawaits || Promise->getType()->isDependentType()) &&
+ "All dependent coawait expressions should already be resolved");
- // Build body for the coroutine wrapper statement.
- Body = new (Context) CoroutineBodyStmt(
- Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
- SetException.get(), Fallthrough.get(), Allocation, Deallocation,
- ReturnObject.get(), ParamMoves);
}
Index: lib/Sema/ScopeInfo.cpp
===================================================================
--- lib/Sema/ScopeInfo.cpp
+++ lib/Sema/ScopeInfo.cpp
@@ -42,6 +42,9 @@
SwitchStack.clear();
Returns.clear();
CoroutinePromise = nullptr;
+ HasCoroutineSuspends = false;
+ CoroutineSuspends.first = nullptr;
+ CoroutineSuspends.second = nullptr;
CoroutineStmts.clear();
ErrorTrap.reset();
PossiblyUnreachableDiags.clear();
Index: lib/Parse/ParseStmt.cpp
===================================================================
--- lib/Parse/ParseStmt.cpp
+++ lib/Parse/ParseStmt.cpp
@@ -1898,7 +1898,7 @@
}
}
if (IsCoreturn)
- return Actions.ActOnCoreturnStmt(ReturnLoc, R.get());
+ return Actions.ActOnCoreturnStmt(getCurScope(), ReturnLoc, R.get());
return Actions.ActOnReturnStmt(ReturnLoc, R.get(), getCurScope());
}
Index: lib/AST/StmtProfile.cpp
===================================================================
--- lib/AST/StmtProfile.cpp
+++ lib/AST/StmtProfile.cpp
@@ -1552,6 +1552,10 @@
VisitExpr(S);
}
+void StmtProfiler::VisitDependentCoawaitExpr(const DependentCoawaitExpr *S) {
+ VisitExpr(S);
+}
+
void StmtProfiler::VisitCoyieldExpr(const CoyieldExpr *S) {
VisitExpr(S);
}
Index: lib/AST/StmtPrinter.cpp
===================================================================
--- lib/AST/StmtPrinter.cpp
+++ lib/AST/StmtPrinter.cpp
@@ -2422,6 +2422,11 @@
PrintExpr(S->getOperand());
}
+void StmtPrinter::VisitDependentCoawaitExpr(DependentCoawaitExpr *S) {
+ OS << "co_await ";
+ PrintExpr(S->getOperand());
+}
+
void StmtPrinter::VisitCoyieldExpr(CoyieldExpr *S) {
OS << "co_yield ";
PrintExpr(S->getOperand());
Index: lib/AST/ItaniumMangle.cpp
===================================================================
--- lib/AST/ItaniumMangle.cpp
+++ lib/AST/ItaniumMangle.cpp
@@ -3299,6 +3299,8 @@
// These all can only appear in local or variable-initialization
// contexts and so should never appear in a mangling.
case Expr::AddrLabelExprClass:
+ // This should no longer exist in the AST by now
+ case Expr::DependentCoawaitExprClass:
case Expr::DesignatedInitUpdateExprClass:
case Expr::ImplicitValueInitExprClass:
case Expr::NoInitExprClass:
Index: lib/AST/ExprConstant.cpp
===================================================================
--- lib/AST/ExprConstant.cpp
+++ lib/AST/ExprConstant.cpp
@@ -9442,6 +9442,7 @@
case Expr::LambdaExprClass:
case Expr::CXXFoldExprClass:
case Expr::CoawaitExprClass:
+ case Expr::DependentCoawaitExprClass:
case Expr::CoyieldExprClass:
return ICEDiag(IK_NotICE, E->getLocStart());
Index: lib/AST/ExprClassification.cpp
===================================================================
--- lib/AST/ExprClassification.cpp
+++ lib/AST/ExprClassification.cpp
@@ -188,6 +188,9 @@
case Expr::CXXFoldExprClass:
case Expr::NoInitExprClass:
case Expr::DesignatedInitUpdateExprClass:
+ // FIXME How should we classify co_await expressions while they're still
+ // dependent?
+ case Expr::DependentCoawaitExprClass:
case Expr::CoyieldExprClass:
return Cl::CL_PRValue;
Index: lib/AST/Expr.cpp
===================================================================
--- lib/AST/Expr.cpp
+++ lib/AST/Expr.cpp
@@ -2923,6 +2923,7 @@
case CXXNewExprClass:
case CXXDeleteExprClass:
case CoawaitExprClass:
+ case DependentCoawaitExprClass:
case CoyieldExprClass:
// These always have a side-effect.
return true;
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8032,12 +8032,21 @@
//
ExprResult ActOnCoawaitExpr(Scope *S, SourceLocation KwLoc, Expr *E);
ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E);
- StmtResult ActOnCoreturnStmt(SourceLocation KwLoc, Expr *E);
+ StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E);
- ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E);
+ ExprResult BuildCoawaitExpr(SourceLocation KwLoc, Expr *E,
+ bool IsImplicitlyCreated = false);
+ ExprResult BuildDependentCoawaitExpr(SourceLocation KwLoc, Expr *E,
+ UnresolvedLookupExpr *Lookup);
ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
- StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E);
-
+ StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
+ bool IsImplicitlyCreated = false);
+ StmtResult BuildCoroutineBodyStmt(Stmt *Body, VarDecl *Promise, Stmt *InitSuspend,
+ Stmt *FinalSuspend, Stmt *OnException,
+ Stmt *OnFallthrough, Expr *Allocation,
+ Stmt *Deallocation, Expr *ReturnValue);
+
+ VarDecl *buildCoroutinePromise(SourceLocation Loc);
void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
//===--------------------------------------------------------------------===//
Index: include/clang/Sema/ScopeInfo.h
===================================================================
--- include/clang/Sema/ScopeInfo.h
+++ include/clang/Sema/ScopeInfo.h
@@ -16,12 +16,14 @@
#define LLVM_CLANG_SEMA_SCOPEINFO_H
#include "clang/AST/Expr.h"
+#include "clang/AST/StmtCXX.h"
#include "clang/AST/Type.h"
#include "clang/Basic/CapturedStmt.h"
#include "clang/Basic/PartialDiagnostic.h"
#include "clang/Sema/CleanupInfo.h"
#include "clang/Sema/Ownership.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
@@ -135,6 +137,10 @@
/// false if there is an invocation of an initializer on 'self'.
bool ObjCWarnForNoInitDelegation : 1;
+ /// true iff we have attempted to build the initial and final coroutine
+ /// suspend points.
+ bool HasCoroutineSuspends : 1;
+
/// First 'return' statement in the current function.
SourceLocation FirstReturnLoc;
@@ -159,6 +165,9 @@
/// \brief The promise object for this coroutine, if any.
VarDecl *CoroutinePromise;
+ /// \brief The initial and final coroutine suspend points.
+ std::pair<Stmt *, Stmt *> CoroutineSuspends;
+
/// \brief The list of coroutine control flow constructs (co_await, co_yield,
/// co_return) that occur within the function or block. Empty if and only if
/// this function or block is not (yet known to be) a coroutine.
@@ -376,22 +385,33 @@
(HasIndirectGoto ||
(HasBranchProtectedScope && HasBranchIntoScope));
}
-
+
+ void setCoroutineSuspendsInvalid() {
+ assert(!HasCoroutineSuspends && CoroutineSuspends.first == nullptr &&
+ "we already have valid suspend points");
+ HasCoroutineSuspends = true;
+ }
+
+ bool hasInvalidCoroutineSuspends() const {
+ return HasCoroutineSuspends && CoroutineSuspends.first == nullptr;
+ }
+
+ void setCoroutineSuspends(Stmt *Initial, Stmt *Final) {
+ assert(Initial && Final && "suspend points cannot be null");
+ HasCoroutineSuspends = true;
+ CoroutineSuspends.first = Initial;
+ CoroutineSuspends.second = Final;
+ }
+
FunctionScopeInfo(DiagnosticsEngine &Diag)
- : Kind(SK_Function),
- HasBranchProtectedScope(false),
- HasBranchIntoScope(false),
- HasIndirectGoto(false),
- HasDroppedStmt(false),
- HasOMPDeclareReductionCombiner(false),
- HasFallthroughStmt(false),
- HasPotentialAvailabilityViolations(false),
- ObjCShouldCallSuper(false),
- ObjCIsDesignatedInit(false),
- ObjCWarnForNoDesignatedInitChain(false),
- ObjCIsSecondaryInit(false),
- ObjCWarnForNoInitDelegation(false),
- ErrorTrap(Diag) { }
+ : Kind(SK_Function), HasBranchProtectedScope(false),
+ HasBranchIntoScope(false), HasIndirectGoto(false),
+ HasDroppedStmt(false), HasOMPDeclareReductionCombiner(false),
+ HasFallthroughStmt(false), HasPotentialAvailabilityViolations(false),
+ ObjCShouldCallSuper(false), ObjCIsDesignatedInit(false),
+ ObjCWarnForNoDesignatedInitChain(false), ObjCIsSecondaryInit(false),
+ ObjCWarnForNoInitDelegation(false), HasCoroutineSuspends(false),
+ CoroutinePromise(nullptr), ErrorTrap(Diag) {}
virtual ~FunctionScopeInfo();
Index: include/clang/Basic/StmtNodes.td
===================================================================
--- include/clang/Basic/StmtNodes.td
+++ include/clang/Basic/StmtNodes.td
@@ -148,6 +148,7 @@
// C++ Coroutines TS expressions
def CoroutineSuspendExpr : DStmt<Expr, 1>;
def CoawaitExpr : DStmt<CoroutineSuspendExpr>;
+def DependentCoawaitExpr : DStmt<Expr>;
def CoyieldExpr : DStmt<CoroutineSuspendExpr>;
// Obj-C Expressions.
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -8634,8 +8634,7 @@
def err_return_in_coroutine : Error<
"return statement not allowed in coroutine; did you mean 'co_return'?">;
def note_declared_coroutine_here : Note<
- "function is a coroutine due to use of "
- "'%select{co_await|co_yield|co_return}0' here">;
+ "function is a coroutine due to use of '%0' here">;
def err_coroutine_objc_method : Error<
"Objective-C methods as coroutines are not yet supported">;
def err_coroutine_unevaluated_context : Error<
@@ -8659,6 +8658,8 @@
"this function cannot be a coroutine: %q0 has no member named 'promise_type'">;
def err_implied_std_coroutine_traits_promise_type_not_class : Error<
"this function cannot be a coroutine: %0 is not a class">;
+def err_coroutine_promise_type_incomplete : Error<
+ "this function cannot be a coroutine: %0 is an incomplete type">;
def err_coroutine_traits_missing_specialization : Error<
"this function cannot be a coroutine: missing definition of "
"specialization %q0">;
@@ -8669,6 +8670,11 @@
"'std::current_exception' must be a function">;
def err_coroutine_promise_return_ill_formed : Error<
"%0 declares both 'return_value' and 'return_void'">;
+def note_coroutine_promise_implicit_await_transform_required_here : Note<
+ "call to 'await_transform' implicitly required by 'co_await' here">;
+def note_coroutine_promise_call_implicitly_required : Note<
+ "call to '%select{initial_suspend|final_suspend}0' implicitly "
+ "required by the %select{initial suspend point|final suspend point}0">;
}
let CategoryName = "Documentation Issue" in {
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -327,6 +327,8 @@
SubStmts[CoroutineBodyStmt::Allocate] = Allocate;
SubStmts[CoroutineBodyStmt::Deallocate] = Deallocate;
SubStmts[CoroutineBodyStmt::ReturnValue] = ReturnValue;
+ assert(Promise && InitSuspend && FinalSuspend &&
+ "these members must never be null");
// FIXME: Tail-allocate space for parameter move expressions and store them.
assert(ParamMoves.empty() && "not implemented yet");
}
@@ -336,32 +338,54 @@
Stmt *getBody() const {
return SubStmts[SubStmt::Body];
}
-
+ void setBody(Stmt *B) {
+ assert(!B || !isa<CoroutineBodyStmt>(B));
+ SubStmts[SubStmt::Body] = B;
+ }
Stmt *getPromiseDeclStmt() const { return SubStmts[SubStmt::Promise]; }
VarDecl *getPromiseDecl() const {
return cast<VarDecl>(cast<DeclStmt>(getPromiseDeclStmt())->getSingleDecl());
}
+ void setPromiseDeclStmt(Stmt *S) {
+ assert(SubStmts[SubStmt::Promise] == nullptr);
+ SubStmts[SubStmt::Promise] = S;
+ }
+
Stmt *getInitSuspendStmt() const { return SubStmts[SubStmt::InitSuspend]; }
Stmt *getFinalSuspendStmt() const { return SubStmts[SubStmt::FinalSuspend]; }
+ void setInitialSuspendStmt(Stmt *Suspend) {
+ assert(SubStmts[SubStmt::InitSuspend] == nullptr);
+ SubStmts[SubStmt::InitSuspend] = Suspend;
+ }
+ void setFinalSuspendStmt(Stmt *Suspend) {
+ assert(SubStmts[SubStmt::FinalSuspend] == nullptr);
+ SubStmts[SubStmt::FinalSuspend] = Suspend;
+ }
+
Stmt *getExceptionHandler() const { return SubStmts[SubStmt::OnException]; }
+ void setExceptionHandler(Stmt *S) { SubStmts[SubStmt::OnException] = S; }
Stmt *getFallthroughHandler() const {
return SubStmts[SubStmt::OnFallthrough];
}
-
- Expr *getAllocate() const { return cast<Expr>(SubStmts[SubStmt::Allocate]); }
+ void setFalltroughHandler(Stmt *S) { SubStmts[SubStmt::OnFallthrough] = S; }
+ Expr *getAllocate() const { return cast_or_null<Expr>(SubStmts[SubStmt::Allocate]); }
Stmt *getDeallocate() const { return SubStmts[SubStmt::Deallocate]; }
+ void setAllocate(Expr *E) { SubStmts[SubStmt::Allocate] = E; }
+ void setDeallocate(Stmt *S) { SubStmts[SubStmt::Deallocate] = S; }
Expr *getReturnValueInit() const {
- return cast<Expr>(SubStmts[SubStmt::ReturnValue]);
+ return cast_or_null<Expr>(SubStmts[SubStmt::ReturnValue]);
}
+ void setReturnValueInit(Expr *E) { SubStmts[SubStmt::ReturnValue] = E; }
+
SourceLocation getLocStart() const LLVM_READONLY {
- return getBody()->getLocStart();
+ return getBody() ? getBody()->getLocStart() : getPromiseDecl()->getLocStart();
}
SourceLocation getLocEnd() const LLVM_READONLY {
- return getBody()->getLocEnd();
+ return getBody() ? getBody()->getLocEnd() : getPromiseDecl()->getLocStart();
}
child_range children() {
@@ -390,10 +414,14 @@
enum SubStmt { Operand, PromiseCall, Count };
Stmt *SubStmts[SubStmt::Count];
+ bool IsImplicitlyCreated : 1;
+
friend class ASTStmtReader;
public:
- CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall)
- : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc) {
+ CoreturnStmt(SourceLocation CoreturnLoc, Stmt *Operand, Stmt *PromiseCall,
+ bool IsImplicit = false)
+ : Stmt(CoreturnStmtClass), CoreturnLoc(CoreturnLoc),
+ IsImplicitlyCreated(IsImplicit) {
SubStmts[SubStmt::Operand] = Operand;
SubStmts[SubStmt::PromiseCall] = PromiseCall;
}
@@ -410,6 +438,8 @@
Expr *getPromiseCall() const {
return static_cast<Expr*>(SubStmts[PromiseCall]);
}
+ bool isImplicitlyCreated() const { return IsImplicitlyCreated; }
+ void setImplicitlyCreated(bool value = true) { IsImplicitlyCreated = value; }
SourceLocation getLocStart() const LLVM_READONLY { return CoreturnLoc; }
SourceLocation getLocEnd() const LLVM_READONLY {
Index: include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- include/clang/AST/RecursiveASTVisitor.h
+++ include/clang/AST/RecursiveASTVisitor.h
@@ -2471,6 +2471,12 @@
ShouldVisitChildren = false;
}
})
+DEF_TRAVERSE_STMT(DependentCoawaitExpr, {
+ if (!getDerived().shouldVisitImplicitCode()) {
+ TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
+ ShouldVisitChildren = false;
+ }
+})
DEF_TRAVERSE_STMT(CoyieldExpr, {
if (!getDerived().shouldVisitImplicitCode()) {
TRY_TO_TRAVERSE_OR_ENQUEUE_STMT(S->getOperand());
Index: include/clang/AST/ExprCXX.h
===================================================================
--- include/clang/AST/ExprCXX.h
+++ include/clang/AST/ExprCXX.h
@@ -4231,26 +4231,82 @@
/// \brief Represents a 'co_await' expression.
class CoawaitExpr : public CoroutineSuspendExpr {
friend class ASTStmtReader;
+
+ /// \brief True if this co_await expression was implicitly generated by the
+ /// compiler.
+ bool IsImplicitlyCreated : 1;
+
public:
CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready,
- Expr *Suspend, Expr *Resume)
+ Expr *Suspend, Expr *Resume, bool IsImplicit = false)
: CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready,
- Suspend, Resume) {}
- CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand)
- : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {}
+ Suspend, Resume),
+ IsImplicitlyCreated(IsImplicit) {}
+ CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand,
+ bool IsImplicit = false)
+ : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand),
+ IsImplicitlyCreated(IsImplicit) {}
CoawaitExpr(EmptyShell Empty)
: CoroutineSuspendExpr(CoawaitExprClass, Empty) {}
Expr *getOperand() const {
// FIXME: Dig out the actual operand or store it.
return getCommonExpr();
}
+ bool isImplicitlyCreated() const { return IsImplicitlyCreated; }
+ void setIsImplicitlyCreated(bool value = true) {
+ IsImplicitlyCreated = value;
+ }
+
static bool classof(const Stmt *T) {
return T->getStmtClass() == CoawaitExprClass;
}
};
+/// \brief Represents a 'co_await' expression while the type of the promise
+/// is dependent.
+class DependentCoawaitExpr : public Expr {
+ SourceLocation KeywordLoc;
+ Stmt *SubExprs[2];
+
+ friend class ASTStmtReader;
+
+public:
+ DependentCoawaitExpr(SourceLocation KeywordLoc, QualType Ty, Expr *Op,
+ UnresolvedLookupExpr *OpCoawait)
+ : Expr(DependentCoawaitExprClass, Ty, VK_RValue, OK_Ordinary,
+ /*TypeDependent*/ true, /*ValueDependent*/ true,
+ /*InstantiationDependent*/ true,
+ Op->containsUnexpandedParameterPack()),
+ KeywordLoc(KeywordLoc) {
+ assert(Op->isTypeDependent() && Ty->isDependentType() &&
+ "wrong constructor for non-dependent co_await/co_yield expression");
+ SubExprs[0] = Op;
+ SubExprs[1] = OpCoawait;
+ }
+
+ DependentCoawaitExpr(EmptyShell Empty)
+ : Expr(DependentCoawaitExprClass, Empty) {}
+
+ Expr *getOperand() const { return cast<Expr>(SubExprs[0]); }
+ UnresolvedLookupExpr *getOperatorCoawaitLookup() const {
+ return cast<UnresolvedLookupExpr>(SubExprs[1]);
+ }
+ SourceLocation getKeywordLoc() const { return KeywordLoc; }
+
+ SourceLocation getLocStart() const LLVM_READONLY { return KeywordLoc; }
+ SourceLocation getLocEnd() const LLVM_READONLY {
+ return getOperand()->getLocEnd();
+ }
+
+ child_range children() { return child_range(SubExprs, SubExprs + 2); }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == DependentCoawaitExprClass;
+ }
+};
+
/// \brief Represents a 'co_yield' expression.
class CoyieldExpr : public CoroutineSuspendExpr {
friend class ASTStmtReader;
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits