https://github.com/hanickadot created https://github.com/llvm/llvm-project/pull/138477
This change makes `[[clang::musttail]]` work. Function calls marked with this attribute won't use system stack, but will loop after nearest function call. The attribute is already very strick, and checks all problematic cases (non-trivial destructors, referencing local variables). This PR is work in progress. From f084366a545f5e2c0ec54fa7cc4dd688950c13af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hana=20Dusi=CC=81kova=CC=81?= <hani...@hanicka.net> Date: Sun, 4 May 2025 23:27:09 +0200 Subject: [PATCH] [clang] Attribute support [[clang::musttail]] in ExprConstant.cpp allows guaranteed tail recursion. --- clang/lib/AST/ExprConstant.cpp | 260 +++++++++++++++++++++++++++------ 1 file changed, 219 insertions(+), 41 deletions(-) diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index b79d8c197fe7d..9ef6b983d196a 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -735,6 +735,13 @@ namespace { ScopeKind Scope) : Value(Val, Scope), Base(Base), T(T) {} + Cleanup(Cleanup &&Other) noexcept + : Value{Other.Value}, Base{Other.Base}, T{Other.T} { + Other.Value = {}; + } + + Cleanup &operator=(Cleanup &&) = default; + /// Determine whether this cleanup should be performed at the end of the /// given kind of scope. bool isDestroyedAtEndOf(ScopeKind K) const { @@ -1006,6 +1013,24 @@ namespace { EM_IgnoreSideEffects, } EvalMode; + /// Pointer to last tail recursion enabled return. Enforced with + /// [[clang::musttail]] + const ReturnStmt *TailRecursionReturnStmt = nullptr; + + struct DeferRecursionFunctionCall { + const CallExpr *E{nullptr}; + const FunctionDecl *Definition{nullptr}; + bool HasThis{false}; + APValue ThisVal{}; // can't use LValue here :( + llvm::ArrayRef<const clang::Expr *> Args{}; + CallRef Call{}; + Stmt *Body{nullptr}; + SmallVector<QualType, 4> CovariantAdjustmentPath{}; + SmallVector<Cleanup, 16> ArgumentsStored{}; + }; + + DeferRecursionFunctionCall DeferFunctionCall{}; + /// Are we checking whether the expression is a potential constant /// expression? bool checkingPotentialConstantExpression() const override { @@ -1124,6 +1149,21 @@ namespace { return Result; } + void EnableTailRecursion(const ReturnStmt *ret) { + TailRecursionReturnStmt = ret; + } + + void DisableTailRecursion() { TailRecursionReturnStmt = nullptr; } + + bool TailRecursionReady() const { return DeferFunctionCall.E != nullptr; } + + bool IsTailRecursion(const ReturnStmt *ret) { + if (TailRecursionReturnStmt != ret) + return false; + TailRecursionReturnStmt = nullptr; + return true; + } + /// Get the allocated storage for the given parameter of the given call. APValue *getParamSlot(CallRef Call, const ParmVarDecl *PVD) { CallStackFrame *Frame = getCallFrameAndDepth(Call.CallIndex).first; @@ -1439,6 +1479,12 @@ namespace { // instances of this class. Info.CurrentCall->popTempVersion(); } + + friend void transferFromCallScope(ScopeRAII &, + llvm::SmallVectorImpl<Cleanup> &); + friend bool transferIntoCallScope(ScopeRAII &, + llvm::SmallVectorImpl<Cleanup> &); + private: static bool cleanup(EvalInfo &Info, bool RunDestructors, unsigned OldStackSize) { @@ -1457,6 +1503,10 @@ namespace { } } + compact(Info, OldStackSize); + return Success; + } + static void compact(EvalInfo &Info, unsigned OldStackSize) { // Compact any retained cleanups. auto NewEnd = Info.CleanupStack.begin() + OldStackSize; if (Kind != ScopeKind::Block) @@ -1465,12 +1515,47 @@ namespace { return C.isDestroyedAtEndOf(Kind); }); Info.CleanupStack.erase(NewEnd, Info.CleanupStack.end()); - return Success; } }; typedef ScopeRAII<ScopeKind::Block> BlockScopeRAII; typedef ScopeRAII<ScopeKind::FullExpression> FullExpressionRAII; typedef ScopeRAII<ScopeKind::Call> CallScopeRAII; + + static void transferFromCallScope(CallScopeRAII &Scope, + llvm::SmallVectorImpl<Cleanup> &Backup) { + Backup.clear(); + + auto CurrentVariables = MutableArrayRef<Cleanup>(Scope.Info.CleanupStack) + .slice(Scope.OldStackSize); + + // Transfer of cleanup informations of tail call outside of current scope. + // These variables are going to be destroyed in current scope, which only + // prepares the tail call, but is not doing it. + Backup.clear(); + + for (Cleanup &Lifetime : CurrentVariables) { + Backup.push_back(std::move(Lifetime)); + } + + // Remove lifetime management from this scope. + Scope.compact(Scope.Info, Scope.OldStackSize); + Scope.Info.CleanupStack.truncate( + Scope.OldStackSize); // make sure this is ok + assert(Scope.Info.CleanupStack.size() == Scope.OldStackSize); + } + + static bool transferIntoCallScope(CallScopeRAII &Scope, + llvm::SmallVectorImpl<Cleanup> &Backup) { + if (!Scope.cleanup(Scope.Info, true, Scope.OldStackSize)) + return false; + + for (auto &Lifetime : Backup) { + Scope.Info.CleanupStack.push_back(std::move(Lifetime)); + } + + Backup.clear(); + return true; + } } bool SubobjectDesignator::checkSubobject(EvalInfo &Info, const Expr *E, @@ -5614,10 +5699,14 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info, // We know we returned, but we don't know what the value is. return ESR_Failed; } - if (RetExpr && - !(Result.Slot - ? EvaluateInPlace(Result.Value, Info, *Result.Slot, RetExpr) - : Evaluate(Result.Value, Info, RetExpr))) + + if (!RetExpr || !isa<CallExpr>(RetExpr)) { + Info.DisableTailRecursion(); + } + + if (RetExpr && !(Result.Slot ? EvaluateInPlace(Result.Value, Info, + *Result.Slot, RetExpr) + : Evaluate(Result.Value, Info, RetExpr))) return ESR_Failed; return Scope.destroy() ? ESR_Returned : ESR_Failed; } @@ -5869,32 +5958,37 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info, case Stmt::AttributedStmtClass: { const auto *AS = cast<AttributedStmt>(S); const auto *SS = AS->getSubStmt(); + const auto *RS = dyn_cast<ReturnStmt>(SS); MSConstexprContextRAII ConstexprContext( - *Info.CurrentCall, hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && - isa<ReturnStmt>(SS)); + *Info.CurrentCall, + hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && RS != nullptr); auto LO = Info.getASTContext().getLangOpts(); - if (LO.CXXAssumptions && !LO.MSVCCompat) { - for (auto *Attr : AS->getAttrs()) { - auto *AA = dyn_cast<CXXAssumeAttr>(Attr); - if (!AA) - continue; - - auto *Assumption = AA->getAssumption(); - if (Assumption->isValueDependent()) - return ESR_Failed; + for (auto *Attr : AS->getAttrs()) { + if (auto *AA = dyn_cast<CXXAssumeAttr>(Attr)) { + // This branch handles C++'s [[assume(<EXPR>)]] + if (LO.CXXAssumptions && !LO.MSVCCompat) { + auto *Assumption = AA->getAssumption(); + if (Assumption->isValueDependent()) + return ESR_Failed; - if (Assumption->HasSideEffects(Info.getASTContext())) - continue; + if (Assumption->HasSideEffects(Info.getASTContext())) + continue; - bool Value; - if (!EvaluateAsBooleanCondition(Assumption, Value, Info)) - return ESR_Failed; - if (!Value) { - Info.CCEDiag(Assumption->getExprLoc(), - diag::note_constexpr_assumption_failed); - return ESR_Failed; + bool Value; + if (!EvaluateAsBooleanCondition(Assumption, Value, Info)) + return ESR_Failed; + if (!Value) { + Info.CCEDiag(Assumption->getExprLoc(), + diag::note_constexpr_assumption_failed); + return ESR_Failed; + } } + } else if (isa<MustTailAttr>(Attr) && RS != nullptr) { + // This branch handles [[clang::mustttail]] enforcement on + // tail-recursion which is strict and already checked, otherwise it will + // fail to compile. + Info.EnableTailRecursion(RS); } } @@ -6514,16 +6608,16 @@ static bool MaybeHandleUnionActiveMemberChange(EvalInfo &Info, static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg, CallRef Call, EvalInfo &Info, - bool NonNull = false) { + CallStackFrame &CallerFrame, bool NonNull = false) { LValue LV; // Create the parameter slot and register its destruction. For a vararg // argument, create a temporary. // FIXME: For calling conventions that destroy parameters in the callee, // should we consider performing destruction when the function returns // instead? - APValue &V = PVD ? Info.CurrentCall->createParam(Call, PVD, LV) - : Info.CurrentCall->createTemporary(Arg, Arg->getType(), - ScopeKind::Call, LV); + APValue &V = PVD ? CallerFrame.createParam(Call, PVD, LV) + : CallerFrame.createTemporary(Arg, Arg->getType(), + ScopeKind::Call, LV); if (!EvaluateInPlace(V, Info, LV, Arg)) return false; @@ -6539,8 +6633,8 @@ static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg, /// Evaluate the arguments to a function call. static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call, - EvalInfo &Info, const FunctionDecl *Callee, - bool RightToLeft = false) { + EvalInfo &Info, CallStackFrame &CallerFrame, + const FunctionDecl *Callee, bool RightToLeft = false) { bool Success = true; llvm::SmallBitVector ForbiddenNullArgs; if (Callee->hasAttr<NonNullAttr>()) { @@ -6563,7 +6657,7 @@ static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call, const ParmVarDecl *PVD = Idx < Callee->getNumParams() ? Callee->getParamDecl(Idx) : nullptr; bool NonNull = !ForbiddenNullArgs.empty() && ForbiddenNullArgs[Idx]; - if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, NonNull)) { + if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, CallerFrame, NonNull)) { // If we're checking for a potential constant expression, evaluate all // initializers even if some of them fail. if (!Info.noteFailure()) @@ -6650,6 +6744,44 @@ static bool HandleFunctionCall(SourceLocation CallLoc, return ESR == ESR_Returned; } +static void HandleTailCallTransfer( + EvalInfo &Info, const CallExpr *E, const FunctionDecl *Definition, + const LValue *This, LValue &ThisVal, + llvm::ArrayRef<const clang::Expr *> Args, CallRef Call, Stmt *Body, + SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) { + auto &defer = Info.DeferFunctionCall; + + defer.E = E; + defer.Definition = Definition; + defer.HasThis = This != nullptr; + ThisVal.moveInto(defer.ThisVal); + defer.Args = Args; + defer.Call = Call; + defer.Body = Body; + defer.CovariantAdjustmentPath = std::move(CovariantAdjustmentPath); + + transferFromCallScope(Scope, defer.ArgumentsStored); +} + +static bool HandleTailCallSetup( + EvalInfo &Info, const CallExpr *&E, const FunctionDecl *&Definition, + LValue *&This, LValue &ThisVal, llvm::ArrayRef<const clang::Expr *> &Args, + CallRef &Call, Stmt *&Body, + SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) { + auto &defer = Info.DeferFunctionCall; + assert(defer.E != nullptr); + + E = std::exchange(defer.E, nullptr); + Definition = defer.Definition; + ThisVal.setFrom(Info.Ctx, defer.ThisVal); + This = defer.HasThis ? &ThisVal : nullptr; + Args = defer.Args; + Call = defer.Call; + Body = defer.Body; + CovariantAdjustmentPath = std::move(defer.CovariantAdjustmentPath); + return transferIntoCallScope(Scope, defer.ArgumentsStored); +} + /// Evaluate a constructor call. static bool HandleConstructorCall(const Expr *E, const LValue &This, CallRef Call, @@ -6871,7 +7003,7 @@ static bool HandleConstructorCall(const Expr *E, const LValue &This, EvalInfo &Info, APValue &Result) { CallScopeRAII CallScope(Info); CallRef Call = Info.CurrentCall->createCall(Definition); - if (!EvaluateArgs(Args, Call, Info, Definition)) + if (!EvaluateArgs(Args, Call, Info, *Info.CurrentCall, Definition)) return false; return HandleConstructorCall(E, This, Call, Definition, Info, Result) && @@ -8242,6 +8374,13 @@ class ExprEvaluatorBase APValue Result; if (!handleCallExpr(E, Result, nullptr)) return false; + + // When our current call is defered as a tail recursion + // we can't change result (yet). + if (Info.DeferFunctionCall.E != nullptr) { + return true; + } + return DerivedSuccess(Result, E); } @@ -8257,6 +8396,11 @@ class ExprEvaluatorBase auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs()); bool HasQualifier = false; + // Check for tail recursion, before we start evaluating any internal + // expression which can steal tail on their own. + const bool TailRecursion = + std::exchange(Info.TailRecursionReturnStmt, nullptr) != nullptr; + CallRef Call; // Extract function decl and 'this' pointer from the callee. @@ -8317,12 +8461,15 @@ class ExprEvaluatorBase auto *OCE = dyn_cast<CXXOperatorCallExpr>(E); if (OCE && OCE->isAssignmentOp()) { assert(Args.size() == 2 && "wrong number of arguments in assignment"); - Call = Info.CurrentCall->createCall(FD); bool HasThis = false; if (const auto *MD = dyn_cast<CXXMethodDecl>(FD)) HasThis = MD->isImplicitObjectMemberFunction(); - if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, FD, - /*RightToLeft=*/true)) + + CallStackFrame &CallOriginFrame = + *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall); + Call = CallOriginFrame.createCall(FD); + if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, + CallOriginFrame, FD, /*RightToLeft = */ true)) return false; } @@ -8404,8 +8551,10 @@ class ExprEvaluatorBase // Evaluate the arguments now if we've not already done so. if (!Call) { - Call = Info.CurrentCall->createCall(FD); - if (!EvaluateArgs(Args, Call, Info, FD)) + CallStackFrame &CallOriginFrame = + *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall); + Call = CallOriginFrame.createCall(FD); + if (!EvaluateArgs(Args, Call, Info, CallOriginFrame, FD)) return false; } @@ -8438,11 +8587,40 @@ class ExprEvaluatorBase const FunctionDecl *Definition = nullptr; Stmt *Body = FD->getBody(Definition); - if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body) || - !HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call, + if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body)) { + return false; + } + + // If we are doing tail recursion, we need to store everything needed for + // the function call. There is always max one tail recursion prepared during + // execution of a program. + if (TailRecursion) { + HandleTailCallTransfer(Info, E, Definition, This, ThisVal, Args, Call, + Body, CovariantAdjustmentPath, CallScope); + return true; + } + + if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call, Body, Info, Result, ResultSlot)) return false; + // If we do tail recursion, we don't have result yet. + assert(!Info.TailRecursionReady() || Result.isAbsent()); + + // A tail recursion can result in another tail recursion, so we need to loop + // here. + while (Info.TailRecursionReady()) { + if (!HandleTailCallSetup(Info, E, Definition, This, ThisVal, Args, Call, + Body, CovariantAdjustmentPath, CallScope)) + return false; + + if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call, + Body, Info, Result, ResultSlot)) + return false; + } + + // TODO checkme this is correct + // We got out of tail recursion, it was just a normal function. if (!CovariantAdjustmentPath.empty() && !HandleCovariantReturnAdjustment(Info, E, Result, CovariantAdjustmentPath)) @@ -17832,7 +18010,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx, break; const ParmVarDecl *PVD = Callee->getParamDecl(Idx); if ((*I)->isValueDependent() || - !EvaluateCallArg(PVD, *I, Call, Info) || + !EvaluateCallArg(PVD, *I, Call, Info, *Info.CurrentCall) || Info.EvalStatus.HasSideEffects) { // If evaluation fails, throw away the argument entirely. if (APValue *Slot = Info.getParamSlot(Call, PVD)) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits