https://github.com/yuxuanchen1997 created https://github.com/llvm/llvm-project/pull/108474
None >From a4736c1effa479692157dbe735fa873b233f98bd Mon Sep 17 00:00:00 2001 From: Yuxuan Chen <y...@meta.com> Date: Thu, 12 Sep 2024 17:13:57 -0700 Subject: [PATCH] [Clang] Propagate elide safe context through [[clang::coro_must_await]] --- clang/include/clang/Basic/Attr.td | 8 ++++ clang/include/clang/Basic/AttrDocs.td | 12 +++++ clang/lib/Sema/SemaCoroutine.cpp | 48 ++++++++++++++----- .../CodeGenCoroutines/coro-await-elidable.cpp | 19 ++++++++ 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 9a7b163b2c6da8..cd29f08a3320e7 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1258,6 +1258,14 @@ def CoroAwaitElidable : InheritableAttr { let SimpleHandler = 1; } +def CoroMustAwait : InheritableAttr { + let Spellings = [Clang<"coro_must_await">]; + let Subjects = SubjectList<[ParmVar]>; + let LangOpts = [CPlusPlus]; + let Documentation = [CoroMustAwaitDoc]; + let SimpleHandler = 1; +} + // OSObject-based attributes. def OSConsumed : InheritableParamAttr { let Spellings = [Clang<"os_consumed">]; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 9f72456d2da678..2403cb4d2f994a 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -8301,6 +8301,18 @@ callee coroutine. }]; } +def CoroMustAwaitDoc : Documentation { + let Category = DocCatDecl; + let Content = [{ +A direct call expression which returned a prvalue of a type attributed +[[clang::coro_await_elidable]] is said to be under a SafeElide context +if one of the following is true: +- it is the right-hand side operand to an co_await expression. +- it is an argument to a [[clang::coro_must_await]] parameter or +parameter pack of another direct call expression under a SafeElide context. +}]; +} + def CountedByDocs : Documentation { let Category = DocCatField; let Content = [{ diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index a574d56646f3a2..17ffd24a3eb3e4 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -849,12 +849,40 @@ static bool isAttributedCoroAwaitElidable(const QualType &QT) { return Record && Record->hasAttr<CoroAwaitElidableAttr>(); } -static bool isCoroAwaitElidableCall(Expr *Operand) { +static void applyAwaitElidableContext(Expr *Operand) { if (!Operand->isPRValue()) { - return false; + return; } - return isAttributedCoroAwaitElidable(Operand->getType()); + auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit()); + if (!Call) + return; + + if (!isAttributedCoroAwaitElidable(Call->getType())) + return; + + Call->setCoroElideSafe(); + + // Check parameter + auto *Fn = llvm::dyn_cast_if_present<FunctionDecl>(Call->getCalleeDecl()); + if (!Fn) + return; + + size_t ParmIdx = 0; + for (ParmVarDecl *PD : Fn->parameters()) { + if (PD->hasAttr<CoroMustAwaitAttr>()) { + if (PD->isParameterPack()) { + size_t NumArgs = Call->getNumArgs(); + for (size_t ArgIdx = ParmIdx; ArgIdx < NumArgs; ArgIdx++) { + applyAwaitElidableContext(Call->getArg(ArgIdx)); + } + break; + } else { + applyAwaitElidableContext(Call->getArg(ParmIdx)); + } + } + ParmIdx++; + } } // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to @@ -880,14 +908,12 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand, } auto *RD = Promise->getType()->getAsCXXRecordDecl(); - bool AwaitElidable = - isCoroAwaitElidableCall(Operand) && - isAttributedCoroAwaitElidable( - getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); - - if (AwaitElidable) - if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit())) - Call->setCoroElideSafe(); + + bool CurFnAwaitElidable = isAttributedCoroAwaitElidable( + getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType()); + + if (CurFnAwaitElidable) + applyAwaitElidableContext(Operand); Expr *Transformed = Operand; if (lookupMember(*this, "await_transform", RD, Loc)) { diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp index 8512995dfad45a..9e28f0351875cb 100644 --- a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp +++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp @@ -84,4 +84,23 @@ Task<int> nonelidable() { co_return 1; } +// CHECK-LABEL: define{{.*}} @_Z8addTasks4TaskIiES0_{{.*}} { +Task<int> addTasks([[clang::coro_must_await]] Task<int> t1, Task<int> t2) { + int i1 = co_await t1; + int i2 = co_await t2; + co_return i1 + i2; +} + +// CHECK-LABEL: define{{.*}} @_Z10returnSamei{{.*}} { +Task<int> returnSame(int i) { + co_return i; +} + +// CHECK-LABEL: define{{.*}} @_Z21elidableWithMustAwaitv{{.*}} { +Task<int> elidableWithMustAwait() { + // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 2) #[[ELIDE_SAFE]] + // CHECK-NOT: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 3) #[[ELIDE_SAFE]] + co_return co_await addTasks(returnSame(2), returnSame(3)); +} + // CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits