https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/85684
>From 08de54f02038795924a6e5fdbcf51a496fcedf56 Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <y...@meta.com> Date: Mon, 18 Mar 2024 10:45:20 -0700 Subject: [PATCH] Check if Coroutine await_suspend type returns the right type --- .../clang/Basic/DiagnosticSemaKinds.td | 2 +- clang/include/clang/Sema/Sema.h | 2 + clang/lib/Sema/SemaCoroutine.cpp | 75 +++++++++++------ clang/lib/Sema/SemaExprCXX.cpp | 84 +++++++++---------- clang/test/SemaCXX/coroutines.cpp | 28 +++++-- 5 files changed, 116 insertions(+), 75 deletions(-) diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 8e97902564af08..f99170445c76b6 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11701,7 +11701,7 @@ def err_coroutine_promise_new_requires_nothrow : Error< def note_coroutine_promise_call_implicitly_required : Note< "call to %0 implicitly required by coroutine function here">; def err_await_suspend_invalid_return_type : Error< - "return type of 'await_suspend' is required to be 'void' or 'bool' (have %0)" + "return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have %0)" >; def note_await_ready_no_bool_conversion : Note< "return type of 'await_ready' is required to be contextually convertible to 'bool'" diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 95ea5ebc7f1ac1..4976ff96b03d5b 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -7011,6 +7011,8 @@ class Sema final { ExprResult BuildTypeTrait(TypeTrait Kind, SourceLocation KWLoc, ArrayRef<TypeSourceInfo *> Args, SourceLocation RParenLoc); + bool EvaluateBinaryTypeTrait(TypeTrait BTT, QualType LhsT, QualType RhsT, + SourceLocation KeyLoc); /// ActOnArrayTypeTrait - Parsed one of the binary type trait support /// pseudo-functions. diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 736632857efc36..fbe230737404fa 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -331,16 +331,12 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, // coroutine. static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, SourceLocation Loc) { - if (RetType->isReferenceType()) - return nullptr; + assert(!RetType->isReferenceType() && + "Should have diagnosed reference types."); Type const *T = RetType.getTypePtr(); if (!T->isClassType() && !T->isStructureType()) return nullptr; - // FIXME: Add convertability check to coroutine_handle<>. Possibly via - // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment - // a private function in SemaExprCXX.cpp - ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", std::nullopt); if (AddressExpr.isInvalid()) return nullptr; @@ -358,6 +354,14 @@ static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E, return S.MaybeCreateExprWithCleanups(JustAddress); } +static bool isConvertibleToCoroutineHandle(Sema &S, QualType Ty, + SourceLocation Loc) { + QualType ErasedHandleType = + lookupCoroutineHandleType(S, S.Context.VoidTy, Loc); + return S.EvaluateBinaryTypeTrait(BTT_IsConvertible, Ty, ErasedHandleType, + Loc); +} + /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. /// The generated AST tries to clean up temporary objects as early as @@ -418,39 +422,60 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, return Calls; } Expr *CoroHandle = CoroHandleRes.get(); - CallExpr *AwaitSuspend = cast_or_null<CallExpr>( - BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle)); + auto *AwaitSuspend = [&]() -> CallExpr * { + auto *SubExpr = BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle); + if (!SubExpr) + return nullptr; + if (auto *E = dyn_cast<CXXBindTemporaryExpr>(SubExpr)) { + // This happens when await_suspend return type is not trivially + // destructible. This doesn't happen for the permitted return types of + // such function. Diagnose it later. + return cast_or_null<CallExpr>(E->getSubExpr()); + } else { + return cast_or_null<CallExpr>(SubExpr); + } + }(); + if (!AwaitSuspend) return Calls; + if (!AwaitSuspend->getType()->isDependentType()) { + auto InvalidAwaitSuspendReturnType = [&](QualType RetType) { + // non-class prvalues always have cv-unqualified types + S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), + diag::err_await_suspend_invalid_return_type) + << RetType; + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << AwaitSuspend->getDirectCallee(); + Calls.IsInvalid = true; + }; + // [expr.await]p3 [...] // - await-suspend is the expression e.await_suspend(h), which shall be // a prvalue of type void, bool, or std::coroutine_handle<Z> for some // type Z. QualType RetType = AwaitSuspend->getCallReturnType(S.Context); - // Support for coroutine_handle returning await_suspend. - if (Expr *TailCallSuspend = - maybeTailCall(S, RetType, AwaitSuspend, Loc)) + if (RetType->isReferenceType()) { + InvalidAwaitSuspendReturnType(RetType); + } else if (RetType->isBooleanType() || RetType->isVoidType()) { + Calls.Results[ACT::ACT_Suspend] = + S.MaybeCreateExprWithCleanups(AwaitSuspend); + } else if (isConvertibleToCoroutineHandle(S, RetType, Loc)) { + // Support for coroutine_handle returning await_suspend. + // // Note that we don't wrap the expression with ExprWithCleanups here // because that might interfere with tailcall contract (e.g. inserting // clean up instructions in-between tailcall and return). Instead // ExprWithCleanups is wrapped within maybeTailCall() prior to the resume // call. - Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; - else { - // non-class prvalues always have cv-unqualified types - if (RetType->isReferenceType() || - (!RetType->isBooleanType() && !RetType->isVoidType())) { - S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(), - diag::err_await_suspend_invalid_return_type) - << RetType; - S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) - << AwaitSuspend->getDirectCallee(); - Calls.IsInvalid = true; - } else - Calls.Results[ACT::ACT_Suspend] = - S.MaybeCreateExprWithCleanups(AwaitSuspend); + Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc); + if (TailCallSuspend) + Calls.Results[ACT::ACT_Suspend] = TailCallSuspend; + else + InvalidAwaitSuspendReturnType(RetType); + } else { + InvalidAwaitSuspendReturnType(RetType); } } diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp index c34a40fa7c81ac..db04e59a91332d 100644 --- a/clang/lib/Sema/SemaExprCXX.cpp +++ b/clang/lib/Sema/SemaExprCXX.cpp @@ -5559,9 +5559,6 @@ static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait UTT, } } -static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, - QualType RhsT, SourceLocation KeyLoc); - static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind, SourceLocation KWLoc, ArrayRef<TypeSourceInfo *> Args, @@ -5576,8 +5573,8 @@ static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind, // Evaluate ReferenceBindsToTemporary and ReferenceConstructsFromTemporary // alongside the IsConstructible traits to avoid duplication. if (Kind <= BTT_Last && Kind != BTT_ReferenceBindsToTemporary && Kind != BTT_ReferenceConstructsFromTemporary) - return EvaluateBinaryTypeTrait(S, Kind, Args[0]->getType(), - Args[1]->getType(), RParenLoc); + return S.EvaluateBinaryTypeTrait(Kind, Args[0]->getType(), + Args[1]->getType(), RParenLoc); switch (Kind) { case clang::BTT_ReferenceBindsToTemporary: @@ -5674,7 +5671,8 @@ static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind, QualType TPtr = S.Context.getPointerType(S.BuiltinRemoveReference(T, UnaryTransformType::RemoveCVRef, {})); QualType UPtr = S.Context.getPointerType(S.BuiltinRemoveReference(U, UnaryTransformType::RemoveCVRef, {})); - return EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsConvertibleTo, UPtr, TPtr, RParenLoc); + return S.EvaluateBinaryTypeTrait(TypeTrait::BTT_IsConvertibleTo, UPtr, + TPtr, RParenLoc); } if (Kind == clang::TT_IsNothrowConstructible) @@ -5807,8 +5805,8 @@ ExprResult Sema::ActOnTypeTrait(TypeTrait Kind, SourceLocation KWLoc, return BuildTypeTrait(Kind, KWLoc, ConvertedArgs, RParenLoc); } -static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, - QualType RhsT, SourceLocation KeyLoc) { +bool Sema::EvaluateBinaryTypeTrait(TypeTrait BTT, QualType LhsT, QualType RhsT, + SourceLocation KeyLoc) { assert(!LhsT->isDependentType() && !RhsT->isDependentType() && "Cannot evaluate traits of dependent types"); @@ -5832,15 +5830,15 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, if (!BaseInterface || !DerivedInterface) return false; - if (Self.RequireCompleteType( + if (RequireCompleteType( KeyLoc, RhsT, diag::err_incomplete_type_used_in_type_trait_expr)) return false; return BaseInterface->isSuperClassOf(DerivedInterface); } - assert(Self.Context.hasSameUnqualifiedType(LhsT, RhsT) - == (lhsRecord == rhsRecord)); + assert(Context.hasSameUnqualifiedType(LhsT, RhsT) == + (lhsRecord == rhsRecord)); // Unions are never base classes, and never have base classes. // It doesn't matter if they are complete or not. See PR#41843 @@ -5856,21 +5854,21 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, // If Base and Derived are class types and are different types // (ignoring possible cv-qualifiers) then Derived shall be a // complete type. - if (Self.RequireCompleteType(KeyLoc, RhsT, - diag::err_incomplete_type_used_in_type_trait_expr)) + if (RequireCompleteType(KeyLoc, RhsT, + diag::err_incomplete_type_used_in_type_trait_expr)) return false; return cast<CXXRecordDecl>(rhsRecord->getDecl()) ->isDerivedFrom(cast<CXXRecordDecl>(lhsRecord->getDecl())); } case BTT_IsSame: - return Self.Context.hasSameType(LhsT, RhsT); + return Context.hasSameType(LhsT, RhsT); case BTT_TypeCompatible: { // GCC ignores cv-qualifiers on arrays for this builtin. Qualifiers LhsQuals, RhsQuals; - QualType Lhs = Self.getASTContext().getUnqualifiedArrayType(LhsT, LhsQuals); - QualType Rhs = Self.getASTContext().getUnqualifiedArrayType(RhsT, RhsQuals); - return Self.Context.typesAreCompatible(Lhs, Rhs); + QualType Lhs = getASTContext().getUnqualifiedArrayType(LhsT, LhsQuals); + QualType Rhs = getASTContext().getUnqualifiedArrayType(RhsT, RhsQuals); + return Context.typesAreCompatible(Lhs, Rhs); } case BTT_IsConvertible: case BTT_IsConvertibleTo: @@ -5909,16 +5907,16 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, return LhsT->isVoidType(); // A function definition requires a complete, non-abstract return type. - if (!Self.isCompleteType(KeyLoc, RhsT) || Self.isAbstractType(KeyLoc, RhsT)) + if (!isCompleteType(KeyLoc, RhsT) || isAbstractType(KeyLoc, RhsT)) return false; // Compute the result of add_rvalue_reference. if (LhsT->isObjectType() || LhsT->isFunctionType()) - LhsT = Self.Context.getRValueReferenceType(LhsT); + LhsT = Context.getRValueReferenceType(LhsT); // Build a fake source and destination for initialization. InitializedEntity To(InitializedEntity::InitializeTemporary(RhsT)); - OpaqueValueExpr From(KeyLoc, LhsT.getNonLValueExprType(Self.Context), + OpaqueValueExpr From(KeyLoc, LhsT.getNonLValueExprType(Context), Expr::getValueKindForType(LhsT)); Expr *FromPtr = &From; InitializationKind Kind(InitializationKind::CreateCopy(KeyLoc, @@ -5927,21 +5925,21 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, // Perform the initialization in an unevaluated context within a SFINAE // trap at translation unit scope. EnterExpressionEvaluationContext Unevaluated( - Self, Sema::ExpressionEvaluationContext::Unevaluated); - Sema::SFINAETrap SFINAE(Self, /*AccessCheckingSFINAE=*/true); - Sema::ContextRAII TUContext(Self, Self.Context.getTranslationUnitDecl()); - InitializationSequence Init(Self, To, Kind, FromPtr); + *this, Sema::ExpressionEvaluationContext::Unevaluated); + Sema::SFINAETrap SFINAE(*this, /*AccessCheckingSFINAE=*/true); + Sema::ContextRAII TUContext(*this, Context.getTranslationUnitDecl()); + InitializationSequence Init(*this, To, Kind, FromPtr); if (Init.Failed()) return false; - ExprResult Result = Init.Perform(Self, To, Kind, FromPtr); + ExprResult Result = Init.Perform(*this, To, Kind, FromPtr); if (Result.isInvalid() || SFINAE.hasErrorOccurred()) return false; if (BTT != BTT_IsNothrowConvertible) return true; - return Self.canThrow(Result.get()) == CT_Cannot; + return canThrow(Result.get()) == CT_Cannot; } case BTT_IsAssignable: @@ -5959,12 +5957,12 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, // For both, T and U shall be complete types, (possibly cv-qualified) // void, or arrays of unknown bound. if (!LhsT->isVoidType() && !LhsT->isIncompleteArrayType() && - Self.RequireCompleteType(KeyLoc, LhsT, - diag::err_incomplete_type_used_in_type_trait_expr)) + RequireCompleteType(KeyLoc, LhsT, + diag::err_incomplete_type_used_in_type_trait_expr)) return false; if (!RhsT->isVoidType() && !RhsT->isIncompleteArrayType() && - Self.RequireCompleteType(KeyLoc, RhsT, - diag::err_incomplete_type_used_in_type_trait_expr)) + RequireCompleteType(KeyLoc, RhsT, + diag::err_incomplete_type_used_in_type_trait_expr)) return false; // cv void is never assignable. @@ -5974,27 +5972,27 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, // Build expressions that emulate the effect of declval<T>() and // declval<U>(). if (LhsT->isObjectType() || LhsT->isFunctionType()) - LhsT = Self.Context.getRValueReferenceType(LhsT); + LhsT = Context.getRValueReferenceType(LhsT); if (RhsT->isObjectType() || RhsT->isFunctionType()) - RhsT = Self.Context.getRValueReferenceType(RhsT); - OpaqueValueExpr Lhs(KeyLoc, LhsT.getNonLValueExprType(Self.Context), + RhsT = Context.getRValueReferenceType(RhsT); + OpaqueValueExpr Lhs(KeyLoc, LhsT.getNonLValueExprType(Context), Expr::getValueKindForType(LhsT)); - OpaqueValueExpr Rhs(KeyLoc, RhsT.getNonLValueExprType(Self.Context), + OpaqueValueExpr Rhs(KeyLoc, RhsT.getNonLValueExprType(Context), Expr::getValueKindForType(RhsT)); // Attempt the assignment in an unevaluated context within a SFINAE // trap at translation unit scope. EnterExpressionEvaluationContext Unevaluated( - Self, Sema::ExpressionEvaluationContext::Unevaluated); - Sema::SFINAETrap SFINAE(Self, /*AccessCheckingSFINAE=*/true); - Sema::ContextRAII TUContext(Self, Self.Context.getTranslationUnitDecl()); - ExprResult Result = Self.BuildBinOp(/*S=*/nullptr, KeyLoc, BO_Assign, &Lhs, - &Rhs); + *this, Sema::ExpressionEvaluationContext::Unevaluated); + Sema::SFINAETrap SFINAE(*this, /*AccessCheckingSFINAE=*/true); + Sema::ContextRAII TUContext(*this, Context.getTranslationUnitDecl()); + ExprResult Result = + BuildBinOp(/*S=*/nullptr, KeyLoc, BO_Assign, &Lhs, &Rhs); if (Result.isInvalid()) return false; // Treat the assignment as unused for the purpose of -Wdeprecated-volatile. - Self.CheckUnusedVolatileAssignment(Result.get()); + CheckUnusedVolatileAssignment(Result.get()); if (SFINAE.hasErrorOccurred()) return false; @@ -6003,7 +6001,7 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, return true; if (BTT == BTT_IsNothrowAssignable) - return Self.canThrow(Result.get()) == CT_Cannot; + return canThrow(Result.get()) == CT_Cannot; if (BTT == BTT_IsTriviallyAssignable) { // Under Objective-C ARC and Weak, if the destination has non-trivial @@ -6011,14 +6009,14 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT, if (LhsT.getNonReferenceType().hasNonTrivialObjCLifetime()) return false; - return !Result.get()->hasNonTrivialCall(Self.Context); + return !Result.get()->hasNonTrivialCall(Context); } llvm_unreachable("unhandled type trait"); return false; } case BTT_IsLayoutCompatible: { - return Self.IsLayoutCompatible(LhsT, RhsT); + return IsLayoutCompatible(LhsT, RhsT); } default: llvm_unreachable("not a BTT"); } diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp index 2292932583fff6..cdd9be4c201d3f 100644 --- a/clang/test/SemaCXX/coroutines.cpp +++ b/clang/test/SemaCXX/coroutines.cpp @@ -1005,12 +1005,24 @@ coro<promise_no_return_func> no_return_value_or_return_void_3() { co_return 43; // expected-error {{no member named 'return_value'}} } -struct bad_await_suspend_return { +struct non_trivial_destruction_type { + ~non_trivial_destruction_type(); +}; + +struct bad_await_suspend_return_1 { bool await_ready(); - // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'char')}} + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'char')}} char await_suspend(std::coroutine_handle<>); void await_resume(); }; + +struct bad_await_suspend_return_2 { + bool await_ready(); + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'non_trivial_destruction_type')}} + non_trivial_destruction_type await_suspend(std::coroutine_handle<>); + void await_resume(); +}; + struct bad_await_ready_return { // expected-note@+1 {{return type of 'await_ready' is required to be contextually convertible to 'bool'}} void await_ready(); @@ -1028,8 +1040,8 @@ struct await_ready_explicit_bool { template <class SuspendTy> struct await_suspend_type_test { bool await_ready(); - // expected-error@+2 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &')}} - // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &&')}} + // expected-error@+2 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'bool &')}} + // expected-error@+1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'bool &&')}} SuspendTy await_suspend(std::coroutine_handle<>); // cxx20_23-warning@-1 {{volatile-qualified return type 'const volatile bool' is deprecated}} void await_resume(); @@ -1042,8 +1054,12 @@ void test_bad_suspend() { co_await a; // expected-note {{call to 'await_ready' implicitly required by coroutine function here}} } { - bad_await_suspend_return b; - co_await b; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} + bad_await_suspend_return_1 b1; + co_await b1; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} + } + { + bad_await_suspend_return_2 b2; + co_await b2; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}} } { await_ready_explicit_bool c; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits