5chmidti updated this revision to Diff 477184.
5chmidti added a comment.
Fixup: rm added includes
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D138499/new/
https://reviews.llvm.org/D138499
Files:
clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
clang-tools-extra/docs/ReleaseNotes.rst
Index: clang-tools-extra/docs/ReleaseNotes.rst
===================================================================
--- clang-tools-extra/docs/ReleaseNotes.rst
+++ clang-tools-extra/docs/ReleaseNotes.rst
@@ -78,6 +78,9 @@
Miscellaneous
^^^^^^^^^^^^^
+- The extract function tweak gained support for hoisting, i.e. returning decls declared
+ inside the selection that are used outside of the selection.
+
Improvements to clang-doc
-------------------------
Index: clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
+++ clang-tools-extra/clangd/unittests/tweaks/ExtractFunctionTests.cpp
@@ -30,8 +30,9 @@
EXPECT_EQ(apply("auto lam = [](){ [[int x;]] }; "), "unavailable");
// Partial statements aren't extracted.
EXPECT_THAT(apply("int [[x = 0]];"), "unavailable");
- // FIXME: Support hoisting.
- EXPECT_THAT(apply(" [[int a = 5;]] a++; "), "unavailable");
+
+ // Extract regions that require hoisting
+ EXPECT_THAT(apply(" [[int a = 5;]] a++; "), HasSubstr("extracted"));
// Ensure that end of Zone and Beginning of PostZone being adjacent doesn't
// lead to break being included in the extraction zone.
@@ -192,6 +193,202 @@
EXPECT_EQ(apply(CompoundFailInput), "unavailable");
}
+TEST_F(ExtractFunctionTest, Hoisting) {
+ std::string HoistingInput = R"cpp(
+ int foo() {
+ int a = 3;
+ [[int x = 39 + a;
+ ++x;
+ int y = x * 2;
+ int z = 4;]]
+ return x + y + z;
+ }
+ )cpp";
+ std::string HoistingOutput = R"cpp(
+ auto extracted(int &a) {
+int x = 39 + a;
+ ++x;
+ int y = x * 2;
+ int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+ int a = 3;
+ auto [x, y, z] = extracted(a);
+ return x + y + z;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+ std::string HoistingInput2 = R"cpp(
+ int foo() {
+ int a{};
+ [[int b = a + 1;]]
+ return b;
+ }
+ )cpp";
+ std::string HoistingOutput2 = R"cpp(
+ int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+ int a{};
+ auto b = extracted(a);
+ return b;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+
+ std::string HoistingInput3 = R"cpp(
+ int foo(int b) {
+ int a{};
+ if (b == 42) {
+ [[a = 123;
+ return a + b;]]
+ }
+ a = 456;
+ return a;
+ }
+ )cpp";
+ std::string HoistingOutput3 = R"cpp(
+ int extracted(int &b, int &a) {
+a = 123;
+ return a + b;
+}
+int foo(int b) {
+ int a{};
+ if (b == 42) {
+ return extracted(b, a);
+ }
+ a = 456;
+ return a;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput3), HoistingOutput3);
+
+ std::string HoistingInput4 = R"cpp(
+ struct A {
+ bool flag;
+ int val;
+ };
+ A bar();
+ int foo(int b) {
+ int a = 0;
+ [[auto [flag, val] = bar();
+ int c = 4;
+ val = c + a;]]
+ return a + b + c + val;
+ }
+ )cpp";
+ std::string HoistingOutput4 = R"cpp(
+ struct A {
+ bool flag;
+ int val;
+ };
+ A bar();
+ auto extracted(int &a) {
+auto [flag, val] = bar();
+ int c = 4;
+ val = c + a;
+return std::pair{val, c};
+}
+int foo(int b) {
+ int a = 0;
+ auto [val, c] = extracted(a);
+ return a + b + c + val;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput4), HoistingOutput4);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX11) {
+ ExtraArgs.emplace_back("-std=c++11");
+ std::string HoistingInput = R"cpp(
+ int foo() {
+ int a = 3;
+ [[int x = 39 + a;
+ ++x;
+ int y = x * 2;
+ int z = 4;]]
+ return x + y + z;
+ }
+ )cpp";
+ EXPECT_THAT(apply(HoistingInput), HasSubstr("unavailable"));
+
+ std::string HoistingInput2 = R"cpp(
+ int foo() {
+ int a;
+ [[int b = a + 1;]]
+ return b;
+ }
+ )cpp";
+ std::string HoistingOutput2 = R"cpp(
+ int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+ int a;
+ auto b = extracted(a);
+ return b;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
+TEST_F(ExtractFunctionTest, HoistingCXX14) {
+ ExtraArgs.emplace_back("-std=c++14");
+ std::string HoistingInput = R"cpp(
+ int foo() {
+ int a = 3;
+ [[int x = 39 + a;
+ ++x;
+ int y = x * 2;
+ int z = 4;]]
+ return x + y + z;
+ }
+ )cpp";
+ std::string HoistingOutput = R"cpp(
+ auto extracted(int &a) {
+int x = 39 + a;
+ ++x;
+ int y = x * 2;
+ int z = 4;
+return std::tuple{x, y, z};
+}
+int foo() {
+ int a = 3;
+ auto returned = extracted(a);
+auto x = std::get<0>(returned);
+auto y = std::get<1>(returned);
+auto z = std::get<2>(returned);
+ return x + y + z;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput), HoistingOutput);
+
+ std::string HoistingInput2 = R"cpp(
+ int foo() {
+ int a;
+ [[int b = a + 1;]]
+ return b;
+ }
+ )cpp";
+ std::string HoistingOutput2 = R"cpp(
+ int extracted(int &a) {
+int b = a + 1;
+return b;
+}
+int foo() {
+ int a;
+ auto b = extracted(a);
+ return b;
+ }
+ )cpp";
+ EXPECT_EQ(apply(HoistingInput2), HoistingOutput2);
+}
+
TEST_F(ExtractFunctionTest, DifferentHeaderSourceTest) {
Header = R"cpp(
class SomeClass {
Index: clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/ExtractFunction.cpp
@@ -80,6 +80,13 @@
using Node = SelectionTree::Node;
+struct HoistSetComparator {
+ bool operator()(const Decl *const Lhs, const Decl *const Rhs) const {
+ return Lhs->getLocation() < Rhs->getLocation();
+ }
+};
+using HoistSet = llvm::SmallSet<const NamedDecl *, 1, HoistSetComparator>;
+
// ExtractionZone is the part of code that is being extracted.
// EnclosingFunction is the function/method inside which the zone lies.
// We split the file into 4 parts relative to extraction zone.
@@ -172,12 +179,13 @@
// semicolon after the extraction.
const Node *getLastRootStmt() const { return Parent->Children.back(); }
- // Checks if declarations inside extraction zone are accessed afterwards.
+ // Checks if declarations inside extraction zone are accessed afterwards and
+ // adds these declarations to the returned set.
//
// This performs a partial AST traversal proportional to the size of the
// enclosing function, so it is possibly expensive.
- bool requiresHoisting(const SourceManager &SM,
- const HeuristicResolver *Resolver) const {
+ HoistSet getDeclsToHoist(const SourceManager &SM,
+ const HeuristicResolver *Resolver) const {
// First find all the declarations that happened inside extraction zone.
llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
for (auto *RootStmt : RootStmts) {
@@ -192,29 +200,31 @@
}
// Early exit without performing expensive traversal below.
if (DeclsInExtZone.empty())
- return false;
- // Then make sure they are not used outside the zone.
+ return {};
+ // Add any decl used after the selection to the returned set
+ HoistSet DeclsToHoist{};
for (const auto *S : EnclosingFunction->getBody()->children()) {
if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
ZoneRange.getEnd()))
continue;
- bool HasPostUse = false;
findExplicitReferences(
S,
[&](const ReferenceLoc &Loc) {
- if (HasPostUse ||
- SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
+ if (SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
return;
- HasPostUse = llvm::any_of(Loc.Targets,
- [&DeclsInExtZone](const Decl *Target) {
- return DeclsInExtZone.contains(Target);
- });
+ const auto *const PostUseIter = llvm::find_if(
+ Loc.Targets, [&DeclsInExtZone](const Decl *Target) {
+ return DeclsInExtZone.contains(Target);
+ });
+
+ if (const bool FoundPostUse = PostUseIter != Loc.Targets.end();
+ FoundPostUse) {
+ DeclsToHoist.insert(*PostUseIter);
+ }
},
Resolver);
- if (HasPostUse)
- return true;
}
- return false;
+ return DeclsToHoist;
}
};
@@ -368,16 +378,20 @@
bool Static = false;
ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
bool Const = false;
+ const HoistSet &ToHoist;
// Decides whether the extracted function body and the function call need a
// semicolon after extraction.
tooling::ExtractionSemicolonPolicy SemicolonPolicy;
const LangOptions *LangOpts;
- NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
+ NewFunction(const HoistSet &ToHoist,
+ tooling::ExtractionSemicolonPolicy SemicolonPolicy,
const LangOptions *LangOpts)
- : SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
+ : ToHoist(ToHoist), SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {
+ }
// Render the call for this function.
std::string renderCall() const;
+ std::string renderHoistedCall() const;
// Render the definition for this function.
std::string renderDeclaration(FunctionDeclKind K,
const DeclContext &SemanticDC,
@@ -463,7 +477,58 @@
return llvm::formatv("{0}{1}", QualifierName, Name);
}
+// Renders the HoistSet to a comma separated list or a single named decl.
+std::string renderHoistSet(const HoistSet &ToHoist) {
+ std::string Res{};
+ bool NeedsComma = false;
+ const auto Render = [&NeedsComma, &Res](const NamedDecl *const NDecl) {
+ if (NeedsComma) {
+ Res += ", ";
+ }
+ Res += NDecl->getNameAsString();
+ };
+ for (const NamedDecl *DeclToHoist : ToHoist) {
+ if (llvm::isa<VarDecl>(DeclToHoist) ||
+ llvm::isa<BindingDecl>(DeclToHoist)) {
+ Render(DeclToHoist);
+ }
+
+ NeedsComma = true;
+ }
+ return Res;
+}
+
+std::string NewFunction::renderHoistedCall() const {
+ auto HoistedVarDecls = std::string{};
+ auto ExplicitUnpacking = std::string{};
+ const auto HasStructuredBinding = LangOpts->CPlusPlus17;
+
+ if (ToHoist.size() > 1) {
+ if (HasStructuredBinding) {
+ HoistedVarDecls = "auto [" + renderHoistSet(ToHoist) + "] = ";
+ } else {
+ HoistedVarDecls = "auto returned = ";
+ auto DeclIter = ToHoist.begin();
+ for (size_t Index = 0U; Index < ToHoist.size(); ++Index, ++DeclIter) {
+ ExplicitUnpacking +=
+ llvm::formatv("\nauto {0} = std::get<{1}>(returned);",
+ (*DeclIter)->getNameAsString(), Index);
+ }
+ }
+ } else {
+ HoistedVarDecls = "auto " + renderHoistSet(ToHoist) + " = ";
+ }
+
+ return std::string(llvm::formatv(
+ "{0}{1}({2}){3}{4}", HoistedVarDecls, Name, renderParametersForCall(),
+ (SemicolonPolicy.isNeededInOriginalFunction() ? ";" : ""),
+ ExplicitUnpacking));
+}
+
std::string NewFunction::renderCall() const {
+ if (!ToHoist.empty())
+ return renderHoistedCall();
+
return std::string(
llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
renderParametersForCall(),
@@ -496,8 +561,20 @@
// - hoist decls
// - add return statement
// - Add semicolon
- return toSourceCode(SM, BodyRange).str() +
- (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+ auto Body = toSourceCode(SM, BodyRange).str() +
+ (SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
+ if (!ToHoist.empty()) {
+ if (const bool NeedsTupleOrPair = ToHoist.size() > 1; NeedsTupleOrPair) {
+ const auto NeedsPair = ToHoist.size() == 2;
+
+ Body += "\nreturn " +
+ std::string(NeedsPair ? "std::pair{" : "std::tuple{") +
+ renderHoistSet(ToHoist) + "};";
+ } else {
+ Body += "\nreturn " + renderHoistSet(ToHoist) + ";";
+ }
+ }
+ return Body;
}
std::string NewFunction::Parameter::render(const DeclContext *Context) const {
@@ -675,10 +752,6 @@
const auto &DeclInfo = KeyVal.second;
// If a Decl was Declared in zone and referenced in post zone, it
// needs to be hoisted (we bail out in that case).
- // FIXME: Support Decl Hoisting.
- if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
- DeclInfo.IsReferencedInPostZone)
- return false;
if (!DeclInfo.IsReferencedInZone)
continue; // no need to pass as parameter, not referenced
if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
@@ -724,6 +797,19 @@
return SemicolonPolicy;
}
+QualType getReturnTypeForHoisted(const FunctionDecl &EnclosingFunc,
+ const HoistSet &ToHoist) {
+ // Hoisting just one variable, use that variables type instead of auto
+ if (ToHoist.size() == 1) {
+ if (const auto *const VDecl = llvm::dyn_cast<VarDecl>(*ToHoist.begin());
+ VDecl != nullptr) {
+ return VDecl->getType();
+ }
+ }
+
+ return EnclosingFunc.getParentASTContext().getAutoDeductType();
+}
+
// Generate return type for ExtractedFunc. Return false if unable to do so.
bool generateReturnProperties(NewFunction &ExtractedFunc,
const FunctionDecl &EnclosingFunc,
@@ -745,7 +831,11 @@
return true;
}
// FIXME: Generate new return statement if needed.
- ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
+ ExtractedFunc.ReturnType =
+ !ExtractedFunc.ToHoist.empty()
+ ? getReturnTypeForHoisted(EnclosingFunc, ExtractedFunc.ToHoist)
+ : EnclosingFunc.getParentASTContext().VoidTy;
+
return true;
}
@@ -759,6 +849,7 @@
// FIXME: add support for adding other function return types besides void.
// FIXME: assign the value returned by non void extracted function.
llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
+ const HoistSet &ToHoist,
const SourceManager &SM,
const LangOptions &LangOpts) {
CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
@@ -766,7 +857,7 @@
if (CapturedInfo.BrokenControlFlow)
return error("Cannot extract break/continue without corresponding "
"loop/switch statement.");
- NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
+ NewFunction ExtractedFunc(ToHoist, getSemicolonPolicy(ExtZone, SM, LangOpts),
&LangOpts);
ExtractedFunc.SyntacticDC =
@@ -815,6 +906,7 @@
private:
ExtractionZone ExtZone;
+ HoistSet ToHoist;
};
REGISTER_TWEAK(ExtractFunction)
@@ -880,8 +972,12 @@
(hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
return false;
- // FIXME: Get rid of this check once we support hoisting.
- if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
+ ToHoist =
+ MaybeExtZone->getDeclsToHoist(SM, Inputs.AST->getHeuristicResolver());
+
+ const auto HasAutoReturnTypeDeduction = LangOpts.CPlusPlus14;
+ const auto RequiresPairOrTuple = ToHoist.size() > 1;
+ if (RequiresPairOrTuple && !HasAutoReturnTypeDeduction)
return false;
ExtZone = std::move(*MaybeExtZone);
@@ -891,7 +987,7 @@
Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
const SourceManager &SM = Inputs.AST->getSourceManager();
const LangOptions &LangOpts = Inputs.AST->getLangOpts();
- auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
+ auto ExtractedFunc = getExtractedFunction(ExtZone, ToHoist, SM, LangOpts);
// FIXME: Add more types of errors.
if (!ExtractedFunc)
return ExtractedFunc.takeError();
@@ -914,8 +1010,8 @@
tooling::Replacements OtherEdit(
createForwardDeclaration(*ExtractedFunc, SM));
- if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
- OtherEdit))
+ if (auto PathAndEdit =
+ Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc), OtherEdit))
MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
PathAndEdit->second);
else
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits