llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Hana Dusíková (hanickadot)

<details>
<summary>Changes</summary>

This change makes `[[clang::musttail]]` work. Function calls marked with this 
attribute won't use system stack, but will loop after nearest function call. 
The attribute is already very strick, and checks all problematic cases 
(non-trivial destructors, referencing local variables).

This PR is work in progress.

---
Full diff: https://github.com/llvm/llvm-project/pull/138477.diff


1 Files Affected:

- (modified) clang/lib/AST/ExprConstant.cpp (+219-41) 


``````````diff
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b79d8c197fe7d..9ef6b983d196a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -735,6 +735,13 @@ namespace {
             ScopeKind Scope)
         : Value(Val, Scope), Base(Base), T(T) {}
 
+    Cleanup(Cleanup &&Other) noexcept
+        : Value{Other.Value}, Base{Other.Base}, T{Other.T} {
+      Other.Value = {};
+    }
+
+    Cleanup &operator=(Cleanup &&) = default;
+
     /// Determine whether this cleanup should be performed at the end of the
     /// given kind of scope.
     bool isDestroyedAtEndOf(ScopeKind K) const {
@@ -1006,6 +1013,24 @@ namespace {
       EM_IgnoreSideEffects,
     } EvalMode;
 
+    /// Pointer to last tail recursion enabled return. Enforced with
+    /// [[clang::musttail]]
+    const ReturnStmt *TailRecursionReturnStmt = nullptr;
+
+    struct DeferRecursionFunctionCall {
+      const CallExpr *E{nullptr};
+      const FunctionDecl *Definition{nullptr};
+      bool HasThis{false};
+      APValue ThisVal{}; // can't use LValue here :(
+      llvm::ArrayRef<const clang::Expr *> Args{};
+      CallRef Call{};
+      Stmt *Body{nullptr};
+      SmallVector<QualType, 4> CovariantAdjustmentPath{};
+      SmallVector<Cleanup, 16> ArgumentsStored{};
+    };
+
+    DeferRecursionFunctionCall DeferFunctionCall{};
+
     /// Are we checking whether the expression is a potential constant
     /// expression?
     bool checkingPotentialConstantExpression() const override  {
@@ -1124,6 +1149,21 @@ namespace {
       return Result;
     }
 
+    void EnableTailRecursion(const ReturnStmt *ret) {
+      TailRecursionReturnStmt = ret;
+    }
+
+    void DisableTailRecursion() { TailRecursionReturnStmt = nullptr; }
+
+    bool TailRecursionReady() const { return DeferFunctionCall.E != nullptr; }
+
+    bool IsTailRecursion(const ReturnStmt *ret) {
+      if (TailRecursionReturnStmt != ret)
+        return false;
+      TailRecursionReturnStmt = nullptr;
+      return true;
+    }
+
     /// Get the allocated storage for the given parameter of the given call.
     APValue *getParamSlot(CallRef Call, const ParmVarDecl *PVD) {
       CallStackFrame *Frame = getCallFrameAndDepth(Call.CallIndex).first;
@@ -1439,6 +1479,12 @@ namespace {
       // instances of this class.
       Info.CurrentCall->popTempVersion();
     }
+
+    friend void transferFromCallScope(ScopeRAII &,
+                                      llvm::SmallVectorImpl<Cleanup> &);
+    friend bool transferIntoCallScope(ScopeRAII &,
+                                      llvm::SmallVectorImpl<Cleanup> &);
+
   private:
     static bool cleanup(EvalInfo &Info, bool RunDestructors,
                         unsigned OldStackSize) {
@@ -1457,6 +1503,10 @@ namespace {
         }
       }
 
+      compact(Info, OldStackSize);
+      return Success;
+    }
+    static void compact(EvalInfo &Info, unsigned OldStackSize) {
       // Compact any retained cleanups.
       auto NewEnd = Info.CleanupStack.begin() + OldStackSize;
       if (Kind != ScopeKind::Block)
@@ -1465,12 +1515,47 @@ namespace {
               return C.isDestroyedAtEndOf(Kind);
             });
       Info.CleanupStack.erase(NewEnd, Info.CleanupStack.end());
-      return Success;
     }
   };
   typedef ScopeRAII<ScopeKind::Block> BlockScopeRAII;
   typedef ScopeRAII<ScopeKind::FullExpression> FullExpressionRAII;
   typedef ScopeRAII<ScopeKind::Call> CallScopeRAII;
+
+  static void transferFromCallScope(CallScopeRAII &Scope,
+                                    llvm::SmallVectorImpl<Cleanup> &Backup) {
+    Backup.clear();
+
+    auto CurrentVariables = MutableArrayRef<Cleanup>(Scope.Info.CleanupStack)
+                                .slice(Scope.OldStackSize);
+
+    // Transfer of cleanup informations of tail call outside of current scope.
+    // These variables are going to be destroyed in current scope, which only
+    // prepares the tail call, but is not doing it.
+    Backup.clear();
+
+    for (Cleanup &Lifetime : CurrentVariables) {
+      Backup.push_back(std::move(Lifetime));
+    }
+
+    // Remove lifetime management from this scope.
+    Scope.compact(Scope.Info, Scope.OldStackSize);
+    Scope.Info.CleanupStack.truncate(
+        Scope.OldStackSize); // make sure this is ok
+    assert(Scope.Info.CleanupStack.size() == Scope.OldStackSize);
+  }
+
+  static bool transferIntoCallScope(CallScopeRAII &Scope,
+                                    llvm::SmallVectorImpl<Cleanup> &Backup) {
+    if (!Scope.cleanup(Scope.Info, true, Scope.OldStackSize))
+      return false;
+
+    for (auto &Lifetime : Backup) {
+      Scope.Info.CleanupStack.push_back(std::move(Lifetime));
+    }
+
+    Backup.clear();
+    return true;
+  }
 }
 
 bool SubobjectDesignator::checkSubobject(EvalInfo &Info, const Expr *E,
@@ -5614,10 +5699,14 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, 
EvalInfo &Info,
       // We know we returned, but we don't know what the value is.
       return ESR_Failed;
     }
-    if (RetExpr &&
-        !(Result.Slot
-              ? EvaluateInPlace(Result.Value, Info, *Result.Slot, RetExpr)
-              : Evaluate(Result.Value, Info, RetExpr)))
+
+    if (!RetExpr || !isa<CallExpr>(RetExpr)) {
+      Info.DisableTailRecursion();
+    }
+
+    if (RetExpr && !(Result.Slot ? EvaluateInPlace(Result.Value, Info,
+                                                   *Result.Slot, RetExpr)
+                                 : Evaluate(Result.Value, Info, RetExpr)))
       return ESR_Failed;
     return Scope.destroy() ? ESR_Returned : ESR_Failed;
   }
@@ -5869,32 +5958,37 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, 
EvalInfo &Info,
   case Stmt::AttributedStmtClass: {
     const auto *AS = cast<AttributedStmt>(S);
     const auto *SS = AS->getSubStmt();
+    const auto *RS = dyn_cast<ReturnStmt>(SS);
     MSConstexprContextRAII ConstexprContext(
-        *Info.CurrentCall, hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) &&
-                               isa<ReturnStmt>(SS));
+        *Info.CurrentCall,
+        hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && RS != nullptr);
 
     auto LO = Info.getASTContext().getLangOpts();
-    if (LO.CXXAssumptions && !LO.MSVCCompat) {
-      for (auto *Attr : AS->getAttrs()) {
-        auto *AA = dyn_cast<CXXAssumeAttr>(Attr);
-        if (!AA)
-          continue;
-
-        auto *Assumption = AA->getAssumption();
-        if (Assumption->isValueDependent())
-          return ESR_Failed;
+    for (auto *Attr : AS->getAttrs()) {
+      if (auto *AA = dyn_cast<CXXAssumeAttr>(Attr)) {
+        // This branch handles C++'s [[assume(<EXPR>)]]
+        if (LO.CXXAssumptions && !LO.MSVCCompat) {
+          auto *Assumption = AA->getAssumption();
+          if (Assumption->isValueDependent())
+            return ESR_Failed;
 
-        if (Assumption->HasSideEffects(Info.getASTContext()))
-          continue;
+          if (Assumption->HasSideEffects(Info.getASTContext()))
+            continue;
 
-        bool Value;
-        if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
-          return ESR_Failed;
-        if (!Value) {
-          Info.CCEDiag(Assumption->getExprLoc(),
-                       diag::note_constexpr_assumption_failed);
-          return ESR_Failed;
+          bool Value;
+          if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
+            return ESR_Failed;
+          if (!Value) {
+            Info.CCEDiag(Assumption->getExprLoc(),
+                         diag::note_constexpr_assumption_failed);
+            return ESR_Failed;
+          }
         }
+      } else if (isa<MustTailAttr>(Attr) && RS != nullptr) {
+        // This branch handles [[clang::mustttail]] enforcement on
+        // tail-recursion which is strict and already checked, otherwise it 
will
+        // fail to compile.
+        Info.EnableTailRecursion(RS);
       }
     }
 
@@ -6514,16 +6608,16 @@ static bool MaybeHandleUnionActiveMemberChange(EvalInfo 
&Info,
 
 static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
                             CallRef Call, EvalInfo &Info,
-                            bool NonNull = false) {
+                            CallStackFrame &CallerFrame, bool NonNull = false) 
{
   LValue LV;
   // Create the parameter slot and register its destruction. For a vararg
   // argument, create a temporary.
   // FIXME: For calling conventions that destroy parameters in the callee,
   // should we consider performing destruction when the function returns
   // instead?
-  APValue &V = PVD ? Info.CurrentCall->createParam(Call, PVD, LV)
-                   : Info.CurrentCall->createTemporary(Arg, Arg->getType(),
-                                                       ScopeKind::Call, LV);
+  APValue &V = PVD ? CallerFrame.createParam(Call, PVD, LV)
+                   : CallerFrame.createTemporary(Arg, Arg->getType(),
+                                                 ScopeKind::Call, LV);
   if (!EvaluateInPlace(V, Info, LV, Arg))
     return false;
 
@@ -6539,8 +6633,8 @@ static bool EvaluateCallArg(const ParmVarDecl *PVD, const 
Expr *Arg,
 
 /// Evaluate the arguments to a function call.
 static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
-                         EvalInfo &Info, const FunctionDecl *Callee,
-                         bool RightToLeft = false) {
+                         EvalInfo &Info, CallStackFrame &CallerFrame,
+                         const FunctionDecl *Callee, bool RightToLeft = false) 
{
   bool Success = true;
   llvm::SmallBitVector ForbiddenNullArgs;
   if (Callee->hasAttr<NonNullAttr>()) {
@@ -6563,7 +6657,7 @@ static bool EvaluateArgs(ArrayRef<const Expr *> Args, 
CallRef Call,
     const ParmVarDecl *PVD =
         Idx < Callee->getNumParams() ? Callee->getParamDecl(Idx) : nullptr;
     bool NonNull = !ForbiddenNullArgs.empty() && ForbiddenNullArgs[Idx];
-    if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, NonNull)) {
+    if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, CallerFrame, NonNull)) {
       // If we're checking for a potential constant expression, evaluate all
       // initializers even if some of them fail.
       if (!Info.noteFailure())
@@ -6650,6 +6744,44 @@ static bool HandleFunctionCall(SourceLocation CallLoc,
   return ESR == ESR_Returned;
 }
 
+static void HandleTailCallTransfer(
+    EvalInfo &Info, const CallExpr *E, const FunctionDecl *Definition,
+    const LValue *This, LValue &ThisVal,
+    llvm::ArrayRef<const clang::Expr *> Args, CallRef Call, Stmt *Body,
+    SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+  auto &defer = Info.DeferFunctionCall;
+
+  defer.E = E;
+  defer.Definition = Definition;
+  defer.HasThis = This != nullptr;
+  ThisVal.moveInto(defer.ThisVal);
+  defer.Args = Args;
+  defer.Call = Call;
+  defer.Body = Body;
+  defer.CovariantAdjustmentPath = std::move(CovariantAdjustmentPath);
+
+  transferFromCallScope(Scope, defer.ArgumentsStored);
+}
+
+static bool HandleTailCallSetup(
+    EvalInfo &Info, const CallExpr *&E, const FunctionDecl *&Definition,
+    LValue *&This, LValue &ThisVal, llvm::ArrayRef<const clang::Expr *> &Args,
+    CallRef &Call, Stmt *&Body,
+    SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+  auto &defer = Info.DeferFunctionCall;
+  assert(defer.E != nullptr);
+
+  E = std::exchange(defer.E, nullptr);
+  Definition = defer.Definition;
+  ThisVal.setFrom(Info.Ctx, defer.ThisVal);
+  This = defer.HasThis ? &ThisVal : nullptr;
+  Args = defer.Args;
+  Call = defer.Call;
+  Body = defer.Body;
+  CovariantAdjustmentPath = std::move(defer.CovariantAdjustmentPath);
+  return transferIntoCallScope(Scope, defer.ArgumentsStored);
+}
+
 /// Evaluate a constructor call.
 static bool HandleConstructorCall(const Expr *E, const LValue &This,
                                   CallRef Call,
@@ -6871,7 +7003,7 @@ static bool HandleConstructorCall(const Expr *E, const 
LValue &This,
                                   EvalInfo &Info, APValue &Result) {
   CallScopeRAII CallScope(Info);
   CallRef Call = Info.CurrentCall->createCall(Definition);
-  if (!EvaluateArgs(Args, Call, Info, Definition))
+  if (!EvaluateArgs(Args, Call, Info, *Info.CurrentCall, Definition))
     return false;
 
   return HandleConstructorCall(E, This, Call, Definition, Info, Result) &&
@@ -8242,6 +8374,13 @@ class ExprEvaluatorBase
     APValue Result;
     if (!handleCallExpr(E, Result, nullptr))
       return false;
+
+    // When our current call is defered as a tail recursion
+    // we can't change result (yet).
+    if (Info.DeferFunctionCall.E != nullptr) {
+      return true;
+    }
+
     return DerivedSuccess(Result, E);
   }
 
@@ -8257,6 +8396,11 @@ class ExprEvaluatorBase
     auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs());
     bool HasQualifier = false;
 
+    // Check for tail recursion, before we start evaluating any internal
+    // expression which can steal tail on their own.
+    const bool TailRecursion =
+        std::exchange(Info.TailRecursionReturnStmt, nullptr) != nullptr;
+
     CallRef Call;
 
     // Extract function decl and 'this' pointer from the callee.
@@ -8317,12 +8461,15 @@ class ExprEvaluatorBase
       auto *OCE = dyn_cast<CXXOperatorCallExpr>(E);
       if (OCE && OCE->isAssignmentOp()) {
         assert(Args.size() == 2 && "wrong number of arguments in assignment");
-        Call = Info.CurrentCall->createCall(FD);
         bool HasThis = false;
         if (const auto *MD = dyn_cast<CXXMethodDecl>(FD))
           HasThis = MD->isImplicitObjectMemberFunction();
-        if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, FD,
-                          /*RightToLeft=*/true))
+
+        CallStackFrame &CallOriginFrame =
+            *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+        Call = CallOriginFrame.createCall(FD);
+        if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info,
+                          CallOriginFrame, FD, /*RightToLeft = */ true))
           return false;
       }
 
@@ -8404,8 +8551,10 @@ class ExprEvaluatorBase
 
     // Evaluate the arguments now if we've not already done so.
     if (!Call) {
-      Call = Info.CurrentCall->createCall(FD);
-      if (!EvaluateArgs(Args, Call, Info, FD))
+      CallStackFrame &CallOriginFrame =
+          *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+      Call = CallOriginFrame.createCall(FD);
+      if (!EvaluateArgs(Args, Call, Info, CallOriginFrame, FD))
         return false;
     }
 
@@ -8438,11 +8587,40 @@ class ExprEvaluatorBase
     const FunctionDecl *Definition = nullptr;
     Stmt *Body = FD->getBody(Definition);
 
-    if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body) ||
-        !HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+    if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body)) {
+      return false;
+    }
+
+    // If we are doing tail recursion, we need to store everything needed for
+    // the function call. There is always max one tail recursion prepared 
during
+    // execution of a program.
+    if (TailRecursion) {
+      HandleTailCallTransfer(Info, E, Definition, This, ThisVal, Args, Call,
+                             Body, CovariantAdjustmentPath, CallScope);
+      return true;
+    }
+
+    if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
                             Body, Info, Result, ResultSlot))
       return false;
 
+    // If we do tail recursion, we don't have result yet.
+    assert(!Info.TailRecursionReady() || Result.isAbsent());
+
+    // A tail recursion can result in another tail recursion, so we need to 
loop
+    // here.
+    while (Info.TailRecursionReady()) {
+      if (!HandleTailCallSetup(Info, E, Definition, This, ThisVal, Args, Call,
+                               Body, CovariantAdjustmentPath, CallScope))
+        return false;
+
+      if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+                              Body, Info, Result, ResultSlot))
+        return false;
+    }
+
+    // TODO checkme this is correct
+    // We got out of tail recursion, it was just a normal function.
     if (!CovariantAdjustmentPath.empty() &&
         !HandleCovariantReturnAdjustment(Info, E, Result,
                                          CovariantAdjustmentPath))
@@ -17832,7 +18010,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, 
ASTContext &Ctx,
       break;
     const ParmVarDecl *PVD = Callee->getParamDecl(Idx);
     if ((*I)->isValueDependent() ||
-        !EvaluateCallArg(PVD, *I, Call, Info) ||
+        !EvaluateCallArg(PVD, *I, Call, Info, *Info.CurrentCall) ||
         Info.EvalStatus.HasSideEffects) {
       // If evaluation fails, throw away the argument entirely.
       if (APValue *Slot = Info.getParamSlot(Call, PVD))

``````````

</details>


https://github.com/llvm/llvm-project/pull/138477
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to