EricWF updated this revision to Diff 93508.
EricWF added a comment.

- Fix insane definition of `hasDependentPromiseType()`


https://reviews.llvm.org/D31487

Files:
  include/clang/AST/StmtCXX.h
  lib/AST/StmtCXX.cpp
  lib/Sema/CoroutineBuilder.h
  lib/Sema/SemaCoroutine.cpp
  lib/Sema/TreeTransform.h
  test/SemaCXX/coroutines.cpp

Index: test/SemaCXX/coroutines.cpp
===================================================================
--- test/SemaCXX/coroutines.cpp
+++ test/SemaCXX/coroutines.cpp
@@ -534,6 +534,12 @@
   co_await a;
 }
 
+template <class T>
+coro<T> bad_implicit_return_dependent(T) { // expected-error {{'bad_promise_6' declares both 'return_value' and 'return_void'}}
+  co_await a;
+}
+template coro<bad_promise_6> bad_implicit_return_dependent(bad_promise_6); // expected-note {{in instantiation}}
+
 struct bad_promise_7 {
   coro<bad_promise_7> get_return_object();
   suspend_always initial_suspend();
@@ -544,25 +550,38 @@
   co_await a;
 }
 
+template <class T>
+coro<T> no_unhandled_exception_dependent(T) { // expected-error {{'bad_promise_7' is required to declare the member 'unhandled_exception()'}}
+  co_await a;
+}
+template coro<bad_promise_7> no_unhandled_exception_dependent(bad_promise_7); // expected-note {{in instantiation}}
+
 struct bad_promise_base {
 private:
   void return_void();
 };
 struct bad_promise_8 : bad_promise_base {
   coro<bad_promise_8> get_return_object();
   suspend_always initial_suspend();
   suspend_always final_suspend();
-  void unhandled_exception() __attribute__((unavailable)); // expected-note {{made unavailable}}
-  void unhandled_exception() const;                        // expected-note {{candidate}}
-  void unhandled_exception(void *) const;                  // expected-note {{requires 1 argument, but 0 were provided}}
+  void unhandled_exception() __attribute__((unavailable)); // expected-note 2 {{made unavailable}}
+  void unhandled_exception() const;                        // expected-note 2 {{candidate}}
+  void unhandled_exception(void *) const;                  // expected-note 2 {{requires 1 argument, but 0 were provided}}
 };
 coro<bad_promise_8> calls_unhandled_exception() {
   // expected-error@-1 {{call to unavailable member function 'unhandled_exception'}}
   // FIXME: also warn about private 'return_void' here. Even though building
   // the call to unhandled_exception has already failed.
   co_await a;
 }
 
+template <class T>
+coro<T> calls_unhandled_exception_dependent(T) {
+  // expected-error@-1 {{call to unavailable member function 'unhandled_exception'}}
+  co_await a;
+}
+template coro<bad_promise_8> calls_unhandled_exception_dependent(bad_promise_8); // expected-note {{in instantiation}}
+
 struct bad_promise_9 {
   coro<bad_promise_9> get_return_object();
   suspend_always initial_suspend();
@@ -652,3 +671,26 @@
 extern "C" int f(promise_on_alloc_failure_tag) {
   co_return; //expected-note {{function is a coroutine due to use of 'co_return' here}}
 }
+
+struct bad_promise_11 {
+  coro<bad_promise_11> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void unhandled_exception();
+  void return_void();
+
+private:
+  static coro<bad_promise_11> get_return_object_on_allocation_failure(); // expected-note 2 {{declared private here}}
+};
+coro<bad_promise_11> private_alloc_failure_handler() {
+  // expected-error@-1 {{'get_return_object_on_allocation_failure' is a private member of 'bad_promise_11'}}
+  co_return; // FIXME: Add a "declared coroutine here" note.
+}
+
+template <class T>
+coro<T> dependent_private_alloc_failure_handler(T) {
+  // expected-error@-1 {{'get_return_object_on_allocation_failure' is a private member of 'bad_promise_11'}}
+  co_return; // FIXME: Add a "declared coroutine here" note.
+}
+template coro<bad_promise_11> dependent_private_alloc_failure_handler(bad_promise_11);
+// expected-note@-1 {{requested here}}
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H
 #define LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H
 
+#include "CoroutineBuilder.h"
 #include "TypeLocBuilder.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclObjC.h"
@@ -6849,11 +6850,10 @@
 TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) {
   // The coroutine body should be re-formed by the caller if necessary.
   // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody
-  CoroutineBodyStmt::CtorArgs BodyArgs;
 
   auto *ScopeInfo = SemaRef.getCurFunction();
   auto *FD = cast<FunctionDecl>(SemaRef.CurContext);
-  assert(ScopeInfo && !ScopeInfo->CoroutinePromise &&
+  assert(FD && ScopeInfo && !ScopeInfo->CoroutinePromise &&
          ScopeInfo->NeedsCoroutineSuspends &&
          ScopeInfo->CoroutineSuspends.first == nullptr &&
          ScopeInfo->CoroutineSuspends.second == nullptr &&
@@ -6865,17 +6865,11 @@
 
   // The new CoroutinePromise object needs to be built and put into the current
   // FunctionScopeInfo before any transformations or rebuilding occurs.
-  auto *Promise = S->getPromiseDecl();
-  auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation());
-  if (!NewPromise)
+  auto *Promise = SemaRef.buildCoroutinePromise(FD->getLocation());
+  if (!Promise)
     return StmtError();
-  getDerived().transformedLocalDecl(Promise, NewPromise);
-  ScopeInfo->CoroutinePromise = NewPromise;
-  StmtResult PromiseStmt = SemaRef.ActOnDeclStmt(
-          SemaRef.ConvertDeclToDeclGroup(NewPromise),
-          FD->getLocation(), FD->getLocation());
-  assert(!PromiseStmt.isInvalid());
-  BodyArgs.Promise = PromiseStmt.get();
+  getDerived().transformedLocalDecl(S->getPromiseDecl(), Promise);
+  ScopeInfo->CoroutinePromise = Promise;
 
   // Transform the implicit coroutine statements we built during the initial
   // parse.
@@ -6888,52 +6882,70 @@
     return StmtError();
   ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
   assert(isa<Expr>(InitSuspend.get()) && isa<Expr>(FinalSuspend.get()));
-  BodyArgs.InitialSuspend = cast<Expr>(InitSuspend.get());
-  BodyArgs.FinalSuspend = cast<Expr>(FinalSuspend.get());
 
   StmtResult BodyRes = getDerived().TransformStmt(S->getBody());
   if (BodyRes.isInvalid())
     return StmtError();
-  BodyArgs.Body = BodyRes.get();
 
-  if (S->getFallthroughHandler()) {
-    StmtResult Res = getDerived().TransformStmt(S->getFallthroughHandler());
-    if (Res.isInvalid())
-      return StmtError();
-    BodyArgs.OnFallthrough = Res.get();
-  }
+  CoroutineStmtBuilder Builder(SemaRef, *FD, *ScopeInfo, BodyRes.get());
+  if (Builder.isInvalid())
+    return StmtError();
 
-  if (S->getExceptionHandler()) {
-    StmtResult Res = getDerived().TransformStmt(S->getExceptionHandler());
-    if (Res.isInvalid())
+  Expr *ReturnObject = S->getReturnValueInit();
+  assert(ReturnObject && "the return object is expected to be valid");
+  ExprResult Res = getDerived().TransformInitializer(ReturnObject,
+                                                     /*NoCopyInit*/ false);
+  if (Res.isInvalid())
+    return StmtError();
+  Builder.ReturnValue = Res.get();
+
+  if (S->hasDependentPromiseType()) {
+    assert(!Promise->getType()->isDependentType() &&
+           "the promise type must no longer be dependent");
+    assert(!S->getFallthroughHandler() && !S->getExceptionHandler() &&
+           !S->getReturnStmtOnAllocFailure() && !S->getDeallocate() &&
+           "these nodes should not have been built yet");
+    if (!Builder.buildDependentStatements())
       return StmtError();
-    BodyArgs.OnException = Res.get();
-  }
+  } else {
+    if (S->getFallthroughHandler()) {
+      StmtResult Res = getDerived().TransformStmt(S->getFallthroughHandler());
+      if (Res.isInvalid())
+        return StmtError();
+      Builder.OnFallthrough = Res.get();
+    }
 
-  // Transform any additional statements we may have already built
-  if (S->getAllocate() && S->getDeallocate()) {
+    if (S->getExceptionHandler()) {
+      StmtResult Res = getDerived().TransformStmt(S->getExceptionHandler());
+      if (Res.isInvalid())
+        return StmtError();
+      Builder.OnException = Res.get();
+    }
+
+    if (S->getReturnStmtOnAllocFailure()) {
+      StmtResult Res =
+          getDerived().TransformStmt(S->getReturnStmtOnAllocFailure());
+      if (Res.isInvalid())
+        return StmtError();
+      Builder.ReturnStmtOnAllocFailure = Res.get();
+    }
+
+    // Transform any additional statements we may have already built
+    assert(S->getAllocate() && S->getDeallocate() &&
+           "allocation and deallocation calls must already be built");
     ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate());
     if (AllocRes.isInvalid())
       return StmtError();
-    BodyArgs.Allocate = AllocRes.get();
+    Builder.Allocate = AllocRes.get();
 
     ExprResult DeallocRes = getDerived().TransformExpr(S->getDeallocate());
     if (DeallocRes.isInvalid())
       return StmtError();
-    BodyArgs.Deallocate = DeallocRes.get();
-  }
-
-  Expr *ReturnObject = S->getReturnValueInit();
-  if (ReturnObject) {
-    ExprResult Res = getDerived().TransformInitializer(ReturnObject,
-            /*NoCopyInit*/false);
-    if (Res.isInvalid())
-      return StmtError();
-    BodyArgs.ReturnValue = Res.get();
+    Builder.Deallocate = DeallocRes.get();
   }
 
   // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo
-  return getDerived().RebuildCoroutineBodyStmt(BodyArgs);
+  return getDerived().RebuildCoroutineBodyStmt(Builder);
 }
 
 template<typename Derived>
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -11,13 +11,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "clang/Sema/SemaInternal.h"
+#include "CoroutineBuilder.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Sema/Initialization.h"
 #include "clang/Sema/Overload.h"
+#include "clang/Sema/SemaInternal.h"
+
 using namespace clang;
 using namespace sema;
 
@@ -683,47 +685,6 @@
   return OperatorDelete;
 }
 
-namespace {
-class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
-  Sema &S;
-  FunctionDecl &FD;
-  FunctionScopeInfo &Fn;
-  bool IsValid;
-  SourceLocation Loc;
-  QualType RetType;
-  SmallVector<Stmt *, 4> ParamMovesVector;
-  const bool IsPromiseDependentType;
-  CXXRecordDecl *PromiseRecordDecl = nullptr;
-
-public:
-  SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
-      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
-        IsPromiseDependentType(
-            !Fn.CoroutinePromise ||
-            Fn.CoroutinePromise->getType()->isDependentType()) {
-    this->Body = Body;
-    if (!IsPromiseDependentType) {
-      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
-      assert(PromiseRecordDecl && "Type should have already been checked");
-    }
-    this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() &&
-                    makeOnException() && makeOnFallthrough() &&
-                    makeReturnOnAllocFailure() && makeNewAndDeleteExpr() &&
-                    makeReturnObject() && makeParamMoves();
-  }
-
-  bool isInvalid() const { return !this->IsValid; }
-
-  bool makePromiseStmt();
-  bool makeInitialAndFinalSuspend();
-  bool makeNewAndDeleteExpr();
-  bool makeOnFallthrough();
-  bool makeOnException();
-  bool makeReturnObject();
-  bool makeReturnOnAllocFailure();
-  bool makeParamMoves();
-};
-}
 
 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
   FunctionScopeInfo *Fn = getCurFunction();
@@ -750,15 +711,15 @@
     Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
             << Fn->getFirstCoroutineStmtKeyword();
   }
-  SubStmtBuilder Builder(*this, *FD, *Fn, Body);
-  if (Builder.isInvalid())
+  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
+  if (Builder.isInvalid() || !Builder.buildStatements())
     return FD->setInvalidDecl();
 
   // Build body for the coroutine wrapper statement.
   Body = CoroutineBodyStmt::Create(Context, Builder);
 }
 
-bool SubStmtBuilder::makePromiseStmt() {
+bool CoroutineStmtBuilder::makePromiseStmt() {
   // Form a declaration statement for the promise declaration, so that AST
   // visitors can more easily find it.
   StmtResult PromiseStmt =
@@ -770,7 +731,7 @@
   return true;
 }
 
-bool SubStmtBuilder::makeInitialAndFinalSuspend() {
+bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
   if (Fn.hasInvalidCoroutineSuspends())
     return false;
   this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
@@ -801,8 +762,9 @@
   return false;
 }
 
-bool SubStmtBuilder::makeReturnOnAllocFailure() {
-  if (!PromiseRecordDecl) return true;
+bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
+  assert(!IsPromiseDependentType &&
+         "cannot make statement while the promise type is dependent");
 
   // [dcl.fct.def.coroutine]/8
   // The unqualified-id get_return_object_on_allocation_failure is looked up in
@@ -813,41 +775,42 @@
   DeclarationName DN =
       S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
   LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
-  // Suppress diagnostics when a private member is selected. The same warnings
-  // will be produced again when building the call.
-  Found.suppressDiagnostics();
-  if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) return true;
+  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
+    return true;
 
   CXXScopeSpec SS;
   ExprResult DeclNameExpr =
       S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
-  if (DeclNameExpr.isInvalid()) return false;
+  if (DeclNameExpr.isInvalid())
+    return false;
 
   if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
     return false;
 
   ExprResult ReturnObjectOnAllocationFailure =
       S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
-  if (ReturnObjectOnAllocationFailure.isInvalid()) return false;
+  if (ReturnObjectOnAllocationFailure.isInvalid())
+    return false;
 
   // FIXME: ActOnReturnStmt expects a scope that is inside of the function, due
   //   to CheckJumpOutOfSEHFinally(*this, ReturnLoc, *CurScope->getFnParent());
   //   S.getCurScope()->getFnParent() == nullptr at ActOnFinishFunctionBody when
   //   CoroutineBodyStmt is built. Figure it out and fix it.
   //   Use BuildReturnStmt here to unbreak sanitized tests. (Gor:3/27/2017)
   StmtResult ReturnStmt =
       S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
-  if (ReturnStmt.isInvalid()) return false;
+  if (ReturnStmt.isInvalid())
+    return false;
 
   this->ReturnStmtOnAllocFailure = ReturnStmt.get();
   return true;
 }
 
-bool SubStmtBuilder::makeNewAndDeleteExpr() {
+bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
   // Form and check allocation and deallocation calls.
+  assert(!IsPromiseDependentType &&
+         "cannot make statement while the promise type is dependent");
   QualType PromiseType = Fn.CoroutinePromise->getType();
-  if (PromiseType->isDependentType())
-    return true;
 
   if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
     return false;
@@ -920,9 +883,9 @@
   return true;
 }
 
-bool SubStmtBuilder::makeOnFallthrough() {
-  if (!PromiseRecordDecl)
-    return true;
+bool CoroutineStmtBuilder::makeOnFallthrough() {
+  assert(!IsPromiseDependentType &&
+         "cannot make statement while the promise type is dependent");
 
   // [dcl.fct.def.coroutine]/4
   // The unqualified-ids 'return_void' and 'return_value' are looked up in
@@ -951,11 +914,10 @@
   return true;
 }
 
-bool SubStmtBuilder::makeOnException() {
+bool CoroutineStmtBuilder::makeOnException() {
   // Try to form 'p.unhandled_exception();'
-
-  if (!PromiseRecordDecl)
-    return true;
+  assert(!IsPromiseDependentType &&
+         "cannot make statement while the promise type is dependent");
 
   const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
 
@@ -983,7 +945,7 @@
   return true;
 }
 
-bool SubStmtBuilder::makeReturnObject() {
+bool CoroutineStmtBuilder::makeReturnObject() {
 
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
@@ -1008,7 +970,7 @@
   return true;
 }
 
-bool SubStmtBuilder::makeParamMoves() {
+bool CoroutineStmtBuilder::makeParamMoves() {
   // FIXME: Perform move-initialization of parameters into frame-local copies.
   return true;
 }
Index: lib/Sema/CoroutineBuilder.h
===================================================================
--- /dev/null
+++ lib/Sema/CoroutineBuilder.h
@@ -0,0 +1,83 @@
+//===----- CoroutineBuilder.h - Coroutine Semantic checking -----*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//===----------------------------------------------------------------------===//
+//
+//  This file implements a semantic tree transformation that takes a given
+//  AST and rebuilds it, possibly transforming some nodes in the process.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SEMA_COROUTINEBUILDER_H
+#define LLVM_CLANG_LIB_SEMA_COROUTINEBUILDER_H
+
+#include "clang/AST/Decl.h"
+#include "clang/AST/ExprCXX.h"
+#include "clang/AST/StmtCXX.h"
+#include "clang/Lex/Preprocessor.h"
+#include "clang/Sema/SemaInternal.h"
+
+namespace clang {
+
+class CoroutineStmtBuilder : public CoroutineBodyStmt::CtorArgs {
+  Sema &S;
+  FunctionDecl &FD;
+  sema::FunctionScopeInfo &Fn;
+  bool IsValid = true;
+  SourceLocation Loc;
+  QualType RetType;
+  SmallVector<Stmt *, 4> ParamMovesVector;
+  const bool IsPromiseDependentType;
+  CXXRecordDecl *PromiseRecordDecl = nullptr;
+
+public:
+  CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, sema::FunctionScopeInfo &Fn,
+                       Stmt *Body)
+      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
+        IsPromiseDependentType(
+            !Fn.CoroutinePromise ||
+            Fn.CoroutinePromise->getType()->isDependentType()) {
+    this->Body = Body;
+    if (!IsPromiseDependentType) {
+      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
+      assert(PromiseRecordDecl && "Type should have already been checked");
+    }
+    this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
+  }
+
+  bool buildStatements() {
+    assert(this->IsValid && "coroutine already invalid");
+    this->IsValid = makeReturnObject() && makeParamMoves();
+    if (this->IsValid && !IsPromiseDependentType)
+      buildDependentStatements();
+    return this->IsValid;
+  }
+
+  bool buildDependentStatements() {
+    assert(this->IsValid && "coroutine already invalid");
+    assert(!this->IsPromiseDependentType &&
+           "coroutine cannot have a dependent promise type");
+    this->IsValid = makeOnException() && makeOnFallthrough() &&
+                    makeReturnOnAllocFailure() && makeNewAndDeleteExpr();
+    return this->IsValid;
+  }
+
+  bool isInvalid() const { return !this->IsValid; }
+
+private:
+  bool makePromiseStmt();
+  bool makeInitialAndFinalSuspend();
+  bool makeNewAndDeleteExpr();
+  bool makeOnFallthrough();
+  bool makeOnException();
+  bool makeReturnObject();
+  bool makeReturnOnAllocFailure();
+  bool makeParamMoves();
+};
+
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SEMA_COROUTINEBUILDER_H
Index: lib/AST/StmtCXX.cpp
===================================================================
--- lib/AST/StmtCXX.cpp
+++ lib/AST/StmtCXX.cpp
@@ -112,4 +112,4 @@
       Args.ReturnStmtOnAllocFailure;
   std::copy(Args.ParamMoves.begin(), Args.ParamMoves.end(),
             const_cast<Stmt **>(getParamMoves().data()));
-}
\ No newline at end of file
+}
Index: include/clang/AST/StmtCXX.h
===================================================================
--- include/clang/AST/StmtCXX.h
+++ include/clang/AST/StmtCXX.h
@@ -344,6 +344,10 @@
 public:
   static CoroutineBodyStmt *Create(const ASTContext &C, CtorArgs const &Args);
 
+  bool hasDependentPromiseType() const {
+    return getPromiseDecl()->getType()->isDependentType();
+  }
+
   /// \brief Retrieve the body of the coroutine as written. This will be either
   /// a CompoundStmt or a TryStmt.
   Stmt *getBody() const {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to