5chmidti created this revision.
5chmidti added a reviewer: sammccall.
Herald added subscribers: kadircet, arphaman.
Herald added a project: All.
5chmidti requested review of this revision.
Herald added subscribers: cfe-commits, MaskRay, ilya-biryukov.
Herald added a project: clang-tools-extra.

Adds support to hoist variables declared inside the selected region
and used afterwards back out of the extraced function for later use.
Uses the explicit variable type if only one decl needs hoisting,
otherwise pair or tuple with auto return type deduction
(requires c++14) and a structured binding (requires c++17) or
explicitly unpacking the variables with get<>.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D138499

Files:
  clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
  clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
  clang-tools-extra/docs/ReleaseNotes.rst

Index: clang-tools-extra/docs/ReleaseNotes.rst
===================================================================
--- clang-tools-extra/docs/ReleaseNotes.rst
+++ clang-tools-extra/docs/ReleaseNotes.rst
@@ -78,6 +78,9 @@
 Miscellaneous
 ^^^^^^^^^^^^^
 
+- The extract function tweak gained support for hoisting, i.e. returning decls declared
+  inside the selection that are used outside of the selection.
+
 Improvements to clang-doc
 -------------------------
 
Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -30,8 +30,9 @@
   EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable");
   // Partial statements aren't extracted.
   EXPECT_THAT(apply("int [[x = 0]];"), "unavailable");
-  // FIXME: Support hoisting.
-  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable");
+
+  // Extract regions that require hoisting
+  EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted"));
 
   // Ensure that end of Zone and Beginning of PostZone being adjacent doesn't
   // lead to break being included in the extraction zone.
@@ -192,6 +193,202 @@
   EXPECT_EQ(apply(CompoundFailInput), "unavailable");
 }
 
+TEST_F(ExtractFunctionTest, Hoisting) {
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto [x, y, z] = extracted(a);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a{};
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a{};
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+
+  std::string HoistingInput3 = R"cpp(
+    int foo(int b) {
+      int a{};
+      if (b == 42) {
+        [[a = 123;
+        return a + b;]]
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  std::string HoistingOutput3 = R"cpp(
+    int extracted(int &b, int &a) {
+a = 123;
+        return a + b;
+}
+int foo(int b) {
+      int a{};
+      if (b == 42) {
+        return extracted(b, a);
+      }
+      a = 456;
+      return a;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput3), HoistingOutput3);
+
+  std::string HoistingInput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    int foo(int b) {
+      int a = 0;
+      [[auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;]]
+      return a + b + c + val;
+    }
+  )cpp";
+  std::string HoistingOutput4 = R"cpp(
+    struct A {
+      bool flag;
+      int val;
+    };
+    A bar();
+    auto extracted(int &a) {
+auto [flag, val] = bar();
+      int c = 4;
+      val = c + a;
+return std::pair{val, c};
+}
+int foo(int b) {
+      int a = 0;
+      auto [val, c] = extracted(a);
+      return a + b + c + val;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput4), HoistingOutput4);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX11) {
+  ExtraArgs.emplace_back("-std=c++11");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable"));
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX14) {
+  ExtraArgs.emplace_back("-std=c++14");
+  std::string HoistingInput = R"cpp(
+    int foo() {
+      int a = 3;
+      [[int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;]]
+      return x + y + z;
+    }
+  )cpp";
+  std::string HoistingOutput = R"cpp(
+    auto extracted(int &a) {
+int x = 39 + a;
+      ++x;
+      int y = x * 2;
+      int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+      int a = 3;
+      auto returned = extracted(a);
+auto x = std::get<0>(returned);
+auto y = std::get<1>(returned);
+auto z = std::get<2>(returned);
+      return x + y + z;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+  std::string HoistingInput2 = R"cpp(
+    int foo() {
+      int a;
+      [[int b = a + 1;]]
+      return b;
+    }
+  )cpp";
+  std::string HoistingOutput2 = R"cpp(
+    int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+      int a;
+      auto b = extracted(a);
+      return b;
+    }
+  )cpp";
+  EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
 TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) {
   Header = R"cpp(
     class SomeClass {
Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -56,6 +56,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
@@ -72,7 +73,11 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_os_ostream.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+#include <string>
 
 namespace clang {
 namespace clangd {
@@ -80,6 +85,13 @@
 
 using Node = SelectionTree::Node;
 
+struct HoistSetComparator {
+  bool operator()(const Decl *const Lhs, const Decl *const Rhs) const {
+    return Lhs->getLocation() < Rhs->getLocation();
+  }
+};
+using HoistSet = llvm::SmallSet<const NamedDecl *, 1, HoistSetComparator>;
+
 // ExtractionZone is the part of code that is being extracted.
 // EnclosingFunction is the function/method inside which the zone lies.
 // We split the file into 4 parts relative to extraction zone.
@@ -172,12 +184,13 @@
   // semicolon after the extraction.
   const Node *getLastRootStmt() const { return Parent->Children.back(); }
 
-  // Checks if declarations inside extraction zone are accessed afterwards.
+  // Checks if declarations inside extraction zone are accessed afterwards and
+  // adds these declarations to the returned set.
   //
   // This performs a partial AST traversal proportional to the size of the
   // enclosing function, so it is possibly expensive.
-  bool requiresHoisting(const SourceManager &SM,
-                        const HeuristicResolver *Resolver) const {
+  HoistSet getDeclsToHoist(const SourceManager &SM,
+                           const HeuristicResolver *Resolver) const {
     // First find all the declarations that happened inside extraction zone.
     llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
     for (auto *RootStmt : RootStmts) {
@@ -192,29 +205,31 @@
     }
     // Early exit without performing expensive traversal below.
     if (DeclsInExtZone.empty())
-      return false;
-    // Then make sure they are not used outside the zone.
+      return {};
+    // Add any decl used after the selection to the returned set
+    HoistSet DeclsToHoist{};
     for (const auto *S : EnclosingFunction->getBody()->children()) {
       if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
                                        ZoneRange.getEnd()))
         continue;
-      bool HasPostUse = false;
       findExplicitReferences(
           S,
           [&](const ReferenceLoc &Loc) {
-            if (HasPostUse ||
-                SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
+            if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
               return;
-            HasPostUse = llvm::any_of(Loc.Targets,
-                                      [&DeclsInExtZone](const Decl *Target) {
-                                        return DeclsInExtZone.contains(Target);
-                                      });
+            const auto *const PostUseIter = llvm::find_if(
+                Loc.Targets, [&DeclsInExtZone](const Decl *Target) {
+                  return DeclsInExtZone.contains(Target);
+                });
+
+            if (const bool FoundPostUse = PostUseIter != Loc.Targets.end();
+                FoundPostUse) {
+              DeclsToHoist.insert(*PostUseIter);
+            }
           },
           Resolver);
-      if (HasPostUse)
-        return true;
     }
-    return false;
+    return DeclsToHoist;
   }
 };
 
@@ -368,16 +383,20 @@
   bool Static = false;
   ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
   bool Const = false;
+  const HoistSet &ToHoist;
 
   // Decides whether the extracted function body and the function call need a
   // semicolon after extraction.
   tooling::ExtractionSemicolonPolicy SemicolonPolicy;
   const LangOptions *LangOpts;
-  NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
+  NewFunction(const HoistSet &ToHoist,
+              tooling::ExtractionSemicolonPolicy SemicolonPolicy,
               const LangOptions *LangOpts)
-      : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
+      : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {
+  }
   // Render the call for this function.
   std::string renderCall() const;
+  std::string renderHoistedCall() const;
   // Render the definition for this function.
   std::string renderDeclaration(FunctionDeclKind K,
                                 const DeclContext &SemanticDC,
@@ -463,7 +482,58 @@
   return llvm::formatv("{0}{1}", QualifierName, Name);
 }
 
+// Renders the HoistSet to a comma separated list or a single named decl.
+std::string renderHoistSet(const HoistSet &ToHoist) {
+  std::string Res{};
+  bool NeedsComma = false;
+  const auto Render = [&NeedsComma, &Res](const NamedDecl *const NDecl) {
+    if (NeedsComma) {
+      Res += ", ";
+    }
+    Res += NDecl->getNameAsString();
+  };
+  for (const NamedDecl *DeclToHoist : ToHoist) {
+    if (llvm::isa<VarDecl>(DeclToHoist) ||
+        llvm::isa<BindingDecl>(DeclToHoist)) {
+      Render(DeclToHoist);
+    }
+
+    NeedsComma = true;
+  }
+  return Res;
+}
+
+std::string NewFunction::renderHoistedCall() const {
+  auto HoistedVarDecls = std::string{};
+  auto ExplicitUnpacking = std::string{};
+  const auto HasStructuredBinding = LangOpts->CPlusPlus17;
+
+  if (ToHoist.size() > 1) {
+    if (HasStructuredBinding) {
+      HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = ";
+    } else {
+      HoistedVarDecls = "auto returned = ";
+      auto DeclIter = ToHoist.begin();
+      for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) {
+        ExplicitUnpacking +=
+            llvm::formatv("\nauto {0} = std::get<{1}>(returned);",
+                          (*DeclIter)->getNameAsString(), Index);
+      }
+    }
+  } else {
+    HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = ";
+  }
+
+  return std::string(llvm::formatv(
+      "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(),
+      (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""),
+      ExplicitUnpacking));
+}
+
 std::string NewFunction::renderCall() const {
+  if (!ToHoist.empty())
+    return renderHoistedCall();
+
   return std::string(
       llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
                     renderParametersForCall(),
@@ -496,8 +566,20 @@
   // - hoist decls
   // - add return statement
   // - Add semicolon
-  return toSourceCode(SM, BodyRange).str() +
-         (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  auto Body = toSourceCode(SM, BodyRange).str() +
+              (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+  if (!ToHoist.empty()) {
+    if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) {
+      const auto NeedsPair = ToHoist.size() == 2;
+
+      Body += "\nreturn " +
+              std::string(NeedsPair ? "std::pair{" : "std::tuple{") +
+              renderHoistSet(ToHoist) + "};";
+    } else {
+      Body += "\nreturn " + renderHoistSet(ToHoist) + ";";
+    }
+  }
+  return Body;
 }
 
 std::string NewFunction::Parameter::render(const DeclContext *Context) const {
@@ -675,10 +757,6 @@
     const auto &DeclInfo = KeyVal.second;
     // If a Decl was Declared in zone and referenced in post zone, it
     // needs to be hoisted (we bail out in that case).
-    // FIXME: Support Decl Hoisting.
-    if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
-        DeclInfo.IsReferencedInPostZone)
-      return false;
     if (!DeclInfo.IsReferencedInZone)
       continue; // no need to pass as parameter, not referenced
     if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
@@ -724,6 +802,19 @@
   return SemicolonPolicy;
 }
 
+QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc,
+                                 const HoistSet &ToHoist) {
+  // Hoisting just one variable, use that variables type instead of auto
+  if (ToHoist.size() == 1) {
+    if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin());
+        VDecl != nullptr) {
+      return VDecl->getType();
+    }
+  }
+
+  return EnclosingFunc.getParentASTContext().getAutoDeductType();
+}
+
 // Generate return type for ExtractedFunc. Return false if unable to do so.
 bool generateReturnProperties(NewFunction &ExtractedFunc,
                               const FunctionDecl &EnclosingFunc,
@@ -745,7 +836,11 @@
     return true;
   }
   // FIXME: Generate new return statement if needed.
-  ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
+  ExtractedFunc.ReturnType =
+      !ExtractedFunc.ToHoist.empty()
+          ? getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist)
+          : EnclosingFunc.getParentASTContext().VoidTy;
+
   return true;
 }
 
@@ -759,6 +854,7 @@
 // FIXME: add support for adding other function return types besides void.
 // FIXME: assign the value returned by non void extracted function.
 llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
+                                                 const HoistSet &ToHoist,
                                                  const SourceManager &SM,
                                                  const LangOptions &LangOpts) {
   CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
@@ -766,7 +862,7 @@
   if (CapturedInfo.BrokenControlFlow)
     return error("Cannot extract break/continue without corresponding "
                  "loop/switch statement.");
-  NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
+  NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts),
                             &LangOpts);
 
   ExtractedFunc.SyntacticDC =
@@ -815,6 +911,7 @@
 
 private:
   ExtractionZone ExtZone;
+  HoistSet ToHoist;
 };
 
 REGISTER_TWEAK(ExtractFunction)
@@ -880,8 +977,12 @@
       (hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
     return false;
 
-  // FIXME: Get rid of this check once we support hoisting.
-  if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
+  ToHoist =
+      MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver());
+
+  const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14;
+  const auto RequiresPairOrTuple = ToHoist.size() > 1;
+  if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction)
     return false;
 
   ExtZone = std::move(*MaybeExtZone);
@@ -891,7 +992,7 @@
 Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
   const SourceManager &SM = Inputs.AST->getSourceManager();
   const LangOptions &LangOpts = Inputs.AST->getLangOpts();
-  auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
+  auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts);
   // FIXME: Add more types of errors.
   if (!ExtractedFunc)
     return ExtractedFunc.takeError();
@@ -914,8 +1015,8 @@
 
       tooling::Replacements OtherEdit(
           createForwardDeclaration(*ExtractedFunc, SM));
-      if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
-                                                 OtherEdit))
+      if (auto PathAndEdit =
+              Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
         MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
                                                 PathAndEdit->second);
       else
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to