ioeric updated this revision to Diff 163351.
ioeric added a comment.

- Return all possible #includes in CodeComplete.


Repository:
  rCTE Clang Tools Extra

https://reviews.llvm.org/D51291

Files:
  clangd/CodeComplete.cpp
  clangd/CodeComplete.h
  clangd/index/Index.cpp
  clangd/index/Index.h
  clangd/index/Merge.cpp
  clangd/index/SymbolCollector.cpp
  clangd/index/SymbolYAML.cpp
  unittests/clangd/CodeCompleteTests.cpp
  unittests/clangd/FileIndexTests.cpp
  unittests/clangd/IndexTests.cpp
  unittests/clangd/SymbolCollectorTests.cpp

Index: unittests/clangd/SymbolCollectorTests.cpp
===================================================================
--- unittests/clangd/SymbolCollectorTests.cpp
+++ unittests/clangd/SymbolCollectorTests.cpp
@@ -53,7 +53,11 @@
 MATCHER_P(DeclURI, P, "") { return arg.CanonicalDeclaration.FileURI == P; }
 MATCHER_P(DefURI, P, "") { return arg.Definition.FileURI == P; }
 MATCHER_P(IncludeHeader, P, "") {
-  return arg.Detail && arg.Detail->IncludeHeader == P;
+  return (arg.IncludeHeaders.size() == 1) &&
+         (arg.IncludeHeaders.begin()->IncludeHeader == P);
+}
+MATCHER_P2(IncludeHeaderWithRef, IncludeHeader, References,  "") {
+  return (arg.IncludeHeader == IncludeHeader) && (arg.References == References);
 }
 MATCHER_P(DeclRange, Pos, "") {
   return std::tie(arg.CanonicalDeclaration.Start.Line,
@@ -696,6 +700,11 @@
     Line: 1
     Column: 1
 IsIndexedForCodeCompletion:    true
+IncludeHeaders:
+  - Header:    'include1'
+    References:    7
+  - Header:    'include2'
+    References:    3
 Detail:
   Documentation:    'Foo doc'
   ReturnType:    'int'
@@ -730,6 +739,11 @@
                                          Doc("Foo doc"), ReturnType("int"),
                                          DeclURI("file:///path/foo.h"),
                                          ForCodeCompletion(true))));
+  auto &Sym1 = *Symbols1.begin();
+  assert(Sym1.Detail);
+  EXPECT_THAT(Sym1.IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef("include1", 7u),
+                                   IncludeHeaderWithRef("include2", 3u)));
   auto Symbols2 = symbolsFromYAML(YAML2);
   EXPECT_THAT(Symbols2, UnorderedElementsAre(AllOf(
                             QName("clang::Foo2"), Labeled("Foo2-sig"),
@@ -751,9 +765,10 @@
 TEST_F(SymbolCollectorTest, IncludeHeaderSameAsFileURI) {
   CollectorOpts.CollectIncludePath = true;
   runSymbolCollector("class Foo {};", /*Main=*/"");
-  EXPECT_THAT(Symbols,
-              UnorderedElementsAre(AllOf(QName("Foo"), DeclURI(TestHeaderURI),
-                                         IncludeHeader(TestHeaderURI))));
+  EXPECT_THAT(Symbols, UnorderedElementsAre(
+                           AllOf(QName("Foo"), DeclURI(TestHeaderURI))));
+  EXPECT_THAT(Symbols.begin()->IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef(TestHeaderURI, 1u)));
 }
 
 #ifndef _WIN32
Index: unittests/clangd/IndexTests.cpp
===================================================================
--- unittests/clangd/IndexTests.cpp
+++ unittests/clangd/IndexTests.cpp
@@ -243,6 +243,48 @@
   EXPECT_EQ(M.Name, "right");
 }
 
+MATCHER_P2(IncludeHeaderWithRef, IncludeHeader, References,  "") {
+  return (arg.IncludeHeader == IncludeHeader) && (arg.References == References);
+}
+
+TEST(MergeTest, MergeIncludesOnDifferentDefinitions) {
+  Symbol L, R;
+  L.Name = "left";
+  R.Name = "right";
+  L.ID = R.ID = SymbolID("hello");
+  L.IncludeHeaders.emplace_back("common", 1);
+  R.IncludeHeaders.emplace_back("common", 1);
+  R.IncludeHeaders.emplace_back("new", 1);
+
+  Symbol::Details Scratch;
+
+  // Both have no definition.
+  Symbol M = mergeSymbol(L, R, &Scratch);
+  EXPECT_THAT(M.IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef("common", 2u),
+                                   IncludeHeaderWithRef("new", 1u)));
+
+  // Only merge references of the same includes but do not merge new #includes.
+  L.Definition.FileURI = "file:/left.h";
+  M = mergeSymbol(L, R, &Scratch);
+  EXPECT_THAT(M.IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef("common", 2u)));
+
+  // Definitions are the same.
+  R.Definition.FileURI = "file:/right.h";
+  M = mergeSymbol(L, R, &Scratch);
+  EXPECT_THAT(M.IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef("common", 2u),
+                                   IncludeHeaderWithRef("new", 1u)));
+
+  // Definitions are different.
+  R.Definition.FileURI = "file:/right.h";
+  M = mergeSymbol(L, R, &Scratch);
+  EXPECT_THAT(M.IncludeHeaders,
+              UnorderedElementsAre(IncludeHeaderWithRef("common", 2u),
+                                   IncludeHeaderWithRef("new", 1u)));
+}
+
 } // namespace
 } // namespace clangd
 } // namespace clang
Index: unittests/clangd/FileIndexTests.cpp
===================================================================
--- unittests/clangd/FileIndexTests.cpp
+++ unittests/clangd/FileIndexTests.cpp
@@ -177,7 +177,7 @@
   Req.Query = "";
   bool SeenSymbol = false;
   M.fuzzyFind(Req, [&](const Symbol &Sym) {
-    EXPECT_TRUE(Sym.Detail->IncludeHeader.empty());
+    EXPECT_TRUE(Sym.IncludeHeaders.empty());
     SeenSymbol = true;
   });
   EXPECT_TRUE(SeenSymbol);
Index: unittests/clangd/CodeCompleteTests.cpp
===================================================================
--- unittests/clangd/CodeCompleteTests.cpp
+++ unittests/clangd/CodeCompleteTests.cpp
@@ -55,10 +55,16 @@
 MATCHER_P(Kind, K, "") { return arg.Kind == K; }
 MATCHER_P(Doc, D, "") { return arg.Documentation == D; }
 MATCHER_P(ReturnType, D, "") { return arg.ReturnType == D; }
+MATCHER_P(HasInclude, IncludeHeader, "") {
+  return !arg.Includes.empty() && arg.Includes[0].Header == IncludeHeader;
+}
 MATCHER_P(InsertInclude, IncludeHeader, "") {
-  return arg.Header == IncludeHeader && bool(arg.HeaderInsertion);
+  return !arg.Includes.empty() && arg.Includes[0].Header == IncludeHeader &&
+         bool(arg.Includes[0].Insertion);
+}
+MATCHER(InsertInclude, "") {
+  return !arg.Includes.empty() && bool(arg.Includes[0].Insertion);
 }
-MATCHER(InsertInclude, "") { return bool(arg.HeaderInsertion); }
 MATCHER_P(SnippetSuffix, Text, "") { return arg.SnippetSuffix == Text; }
 MATCHER_P(Origin, OriginSet, "") { return arg.Origin == OriginSet; }
 
@@ -568,7 +574,8 @@
   auto BarURI = URI::createFile(BarHeader).toString();
   Symbol Sym = cls("ns::X");
   Sym.CanonicalDeclaration.FileURI = BarURI;
-  Scratch.IncludeHeader = BarURI;
+  Sym.IncludeHeaders.emplace_back(BarURI, 1);
+
   Sym.Detail = &Scratch;
   // Shoten include path based on search dirctory and insert.
   auto Results = completions(Server,
@@ -602,7 +609,10 @@
   auto BarURI = URI::createFile(BarHeader).toString();
   SymX.CanonicalDeclaration.FileURI = BarURI;
   SymY.CanonicalDeclaration.FileURI = BarURI;
-  Scratch.IncludeHeader = "<bar>";
+
+  SymX.IncludeHeaders.emplace_back("<bar>", 1);
+  SymY.IncludeHeaders.emplace_back("<bar>", 1);
+
   SymX.Detail = &Scratch;
   SymY.Detail = &Scratch;
   // Shoten include path based on search dirctory and insert.
@@ -1130,7 +1140,8 @@
   Symbol::Details Detail;
   std::string DeclFile = URI::createFile(testPath("foo")).toString();
   NoArgsGFunc.CanonicalDeclaration.FileURI = DeclFile;
-  Detail.IncludeHeader = "<foo>";
+  NoArgsGFunc.IncludeHeaders.emplace_back("<foo>", 1);
+
   NoArgsGFunc.Detail = &Detail;
   EXPECT_THAT(
       completions(Context + "int y = GFunc^", {NoArgsGFunc}, Opts).Completions,
@@ -1298,7 +1309,9 @@
   C.RequiredQualifier = "Foo::";
   C.Scope = "ns::Foo::";
   C.Documentation = "This is x().";
-  C.Header = "\"foo.h\"";
+  C.Includes.emplace_back();
+  auto &Include = C.Includes.back();
+  Include.Header = "\"foo.h\"";
   C.Kind = CompletionItemKind::Method;
   C.Score.Total = 1.0;
   C.Origin = SymbolOrigin::AST | SymbolOrigin::Static;
@@ -1323,7 +1336,7 @@
   EXPECT_EQ(R.insertText, "Foo::x(${0:bool})");
   EXPECT_EQ(R.insertTextFormat, InsertTextFormat::Snippet);
 
-  C.HeaderInsertion.emplace();
+  Include.Insertion.emplace();
   R = C.render(Opts);
   EXPECT_EQ(R.label, "^Foo::x(bool) const");
   EXPECT_THAT(R.additionalTextEdits, Not(IsEmpty()));
@@ -1782,6 +1795,41 @@
   ASSERT_EQ(Reqs3.size(), 2u);
 }
 
+TEST(CompletionTest, InsertTheMostPopularHeader) {
+  std::string DeclFile = URI::createFile(testPath("foo")).toString();
+  Symbol sym = func("Func");
+  sym.CanonicalDeclaration.FileURI = DeclFile;
+  sym.IncludeHeaders.emplace_back("\"foo.h\"", 2);
+  sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000);
+
+  auto Results = completions("Fun^", {sym}).Completions;
+  assert(!Results.empty());
+  EXPECT_THAT(Results[0], AllOf(Named("Func"), InsertInclude("\"bar.h\"")));
+  EXPECT_EQ(Results[0].Includes.size(), 2u);
+}
+
+TEST(CompletionTest, NoInsertIncludeIfOnePresent) {
+  MockFSProvider FS;
+  MockCompilationDatabase CDB;
+
+  std::string FooHeader = testPath("foo.h");
+  FS.Files[FooHeader] = "";
+
+  IgnoreDiagnostics DiagConsumer;
+  ClangdServer Server(CDB, FS, DiagConsumer, ClangdServer::optsForTest());
+
+  std::string DeclFile = URI::createFile(testPath("foo")).toString();
+  Symbol sym = func("Func");
+  sym.CanonicalDeclaration.FileURI = DeclFile;
+  sym.IncludeHeaders.emplace_back("\"foo.h\"", 2);
+  sym.IncludeHeaders.emplace_back("\"bar.h\"", 1000);
+
+  EXPECT_THAT(
+      completions(Server, "#include \"foo.h\"\nFun^", {sym}).Completions,
+      UnorderedElementsAre(
+          AllOf(Named("Func"), HasInclude("\"foo.h\""), Not(InsertInclude()))));
+}
+
 } // namespace
 } // namespace clangd
 } // namespace clang
Index: clangd/index/SymbolYAML.cpp
===================================================================
--- clangd/index/SymbolYAML.cpp
+++ clangd/index/SymbolYAML.cpp
@@ -10,11 +10,13 @@
 #include "SymbolYAML.h"
 #include "Index.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
 
 LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(clang::clangd::Symbol)
+LLVM_YAML_IS_SEQUENCE_VECTOR(clang::clangd::Symbol::IncludeHeaderWithReferences)
 
 namespace llvm {
 namespace yaml {
@@ -70,7 +72,15 @@
   static void mapping(IO &io, Symbol::Details &Detail) {
     io.mapOptional("Documentation", Detail.Documentation);
     io.mapOptional("ReturnType", Detail.ReturnType);
-    io.mapOptional("IncludeHeader", Detail.IncludeHeader);
+  }
+};
+
+template <>
+struct MappingTraits<clang::clangd::Symbol::IncludeHeaderWithReferences> {
+  static void mapping(IO &io,
+                      clang::clangd::Symbol::IncludeHeaderWithReferences &Inc) {
+    io.mapRequired("Header", Inc.IncludeHeader);
+    io.mapRequired("References", Inc.References);
   }
 };
 
@@ -112,6 +122,7 @@
                    false);
     IO.mapOptional("Signature", Sym.Signature);
     IO.mapOptional("CompletionSnippetSuffix", Sym.CompletionSnippetSuffix);
+    IO.mapOptional("IncludeHeaders", Sym.IncludeHeaders);
     IO.mapOptional("Detail", NDetail->Opt);
   }
 };
Index: clangd/index/SymbolCollector.cpp
===================================================================
--- clangd/index/SymbolCollector.cpp
+++ clangd/index/SymbolCollector.cpp
@@ -403,10 +403,12 @@
                              SM.getExpansionLoc(MI->getDefinitionLoc()), Opts))
       Include = std::move(*Header);
   }
+  if (!Include.empty())
+    S.IncludeHeaders.emplace_back(Include, 1);
+
   S.Signature = Signature;
   S.CompletionSnippetSuffix = SnippetSuffix;
   Symbol::Details Detail;
-  Detail.IncludeHeader = Include;
   S.Detail = &Detail;
   Symbols.insert(S);
   return true;
@@ -489,7 +491,8 @@
   Symbol::Details Detail;
   Detail.Documentation = Documentation;
   Detail.ReturnType = ReturnType;
-  Detail.IncludeHeader = Include;
+  if (!Include.empty())
+    S.IncludeHeaders.emplace_back(Include, 1);
   S.Detail = &Detail;
 
   S.Origin = Opts.Origin;
Index: clangd/index/Merge.cpp
===================================================================
--- clangd/index/Merge.cpp
+++ clangd/index/Merge.cpp
@@ -100,6 +100,10 @@
   // Classes: this is the def itself. Functions: hopefully the header decl.
   // If both did (or both didn't), continue to prefer L over R.
   bool PreferR = R.Definition && !L.Definition;
+  // Merge include headers only if both have definitions or both have no
+  // definition; otherwise, only accumulate references of common includes.
+  bool MergeIncludes =
+      L.Definition.FileURI.empty() == R.Definition.FileURI.empty();
   Symbol S = PreferR ? R : L;        // The target symbol we're merging into.
   const Symbol &O = PreferR ? L : R; // The "other" less-preferred symbol.
 
@@ -114,6 +118,18 @@
     S.Signature = O.Signature;
   if (S.CompletionSnippetSuffix == "")
     S.CompletionSnippetSuffix = O.CompletionSnippetSuffix;
+  for (const auto &OI : O.IncludeHeaders) {
+    bool Found = false;
+    for (auto &SI : S.IncludeHeaders) {
+      if (SI.IncludeHeader == OI.IncludeHeader) {
+        Found = true;
+        SI.References += OI.References;
+        break;
+      }
+    }
+    if (!Found && MergeIncludes)
+      S.IncludeHeaders.emplace_back(OI.IncludeHeader, OI.References);
+  }
 
   if (O.Detail) {
     if (S.Detail) {
@@ -123,8 +139,6 @@
         Scratch->Documentation = O.Detail->Documentation;
       if (Scratch->ReturnType == "")
         Scratch->ReturnType = O.Detail->ReturnType;
-      if (Scratch->IncludeHeader == "")
-        Scratch->IncludeHeader = O.Detail->IncludeHeader;
       S.Detail = Scratch;
     } else
       S.Detail = O.Detail;
Index: clangd/index/Index.h
===================================================================
--- clangd/index/Index.h
+++ clangd/index/Index.h
@@ -16,7 +16,9 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/StringSaver.h"
 #include <array>
 #include <string>
@@ -200,6 +202,30 @@
   /// (When snippets are disabled, the symbol name alone is used).
   llvm::StringRef CompletionSnippetSuffix;
 
+  struct IncludeHeaderWithReferences {
+    IncludeHeaderWithReferences() = default;
+
+    IncludeHeaderWithReferences(llvm::StringRef IncludeHeader,
+                                unsigned References)
+        : IncludeHeader(IncludeHeader), References(References) {}
+
+    /// This can be either a URI of the header to be #include'd
+    /// for this symbol, or a literal header quoted with <> or "" that is
+    /// suitable to be included directly. When it is a URI, the exact #include
+    /// path needs to be calculated according to the URI scheme.
+    ///
+    /// Note that the include header is a canonical include for the symbol and
+    /// can be different from FileURI in the CanonicalDeclaration.
+    llvm::StringRef IncludeHeader = "";
+    /// The number of translation units that reference this symbol and include
+    /// this header. This number is only meaningful if aggregated in an index.
+    unsigned References = 0;
+  };
+  /// One Symbol can potentially be incuded via different headers.
+  ///   - If we haven't seen a definition, this covers all declarations.
+  ///   - If we have seen a definition, this covers declarations visible from
+  ///   any definition.
+  llvm::SmallVector<IncludeHeaderWithReferences, 1> IncludeHeaders;
   /// Optional symbol details that are not required to be set. For example, an
   /// index fuzzy match can return a large number of symbol candidates, and it
   /// is preferable to send only core symbol information in the batched results
@@ -211,14 +237,6 @@
     /// Type when this symbol is used in an expression. (Short display form).
     /// e.g. return type of a function, or type of a variable.
     llvm::StringRef ReturnType;
-    /// This can be either a URI of the header to be #include'd for this symbol,
-    /// or a literal header quoted with <> or "" that is suitable to be included
-    /// directly. When this is a URI, the exact #include path needs to be
-    /// calculated according to the URI scheme.
-    ///
-    /// This is a canonical include for the symbol and can be different from
-    /// FileURI in the CanonicalDeclaration.
-    llvm::StringRef IncludeHeader;
   };
 
   // Optional details of the symbol.
Index: clangd/index/Index.cpp
===================================================================
--- clangd/index/Index.cpp
+++ clangd/index/Index.cpp
@@ -9,6 +9,7 @@
 
 #include "Index.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/SHA1.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -91,14 +92,16 @@
   Intern(S.Signature);
   Intern(S.CompletionSnippetSuffix);
 
+  for (auto &I : S.IncludeHeaders)
+      Intern(I.IncludeHeader);
+
   if (S.Detail) {
     // Copy values of StringRefs into arena.
     auto *Detail = Arena.Allocate<Symbol::Details>();
     *Detail = *S.Detail;
     // Intern the actual strings.
     Intern(Detail->Documentation);
     Intern(Detail->ReturnType);
-    Intern(Detail->IncludeHeader);
     // Replace the detail pointer with our copy.
     S.Detail = Detail;
   }
Index: clangd/CodeComplete.h
===================================================================
--- clangd/CodeComplete.h
+++ clangd/CodeComplete.h
@@ -26,6 +26,7 @@
 #include "clang/Sema/CodeCompleteOptions.h"
 #include "clang/Tooling/CompilationDatabase.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Error.h"
 #include <future>
@@ -131,12 +132,18 @@
   // Other fields should apply equally to all bundled completions.
   unsigned BundleSize = 1;
   SymbolOrigin Origin = SymbolOrigin::Unknown;
-  // The header through which this symbol could be included.
-  // Quoted string as expected by an #include directive, e.g. "<memory>".
-  // Empty for non-symbol completions, or when not known.
-  std::string Header;
-  // Present if Header is set and should be inserted to use this item.
-  llvm::Optional<TextEdit> HeaderInsertion;
+
+  struct IncludeCandidate {
+    // The header through which this symbol could be included.
+    // Quoted string as expected by an #include directive, e.g. "<memory>".
+    // Empty for non-symbol completions, or when not known.
+    std::string Header;
+    // Present if Header should be inserted to use this item.
+    llvm::Optional<TextEdit> Insertion;
+  };
+  // All possible include headers ranked by preference. By default, the first
+  // include is used.
+  llvm::SmallVector<IncludeCandidate, 1> Includes;
 
   /// Holds information about small corrections that needs to be done. Like
   /// converting '->' to '.' on member access.
Index: clangd/CodeComplete.cpp
===================================================================
--- clangd/CodeComplete.cpp
+++ clangd/CodeComplete.cpp
@@ -44,10 +44,13 @@
 #include "clang/Sema/Sema.h"
 #include "clang/Tooling/Core/Replacement.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/ScopedPrinter.h"
+#include <algorithm>
+#include <iterator>
 #include <queue>
 
 // We log detailed candidate here if you run with -debug-only=codecomplete.
@@ -247,6 +250,7 @@
   // We may have a result from Sema, from the index, or both.
   const CodeCompletionResult *SemaResult = nullptr;
   const Symbol *IndexResult = nullptr;
+  llvm::SmallVector<StringRef, 1> RankedIncludeHeaders;
 
   // States whether this item is an override suggestion.
   bool IsOverride = false;
@@ -267,7 +271,7 @@
         // This could break #include insertion.
         return hash_combine(
             (IndexResult->Scope + IndexResult->Name).toStringRef(Scratch),
-            headerToInsertIfNotPresent().getValueOr(""));
+            headerToInsertIfAllowed().getValueOr(""));
       default:
         return 0;
       }
@@ -281,12 +285,12 @@
       llvm::raw_svector_ostream OS(Scratch);
       D->printQualifiedName(OS);
     }
-    return hash_combine(Scratch, headerToInsertIfNotPresent().getValueOr(""));
+    return hash_combine(Scratch, headerToInsertIfAllowed().getValueOr(""));
   }
 
-  llvm::Optional<llvm::StringRef> headerToInsertIfNotPresent() const {
-    if (!IndexResult || !IndexResult->Detail ||
-        IndexResult->Detail->IncludeHeader.empty())
+  // The best header to include if include insertion is allowed.
+  llvm::Optional<llvm::StringRef> headerToInsertIfAllowed() const {
+    if (RankedIncludeHeaders.empty())
       return llvm::None;
     if (SemaResult && SemaResult->Declaration) {
       // Avoid inserting new #include if the declaration is found in the current
@@ -296,7 +300,7 @@
         if (SM.isInMainFile(SM.getExpansionLoc(RD->getBeginLoc())))
           return llvm::None;
     }
-    return IndexResult->Detail->IncludeHeader;
+    return RankedIncludeHeaders[0];
   }
 
   using Bundle = llvm::SmallVector<CompletionCandidate, 4>;
@@ -359,31 +363,42 @@
       if (Completion.Name.empty())
         Completion.Name = C.IndexResult->Name;
     }
-    if (auto Inserted = C.headerToInsertIfNotPresent()) {
-      // Turn absolute path into a literal string that can be #included.
-      auto Include = [&]() -> Expected<std::pair<std::string, bool>> {
-        auto ResolvedDeclaring =
-            toHeaderFile(C.IndexResult->CanonicalDeclaration.FileURI, FileName);
-        if (!ResolvedDeclaring)
-          return ResolvedDeclaring.takeError();
-        auto ResolvedInserted = toHeaderFile(*Inserted, FileName);
-        if (!ResolvedInserted)
-          return ResolvedInserted.takeError();
-        return std::make_pair(Includes.calculateIncludePath(*ResolvedDeclaring,
-                                                            *ResolvedInserted),
-                              Includes.shouldInsertInclude(*ResolvedDeclaring,
-                                                           *ResolvedInserted));
-      }();
-      if (Include) {
-        Completion.Header = Include->first;
-        if (Include->second)
-          Completion.HeaderInsertion = Includes.insert(Include->first);
+
+    // Turn absolute path into a literal string that can be #included.
+    auto Inserted =
+        [&](StringRef Header) -> Expected<std::pair<std::string, bool>> {
+      auto ResolvedDeclaring =
+          toHeaderFile(C.IndexResult->CanonicalDeclaration.FileURI, FileName);
+      if (!ResolvedDeclaring)
+        return ResolvedDeclaring.takeError();
+      auto ResolvedInserted = toHeaderFile(Header, FileName);
+      if (!ResolvedInserted)
+        return ResolvedInserted.takeError();
+      return std::make_pair(
+          Includes.calculateIncludePath(*ResolvedDeclaring, *ResolvedInserted),
+          Includes.shouldInsertInclude(*ResolvedDeclaring, *ResolvedInserted));
+    };
+    bool ShouldInsert = C.headerToInsertIfAllowed().hasValue();
+    // Calculate include paths and edits for all possible headers.
+    for (const auto &Inc : C.RankedIncludeHeaders) {
+      if (auto ToInclude = Inserted(Inc)) {
+        CodeCompletion::IncludeCandidate Include;
+        Include.Header = ToInclude->first;
+        if (ToInclude->second && ShouldInsert)
+          Include.Insertion = Includes.insert(ToInclude->first);
+        Completion.Includes.push_back(std::move(Include));
       } else
         log("Failed to generate include insertion edits for adding header "
             "(FileURI='{0}', IncludeHeader='{1}') into {2}",
-            C.IndexResult->CanonicalDeclaration.FileURI,
-            C.IndexResult->Detail->IncludeHeader, FileName);
+            C.IndexResult->CanonicalDeclaration.FileURI, Inc, FileName);
     }
+    // Prefer includes that do not need edits (i.e. already exist).
+    std::stable_sort(Completion.Includes.begin(), Completion.Includes.end(),
+                     [](const CodeCompletion::IncludeCandidate &LHS,
+                        const CodeCompletion::IncludeCandidate &RHS) {
+                       return static_cast<int>(LHS.Insertion.hasValue()) <
+                              static_cast<int>(RHS.Insertion.hasValue());
+                     });
   }
 
   void add(const CompletionCandidate &C, CodeCompletionString *SemaCCS) {
@@ -1128,6 +1143,26 @@
   return Result;
 }
 
+// Returns the most popular include header for \p Sym. If two headers are
+// equally popular, prefer the shorter one. Returns empty string if \p Sym has
+// no include header.
+llvm::SmallVector<StringRef, 1>
+getRankedIncludes(const Symbol &Sym) {
+  auto Includes = Sym.IncludeHeaders;
+  // Sort in descending order by reference count and header length.
+  std::sort(Includes.begin(), Includes.end(),
+            [](const Symbol::IncludeHeaderWithReferences &LHS,
+               const Symbol::IncludeHeaderWithReferences &RHS) {
+              if (LHS.References == RHS.References)
+                return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
+              return LHS.References > RHS.References;
+            });
+  llvm::SmallVector<StringRef, 1> Headers;
+  for (const auto &Include : Includes)
+    Headers.push_back(Include.IncludeHeader);
+  return Headers;
+}
+
 // Runs Sema-based (AST) and Index-based completion, returns merged results.
 //
 // There are a few tricky considerations:
@@ -1376,6 +1411,8 @@
       CompletionCandidate C;
       C.SemaResult = SemaResult;
       C.IndexResult = IndexResult;
+      if (C.IndexResult)
+        C.RankedIncludeHeaders = getRankedIncludes(*C.IndexResult);
       C.IsOverride = IsOverride;
       C.Name = IndexResult ? IndexResult->Name : Recorder->getName(*SemaResult);
       if (auto OverloadSet = Opts.BundleOverloads ? C.overloadSet() : 0) {
@@ -1569,16 +1606,18 @@
 
 CompletionItem CodeCompletion::render(const CodeCompleteOptions &Opts) const {
   CompletionItem LSP;
-  LSP.label = (HeaderInsertion ? Opts.IncludeIndicator.Insert
-                               : Opts.IncludeIndicator.NoInsert) +
+  const auto *InsertInclude = Includes.empty() ? nullptr : &Includes[0];
+  LSP.label = ((InsertInclude && InsertInclude->Insertion)
+                   ? Opts.IncludeIndicator.Insert
+                   : Opts.IncludeIndicator.NoInsert) +
               (Opts.ShowOrigins ? "[" + llvm::to_string(Origin) + "]" : "") +
               RequiredQualifier + Name + Signature;
 
   LSP.kind = Kind;
   LSP.detail = BundleSize > 1 ? llvm::formatv("[{0} overloads]", BundleSize)
                               : ReturnType;
-  if (!Header.empty())
-    LSP.detail += "\n" + Header;
+  if (InsertInclude)
+    LSP.detail += "\n" + InsertInclude->Header;
   LSP.documentation = Documentation;
   LSP.sortText = sortText(Score.Total, Name);
   LSP.filterText = Name;
@@ -1606,8 +1645,8 @@
   LSP.insertText = LSP.textEdit->newText;
   LSP.insertTextFormat = Opts.EnableSnippets ? InsertTextFormat::Snippet
                                              : InsertTextFormat::PlainText;
-  if (HeaderInsertion)
-    LSP.additionalTextEdits.push_back(*HeaderInsertion);
+  if (InsertInclude && InsertInclude->Insertion)
+    LSP.additionalTextEdits.push_back(*InsertInclude->Insertion);
   return LSP;
 }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to