EricWF created this revision.
EricWF added reviewers: rsmith, GorNishanov, majnemer.
EricWF added a subscriber: cfe-commits.
Herald added a subscriber: mehdi_amini.

When using TreeTransform to rebuild a coroutine the coroutine `promise_type` 
variable is not transformed because it's stored in the `FunctionScopeInfo` and 
not as a regular sub-statement of the function.

This patch attempts to fix this by rebuilding the promise variable at the start 
of transforming a `CoroutineBodyStmt`. Additionally this patch changes 
`TransformCoroutineBodyStmt` so that it always re-builds the coroutine body. Is 
there a better alternative to always rebuilding the body? Or is always 
rebuilding sufficient for now?


https://reviews.llvm.org/D25303

Files:
  include/clang/Sema/Sema.h
  lib/Sema/SemaCoroutine.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/TreeTransform.h
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -52,21 +52,21 @@
   using promise_type = Promise;
 };
 
-void no_specialization() {
-  co_await a; // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits<void>'}}
+void no_specialization() { // expected-error {{implicit instantiation of undefined template 'std::experimental::coroutine_traits<void>'}}
+  co_await a;
 }
 
 template <typename... T>
 struct std::experimental::coroutine_traits<int, T...> {};
 
-int no_promise_type() {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+int no_promise_type() { // expected-error {{this function cannot be a coroutine: 'std::experimental::coroutine_traits<int>' has no member named 'promise_type'}}
+  co_await a;
 }
 
 template <>
 struct std::experimental::coroutine_traits<double, double> { typedef int promise_type; };
-double bad_promise_type(double) {
-  co_await a; // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+double bad_promise_type(double) { // expected-error {{this function cannot be a coroutine: 'experimental::coroutine_traits<double, double>::promise_type' (aka 'int') is not a class}}
+  co_await a;
 }
 
 template <>
@@ -77,7 +77,7 @@
   co_yield 0; // expected-error {{no member named 'yield_value' in 'std::experimental::coroutine_traits<double, int>::promise_type'}}
 }
 
-struct promise; // expected-note 2{{forward declaration}}
+struct promise; // expected-note 3{{forward declaration}}
 template <typename... T>
 struct std::experimental::coroutine_traits<void, T...> { using promise_type = promise; };
 
@@ -94,6 +94,12 @@
   // expected-error@-2 {{incomplete definition of type 'promise'}}
   co_await a;
 }
+template <class T>
+void undefined_promise_template(T) { // expected-error {{variable has incomplete type 'promise_type'}}
+  // FIXME: This diagnostic doesn't make any sense.
+  co_await a;
+}
+template void undefined_promise_template(int); // expected-note {{requested here}}
 
 struct yielded_thing { const char *p; short a, b; };
 
@@ -299,6 +305,16 @@
   co_await a;
 }
 
+struct not_class_tag {};
+template <>
+struct std::experimental::coroutine_traits<void, not_class_tag> { using promise_type = int; };
+
+template <class T>
+void promise_type_not_class(T) {
+  // expected-error@-1 {{this function cannot be a coroutine: 'experimental::coroutine_traits<void, not_class_tag>::promise_type' (aka 'int') is not a class}}
+  co_await a;
+}
+template void promise_type_not_class(not_class_tag); // expected-note {{requested here}}
 
 template<> struct std::experimental::coroutine_traits<int, int, const char**>
 { using promise_type = promise; };
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -1326,6 +1326,16 @@
     return getSema().BuildCoyieldExpr(CoyieldLoc, Result);
   }
 
+  /// \brief Build a new coroutine body.
+  ///
+  /// By default, performs semantic analysis to build the new body.
+  /// Subclasses may override this routine to provide different behavior.
+  StmtResult RebuildCoroutineBodyStmt(Stmt *Body) {
+    auto *FD = dyn_cast<FunctionDecl>(getSema().CurContext);
+    assert(FD); // FIXME this assertion should never fire
+    return getSema().ActOnFinishCoroutineBody(FD, Body);
+  }
+
   /// \brief Build a new Objective-C \@try statement.
   ///
   /// By default, performs semantic analysis to build the new statement.
@@ -6651,8 +6661,25 @@
 template<typename Derived>
 StmtResult
 TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) {
-  // The coroutine body should be re-formed by the caller if necessary.
-  return getDerived().TransformStmt(S->getBody());
+  // FIXME: Don't rebuild the entire coroutine body.
+  // The coroutine body should only be re-formed by the caller if necessary.
+  FunctionScopeInfo *FS = getSema().getCurFunction();
+  assert(FS);
+  VarDecl *VD =
+      getSema().buildCoroutinePromise(S->getPromiseDecl()->getLocation());
+  if (!VD || VD->isInvalidDecl())
+    return StmtError();
+  getDerived().transformedLocalDecl(S->getPromiseDecl(), VD);
+  // FIXME: Re-setting FS->CoroutinePromise feels like a hack. Is there a better
+  // way to do this? Currently this is needed so the rebuilt body uses the
+  // transformed promise type.
+  FS->CoroutinePromise = VD;
+
+  StmtResult Body = getDerived().TransformStmt(S->getBody());
+  if (Body.isInvalid())
+    return StmtError();
+
+  return getDerived().RebuildCoroutineBodyStmt(Body.get());
 }
 
 template<typename Derived>
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -11626,13 +11626,22 @@
 
 Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body,
                                     bool IsInstantiation) {
-  FunctionDecl *FD = dcl ? dcl->getAsFunction() : nullptr;
+  if (!dcl)
+    return nullptr;
+  FunctionDecl *FD = dcl->getAsFunction();
 
   sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy();
   sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr;
 
-  if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty())
-    CheckCompletedCoroutineBody(FD, Body);
+  if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) {
+    // FIXME: support ObjC methods here
+    assert(FD && "Objective C methods are not supported");
+    StmtResult NewBody = ActOnFinishCoroutineBody(FD, Body);
+    if (NewBody.isInvalid())
+      FD->setInvalidDecl();
+    else
+      Body = NewBody.get();
+  }
 
   if (FD) {
     FD->setBody(Body);
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -24,18 +24,19 @@
 /// Look up the std::coroutine_traits<...>::promise_type for the given
 /// function type.
 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
-                                  SourceLocation Loc) {
+                                  SourceLocation KWLoc,
+                                  SourceLocation FuncLoc) {
   // FIXME: Cache std::coroutine_traits once we've found it.
   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
   if (!StdExp) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
-                      Loc, Sema::LookupOrdinaryName);
+                      FuncLoc, Sema::LookupOrdinaryName);
   if (!S.LookupQualifiedName(Result, StdExp)) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
+    S.Diag(KWLoc, diag::err_implied_std_coroutine_traits_not_found);
     return QualType();
   }
 
@@ -49,36 +50,37 @@
   }
 
   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
-  TemplateArgumentListInfo Args(Loc, Loc);
+  TemplateArgumentListInfo Args(FuncLoc, FuncLoc);
   Args.addArgument(TemplateArgumentLoc(
       TemplateArgument(FnType->getReturnType()),
-      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
+      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), FuncLoc)));
   // FIXME: If the function is a non-static member function, add the type
   // of the implicit object parameter before the formal parameters.
   for (QualType T : FnType->getParamTypes())
     Args.addArgument(TemplateArgumentLoc(
-        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
+        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, FuncLoc)));
 
   // Build the template-id.
   QualType CoroTrait =
-      S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
+      S.CheckTemplateIdType(TemplateName(CoroTraits), FuncLoc, Args);
   if (CoroTrait.isNull())
     return QualType();
-  if (S.RequireCompleteType(Loc, CoroTrait,
+  if (S.RequireCompleteType(FuncLoc, CoroTrait,
                             diag::err_coroutine_traits_missing_specialization))
     return QualType();
 
   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
   assert(RD && "specialization of class template is not a class?");
 
   // Look up the ::promise_type member.
-  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
+  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), FuncLoc,
                  Sema::LookupOrdinaryName);
   S.LookupQualifiedName(R, RD);
   auto *Promise = R.getAsSingle<TypeDecl>();
   if (!Promise) {
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
-      << RD;
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_found)
+        << RD;
     return QualType();
   }
 
@@ -91,14 +93,38 @@
                                       CoroTrait.getTypePtr());
     PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
 
-    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
-      << PromiseType;
+    S.Diag(FuncLoc,
+           diag::err_implied_std_coroutine_traits_promise_type_not_class)
+        << PromiseType;
     return QualType();
   }
 
   return PromiseType;
 }
 
+VarDecl *Sema::buildCoroutinePromise(SourceLocation KWLoc) {
+  auto *FD = dyn_cast<FunctionDecl>(CurContext);
+  assert(FD && "Not inside a function context");
+  SourceLocation FuncLoc = FD->getLocation();
+  QualType T =
+      FD->getType()->isDependentType()
+          ? Context.DependentTy
+          : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+                              KWLoc, FuncLoc);
+  if (T.isNull())
+    return nullptr;
+
+  // Create and default-initialize the promise.
+  VarDecl *VD =
+      VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+                      &PP.getIdentifierTable().get("__promise"), T,
+                      Context.getTrivialTypeSourceInfo(T, FuncLoc), SC_None);
+  CheckVariableDeclarationType(VD);
+  if (!VD->isInvalidDecl())
+    ActOnUninitializedDecl(VD, false);
+  return VD;
+}
+
 /// Check that this is a context in which a coroutine suspension can appear.
 static FunctionScopeInfo *
 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
@@ -138,22 +164,8 @@
 
     // If we don't have a promise variable, build one now.
     if (!ScopeInfo->CoroutinePromise) {
-      QualType T =
-          FD->getType()->isDependentType()
-              ? S.Context.DependentTy
-              : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
-                                  Loc);
-      if (T.isNull())
+      if (!(ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc)))
         return nullptr;
-
-      // Create and default-initialize the promise.
-      ScopeInfo->CoroutinePromise =
-          VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
-                          &S.PP.getIdentifierTable().get("__promise"), T,
-                          S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
-      S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
-      if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
-        S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
     }
 
     return ScopeInfo;
@@ -378,7 +390,7 @@
   return Res;
 }
 
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+StmtResult Sema::ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body) {
   FunctionScopeInfo *Fn = getCurFunction();
   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
 
@@ -410,7 +422,7 @@
   StmtResult PromiseStmt =
       ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
   if (PromiseStmt.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
 
   // Form and check implicit 'co_await p.initial_suspend();' statement.
   ExprResult InitialSuspend =
@@ -420,7 +432,7 @@
     InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
   InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
   if (InitialSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
 
   // Form and check implicit 'co_await p.final_suspend();' statement.
   ExprResult FinalSuspend =
@@ -430,7 +442,7 @@
     FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
   FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
   if (FinalSuspend.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
 
   // FIXME: Perform analysis of set_exception call.
 
@@ -442,26 +454,26 @@
   ExprResult ReturnObject =
     buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
   QualType RetType = FD->getReturnType();
   if (!RetType->isDependentType()) {
     InitializedEntity Entity =
         InitializedEntity::InitializeResult(Loc, RetType, false);
     ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
                                                    ReturnObject.get());
     if (ReturnObject.isInvalid())
-      return FD->setInvalidDecl();
+      return StmtError();
   }
   ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
   if (ReturnObject.isInvalid())
-    return FD->setInvalidDecl();
+    return StmtError();
 
   // FIXME: Perform move-initialization of parameters into frame-local copies.
   SmallVector<Expr*, 16> ParamMoves;
 
   // Build body for the coroutine wrapper statement.
-  Body = new (Context) CoroutineBodyStmt(
+  return new (Context) CoroutineBodyStmt(
       Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
-      /*SetException*/nullptr, /*Fallthrough*/nullptr,
-      ReturnObject.get(), ParamMoves);
+      /*SetException*/ nullptr, /*Fallthrough*/ nullptr, ReturnObject.get(),
+      ParamMoves);
 }
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -8024,7 +8024,9 @@
   ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
   StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E);
 
-  void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
+  VarDecl *buildCoroutinePromise(SourceLocation KWLoc);
+
+  StmtResult ActOnFinishCoroutineBody(FunctionDecl *FD, Stmt *Body);
 
   //===--------------------------------------------------------------------===//
   // OpenMP directives and clauses.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to