kadircet updated this revision to Diff 487721.
kadircet marked 5 inline comments as done.
kadircet added a comment.

- Address all comments but the ones on tests.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D139921/new/

https://reviews.llvm.org/D139921

Files:
  clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
  clang-tools-extra/include-cleaner/lib/Analysis.cpp
  clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
  clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
  clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
  clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
  clang-tools-extra/include-cleaner/lib/Types.cpp
  clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
  clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp

Index: clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
===================================================================
--- clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
+++ clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "AnalysisInternal.h"
-#include "clang-include-cleaner/Analysis.h"
 #include "clang-include-cleaner/Record.h"
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/RecursiveASTVisitor.h"
@@ -15,15 +14,14 @@
 #include "clang/Basic/FileManager.h"
 #include "clang/Frontend/FrontendActions.h"
 #include "clang/Testing/TestAST.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/Support/raw_ostream.h"
-#include "llvm/Testing/Support/Annotations.h"
+#include "llvm/ADT/SmallVector.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <memory>
 
 namespace clang::include_cleaner {
 namespace {
+using testing::ElementsAre;
 using testing::UnorderedElementsAre;
 
 std::string guard(llvm::StringRef Code) {
@@ -53,11 +51,12 @@
   void buildAST() { AST = std::make_unique<TestAST>(Inputs); }
 
   llvm::SmallVector<Header> findHeaders(llvm::StringRef FileName) {
-    return include_cleaner::findHeaders(
+    auto Headers = include_cleaner::findHeaders(
         AST->sourceManager().translateFileLineCol(
             AST->fileManager().getFile(FileName).get(),
             /*Line=*/1, /*Col=*/1),
         AST->sourceManager(), &PI);
+    return {Headers.begin(), Headers.end()};
   }
   const FileEntry *physicalHeader(llvm::StringRef FileName) {
     return AST->fileManager().getFile(FileName).get();
@@ -207,12 +206,166 @@
     CustomVisitor Visitor;
     Visitor.TraverseDecl(AST->context().getTranslationUnitDecl());
 
-    llvm::SmallVector<Header> Headers = clang::include_cleaner::findHeaders(
+    auto Headers = clang::include_cleaner::findHeaders(
         Visitor.Out->getLocation(), AST->sourceManager(),
         /*PragmaIncludes=*/nullptr);
     EXPECT_THAT(Headers, UnorderedElementsAre(physicalHeader("declare.h")));
   }
 }
 
+class HeadersForSymbolTest : public FindHeadersTest {
+protected:
+  llvm::SmallVector<Header> headersForFoo() {
+    struct Visitor : public RecursiveASTVisitor<Visitor> {
+      const NamedDecl *Out = nullptr;
+      bool VisitNamedDecl(const NamedDecl *ND) {
+        if (ND->getName() == "foo") {
+          EXPECT_TRUE(Out == nullptr || Out == ND->getCanonicalDecl())
+              << "Found multiple matches for foo.";
+          Out = cast<NamedDecl>(ND->getCanonicalDecl());
+        }
+        return true;
+      }
+    };
+    Visitor V;
+    V.TraverseDecl(AST->context().getTranslationUnitDecl());
+    if (!V.Out)
+      ADD_FAILURE() << "Couldn't find any decls named foo.";
+    assert(V.Out);
+    return headersForSymbol(*V.Out, AST->sourceManager(), &PI);
+  }
+};
+
+TEST_F(HeadersForSymbolTest, Deduplicates) {
+  Inputs.Code = R"cpp(
+    #include "foo.h"
+  )cpp";
+  Inputs.ExtraFiles["foo.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "foo.h"
+    void foo();
+    void foo();
+  )cpp");
+  buildAST();
+  EXPECT_THAT(
+      headersForFoo(),
+      UnorderedElementsAre(physicalHeader("foo.h"),
+                           // FIXME: de-duplicate across different kinds.
+                           Header("\"foo.h\"")));
+}
+
+TEST_F(HeadersForSymbolTest, RankingPreservesDiscoveryOrder) {
+  Inputs.Code = R"cpp(
+    #include "fox.h"
+    #include "bar.h"
+  )cpp";
+  Inputs.ExtraFiles["fox.h"] = guard(R"cpp(
+    void foo();
+  )cpp");
+  Inputs.ExtraFiles["bar.h"] = guard(R"cpp(
+    void foo();
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("fox.h"), physicalHeader("bar.h")));
+}
+
+TEST_F(HeadersForSymbolTest, Ranking) {
+  // Sorting is done over (canonical, public, complete) triplet.
+  Inputs.Code = R"cpp(
+    #include "private.h"
+    #include "public.h"
+    #include "public_complete.h"
+  )cpp";
+  Inputs.ExtraFiles["public.h"] = guard(R"cpp(
+    struct foo;
+  )cpp");
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "canonical.h"
+    struct foo;
+  )cpp");
+  Inputs.ExtraFiles["public_complete.h"] = guard("struct foo {};");
+  buildAST();
+  EXPECT_THAT(headersForFoo(), ElementsAre(Header("\"canonical.h\""),
+                                           physicalHeader("public_complete.h"),
+                                           physicalHeader("public.h"),
+                                           physicalHeader("private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferPublicOverComplete) {
+  Inputs.Code = R"cpp(
+    #include "complete_private.h"
+    #include "public.h"
+  )cpp";
+  Inputs.ExtraFiles["complete_private.h"] = guard(R"cpp(
+    // IWYU pragma: private
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["public.h"] = guard("struct foo;");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("public.h"),
+                          physicalHeader("complete_private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferNameMatch) {
+  Inputs.Code = R"cpp(
+    #include "public_complete.h"
+    #include "test/foo.proto.h"
+  )cpp";
+  Inputs.ExtraFiles["public_complete.h"] = guard(R"cpp(
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["test/foo.proto.h"] = guard("struct foo;");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("test/foo.proto.h"),
+                          physicalHeader("public_complete.h")));
+}
+
+TEST_F(HeadersForSymbolTest, MainFile) {
+  Inputs.Code = R"cpp(
+    #include "public_complete.h"
+    struct foo;
+  )cpp";
+  Inputs.ExtraFiles["public_complete.h"] = guard(R"cpp(
+    struct foo {};
+  )cpp");
+  buildAST();
+  auto &SM = AST->sourceManager();
+  // FIXME: Symbols provided by main file should be treated specially.
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(physicalHeader("public_complete.h"),
+                          Header(SM.getFileEntryForID(SM.getMainFileID()))));
+}
+
+TEST_F(HeadersForSymbolTest, PreferExporterOfPrivate) {
+  Inputs.Code = R"cpp(
+    #include "private.h"
+    #include "exporter.h"
+  )cpp";
+  Inputs.ExtraFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private
+    struct foo {};
+  )cpp");
+  Inputs.ExtraFiles["exporter.h"] = guard(R"cpp(
+    #include "private.h" // IWYU pragma: export
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(), ElementsAre(physicalHeader("exporter.h"),
+                                           physicalHeader("private.h")));
+}
+
+TEST_F(HeadersForSymbolTest, PreferPublicOverNameMatchOnPrivate) {
+  Inputs.Code = R"cpp(
+    #include "foo.h"
+  )cpp";
+  Inputs.ExtraFiles["foo.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "public.h"
+    struct foo {};
+  )cpp");
+  buildAST();
+  EXPECT_THAT(headersForFoo(),
+              ElementsAre(Header("\"public.h\""), physicalHeader("foo.h")));
+}
 } // namespace
 } // namespace clang::include_cleaner
Index: clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
===================================================================
--- clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
+++ clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang-include-cleaner/Analysis.h"
+#include "AnalysisInternal.h"
 #include "clang-include-cleaner/Record.h"
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/ASTContext.h"
@@ -365,7 +366,7 @@
     FileID MainFID = SM.getMainFileID();
     if (RefLoc.isValid()) {
       EXPECT_THAT(RefLoc, AllOf(expandedAt(MainFID, Main.point("expand"), &SM),
-                                 spelledAt(MainFID, Main.point("spell"), &SM)))
+                                spelledAt(MainFID, Main.point("spell"), &SM)))
           << T.Main;
     } else {
       EXPECT_THAT(Main.points(), testing::IsEmpty());
@@ -373,5 +374,17 @@
   }
 }
 
+TEST(Hints, Ordering) {
+  struct Tag {};
+  auto Hinted = [](Hint Hints) {
+    return clang::include_cleaner::Hinted<Tag>({}, Hints);
+  };
+  EXPECT_LT(Hinted(Hint::None), Hinted(Hint::CompleteSymbol));
+  EXPECT_LT(Hinted(Hint::CompleteSymbol), Hinted(Hint::PublicHeader));
+  EXPECT_LT(Hinted(Hint::PublicHeader), Hinted(Hint::PreferredHeader));
+  EXPECT_LT(Hinted(Hint::CompleteSymbol | Hint::PublicHeader),
+            Hinted(Hint::PreferredHeader));
+}
+
 } // namespace
 } // namespace clang::include_cleaner
Index: clang-tools-extra/include-cleaner/lib/Types.cpp
===================================================================
--- clang-tools-extra/include-cleaner/lib/Types.cpp
+++ clang-tools-extra/include-cleaner/lib/Types.cpp
@@ -106,4 +106,16 @@
   return Result;
 }
 
+bool Header::operator<(const Header &RHS) const {
+  if (kind() != RHS.kind())
+    return kind() < RHS.kind();
+  switch (kind()) {
+  case Header::Physical:
+    return physical() < RHS.physical();
+  case Header::Standard:
+    return standard().name() < RHS.standard().name();
+  case Header::Verbatim:
+    return verbatim() < RHS.verbatim();
+  }
+}
 } // namespace clang::include_cleaner
Index: clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
===================================================================
--- clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
+++ clang-tools-extra/include-cleaner/lib/LocateSymbol.cpp
@@ -7,10 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "AnalysisInternal.h"
+#include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
 #include <utility>
 #include <vector>
@@ -18,13 +22,36 @@
 namespace clang::include_cleaner {
 namespace {
 
-std::vector<SymbolLocation> locateDecl(const Decl &D) {
-  std::vector<SymbolLocation> Result;
+template <typename T> Hint completeIfDefinition(T *D) {
+  return D->isThisDeclarationADefinition() ? Hint::CompleteSymbol : Hint::None;
+}
+
+Hint declHints(const Decl *D) {
+  // Definition is only needed for classes and templates for completeness.
+  if (auto *TD = llvm::dyn_cast<TagDecl>(D))
+    return completeIfDefinition(TD);
+  else if (auto *CTD = llvm::dyn_cast<ClassTemplateDecl>(D))
+    return completeIfDefinition(CTD);
+  else if (auto *FTD = llvm::dyn_cast<FunctionTemplateDecl>(D))
+    return completeIfDefinition(FTD);
+  // Any other declaration is assumed usable.
+  return Hint::CompleteSymbol;
+}
+
+std::vector<Hinted<SymbolLocation>> locateDecl(const Decl &D) {
+  std::vector<Hinted<SymbolLocation>> Result;
   // FIXME: Should we also provide physical locations?
   if (auto SS = tooling::stdlib::Recognizer()(&D))
-    return {SymbolLocation(*SS)};
+    return {{*SS, Hint::CompleteSymbol}};
+  // FIXME: Signal foreign decls, e.g. a forward declaration not owned by a
+  // library. Some useful signals could be derived by checking the DeclContext.
+  // Most incidental forward decls look like:
+  //   namespace clang {
+  //   class SourceManager; // likely an incidental forward decl.
+  //   namespace my_own_ns {}
+  //   }
   for (auto *Redecl : D.redecls())
-    Result.push_back(Redecl->getLocation());
+    Result.push_back({Redecl->getLocation(), declHints(Redecl)});
   return Result;
 }
 
@@ -46,12 +73,12 @@
   llvm_unreachable("Unhandled Symbol kind");
 }
 
-std::vector<SymbolLocation> locateSymbol(const Symbol &S) {
+std::vector<Hinted<SymbolLocation>> locateSymbol(const Symbol &S) {
   switch (S.kind()) {
   case Symbol::Declaration:
     return locateDecl(S.declaration());
   case Symbol::Macro:
-    return {SymbolLocation(S.macro().Definition)};
+    return {{S.macro().Definition, Hint::CompleteSymbol}};
   }
   llvm_unreachable("Unknown Symbol::Kind enum");
 }
Index: clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
===================================================================
--- clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
+++ clang-tools-extra/include-cleaner/lib/HTMLReport.cpp
@@ -187,8 +187,7 @@
     // Duplicates logic from walkUsed(), which doesn't expose SymbolLocations.
     for (auto &Loc : locateSymbol(R.Sym))
       R.Locations.push_back(Loc);
-    for (const auto &Loc : R.Locations)
-      R.Headers.append(findHeaders(Loc, SM, PI));
+    R.Headers = headersForSymbol(R.Sym, SM, PI);
 
     for (const auto &H : R.Headers) {
       R.Includes.append(Includes.match(H));
@@ -205,7 +204,6 @@
                      R.Includes.end());
 
     if (!R.Headers.empty())
-      // FIXME: library should tell us which header to use.
       R.Insert = spellHeader(R.Headers.front());
   }
 
Index: clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
===================================================================
--- clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
+++ clang-tools-extra/include-cleaner/lib/FindHeaders.cpp
@@ -8,31 +8,100 @@
 
 #include "AnalysisInternal.h"
 #include "clang-include-cleaner/Record.h"
+#include "clang-include-cleaner/Types.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
+#include "clang/Basic/FileEntry.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
+#include "clang/Tooling/Inclusions/StandardLibrary.h"
+#include "llvm/ADT/BitmaskEnum.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <utility>
 
 namespace clang::include_cleaner {
+namespace {
+llvm::SmallVector<Hinted<Header>>
+applyHints(llvm::SmallVector<Hinted<Header>> Headers, Hint H) {
+  for (auto &Header : Headers)
+    Header.Hints |= H;
+  return Headers;
+}
+
+llvm::SmallVector<Header> ranked(llvm::SmallVector<Hinted<Header>> Headers) {
+  llvm::stable_sort(Headers,
+                    [](const Hinted<Header> &LHS, const Hinted<Header> &RHS) {
+                      return RHS < LHS;
+                    });
+  return llvm::SmallVector<Header>(Headers.begin(), Headers.end());
+}
+
+llvm::SmallVector<Hinted<Header>>
+nameMatch(llvm::StringRef DeclName, llvm::SmallVector<Hinted<Header>> Headers) {
+  for (auto &H : Headers) {
+    llvm::StringRef SpelledH;
+    switch (H.kind()) {
+    case Header::Physical:
+      SpelledH = H.physical()->getName();
+      break;
+    case Header::Standard:
+      SpelledH = H.standard().name();
+      break;
+    case Header::Verbatim:
+      SpelledH = H.verbatim();
+      break;
+    }
+    llvm::errs() << "Checking name match for: " << SpelledH;
+    SpelledH = SpelledH.trim("<>\"");
+    if (auto LastSlash = SpelledH.rfind('/'); LastSlash != SpelledH.npos)
+      SpelledH = SpelledH.drop_front(LastSlash + 1);
+    // Drop everything after first `.` (dot).
+    // foo.h -> foo
+    // foo.cu.h -> foo
+    SpelledH = SpelledH.substr(0, SpelledH.find('.'));
+    if (SpelledH.equals_insensitive(DeclName))
+      H.Hints |= Hint::PreferredHeader;
+    llvm::errs() << " looks like: " << SpelledH << ' ' << DeclName
+                 << static_cast<int>(H.Hints) << '\n';
+  }
+  return Headers;
+}
 
-llvm::SmallVector<Header> findHeaders(const SymbolLocation &Loc,
-                                      const SourceManager &SM,
-                                      const PragmaIncludes *PI) {
-  llvm::SmallVector<Header> Results;
+} // namespace
+
+llvm::SmallVector<Hinted<Header>> findHeaders(const SymbolLocation &Loc,
+                                              const SourceManager &SM,
+                                              const PragmaIncludes *PI) {
+  llvm::SmallVector<Hinted<Header>> Results;
   switch (Loc.kind()) {
   case SymbolLocation::Physical: {
     FileID FID = SM.getFileID(SM.getExpansionLoc(Loc.physical()));
     const FileEntry *FE = SM.getFileEntryForID(FID);
-    if (!PI) {
-      return FE ? llvm::SmallVector<Header>{Header(FE)}
-                : llvm::SmallVector<Header>();
-    }
+    if (!FE)
+      return {};
+    if (!PI)
+      return {{FE, Hint::PublicHeader}};
+    auto IsPublicHeader = [&PI](const FileEntry *FE) {
+      return (PI->isPrivate(FE) || !PI->isSelfContained(FE))
+                 ? Hint::None
+                 : Hint::PublicHeader;
+    };
     while (FE) {
-      Results.push_back(Header(FE));
+      Hint CurrentHints = IsPublicHeader(FE);
+      Results.emplace_back(FE, CurrentHints);
       // FIXME: compute transitive exporter headers.
       for (const auto *Export : PI->getExporters(FE, SM.getFileManager()))
-        Results.push_back(Header(Export));
+        Results.emplace_back(Export, IsPublicHeader(Export));
 
-      llvm::StringRef VerbatimSpelling = PI->getPublic(FE);
-      if (!VerbatimSpelling.empty()) {
-        Results.push_back(VerbatimSpelling);
+      if (auto Verbatim = PI->getPublic(FE); !Verbatim.empty()) {
+        Results.emplace_back(Verbatim,
+                             Hint::PublicHeader | Hint::PreferredHeader);
         break;
       }
       if (PI->isSelfContained(FE) || FID == SM.getMainFileID())
@@ -46,11 +115,51 @@
   }
   case SymbolLocation::Standard: {
     for (const auto &H : Loc.standard().headers())
-      Results.push_back(H);
+      Results.emplace_back(H, Hint::PreferredHeader | Hint::PublicHeader);
     return Results;
   }
   }
   llvm_unreachable("unhandled SymbolLocation kind!");
 }
 
+llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
+                                           const SourceManager &SM,
+                                           const PragmaIncludes *PI) {
+  llvm::SmallVector<Hinted<Header>> Headers;
+  for (auto &Loc : locateSymbol(S))
+    Headers.append(applyHints(findHeaders(Loc, SM, PI), Loc.Hints));
+  // If two Headers probably refer to the same file (e.g. Verbatim(foo.h) and
+  // Physical(/path/to/foo.h), we won't deduplicate them or merge their hints
+  llvm::stable_sort(
+      Headers, [](const Hinted<Header> &LHS, const Hinted<Header> &RHS) {
+        return static_cast<Header>(LHS) < static_cast<Header>(RHS);
+      });
+  auto *Write = Headers.begin();
+  for (auto *Read = Headers.begin(); Read != Headers.end(); ++Write) {
+    *Write = *Read++;
+    while (Read != Headers.end() &&
+           static_cast<Header>(*Write) == static_cast<Header>(*Read)) {
+      Write->Hints |= Read->Hints;
+      ++Read;
+    }
+  }
+  Headers.erase(Write, Headers.end());
+
+  // Add name match hints to deduplicated providers.
+  llvm::StringRef SymbolName;
+  switch (S.kind()) {
+  case Symbol::Declaration:
+    // Unnamed decls like operators and anonymous structs won't get any name
+    // match.
+    if (const auto *ND = llvm::dyn_cast<NamedDecl>(&S.declaration()))
+      if (auto *II = ND->getIdentifier())
+        SymbolName = II->getName();
+    break;
+  case Symbol::Macro:
+    SymbolName = S.macro().Name->getName();
+    break;
+  }
+  // FIXME: Introduce a MainFile header kind or signal and boost it.
+  return ranked(nameMatch(SymbolName, std::move(Headers)));
+}
 } // namespace clang::include_cleaner
Index: clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
===================================================================
--- clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
+++ clang-tools-extra/include-cleaner/lib/AnalysisInternal.h
@@ -25,6 +25,7 @@
 #include "clang-include-cleaner/Types.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include <variant>
 #include <vector>
@@ -80,11 +81,57 @@
 };
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const SymbolLocation &);
 
+/// Represents properties of a symbol provider.
+///
+/// Hints represents the properties of the edges touched when finding headers
+/// that satisfy an AST node (AST node => symbols => locations => headers).
+///
+/// Since there can be multiple paths from an AST node to same header, we need
+/// to merge hints. These hints are merged by taking the union of all the
+/// properties along all the paths, hence these are all expressed positively.
+///
+/// Hints are sorted in ascending order of relevance.
+enum class Hint : uint8_t {
+  None = 0x00,
+  /// Provides a generally-usable definition for the symbol. (e.g. a function
+  /// decl, or class definition and not a forward declaration of a template).
+  CompleteSymbol = 1 << 0,
+  /// Symbol is provided by a public file. Only absent in the cases where file
+  /// is explicitly marked as such, e.g. non self-contained or IWYU private
+  /// pragmas.
+  PublicHeader = 1 << 1,
+  /// Header providing the symbol is explicitly marked as preferred, e.g. with a
+  /// IWYU private pragma that points at this provider or header and symbol has
+  /// ~the same name.
+  PreferredHeader = 1 << 2,
+  LLVM_MARK_AS_BITMASK_ENUM(PreferredHeader),
+};
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+/// A wrapper to augment types with hints.
+template <typename T> struct Hinted : public T {
+  Hint Hints;
+  Hinted(T &&Wrapped, Hint H) : T(std::move(Wrapped)), Hints(H) {}
+
+  /// Since hints are sorted by relevance, use it directly.
+  bool operator<(const Hinted<T> &Other) const {
+    return static_cast<int>(Hints) < static_cast<int>(Other.Hints);
+  }
+};
+
 /// Finds the headers that provide the symbol location.
-// FIXME: expose signals
-llvm::SmallVector<Header> findHeaders(const SymbolLocation &Loc,
-                                      const SourceManager &SM,
-                                      const PragmaIncludes *PI);
+llvm::SmallVector<Hinted<Header>> findHeaders(const SymbolLocation &Loc,
+                                              const SourceManager &SM,
+                                              const PragmaIncludes *PI);
+
+/// A set of locations that provides the declaration.
+std::vector<Hinted<SymbolLocation>> locateSymbol(const Symbol &S);
+
+/// Gets all the providers for a symbol by traversing each location.
+/// Returned headers are sorted by relevance, first element is the most
+/// likely provider for the symbol.
+llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
+                                           const SourceManager &SM,
+                                           const PragmaIncludes *PI);
 
 /// Write an HTML summary of the analysis to the given stream.
 void writeHTMLReport(FileID File, const Includes &,
@@ -93,9 +140,6 @@
                      HeaderSearch &HS, PragmaIncludes *PI,
                      llvm::raw_ostream &OS);
 
-/// A set of locations that provides the declaration.
-std::vector<SymbolLocation> locateSymbol(const Symbol &S);
-
 } // namespace include_cleaner
 } // namespace clang
 
Index: clang-tools-extra/include-cleaner/lib/Analysis.cpp
===================================================================
--- clang-tools-extra/include-cleaner/lib/Analysis.cpp
+++ clang-tools-extra/include-cleaner/lib/Analysis.cpp
@@ -12,29 +12,19 @@
 #include "clang-include-cleaner/Types.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Format/Format.h"
 #include "clang/Lex/HeaderSearch.h"
 #include "clang/Tooling/Core/Replacement.h"
-#include "clang/Tooling/Inclusions/HeaderIncludes.h"
 #include "clang/Tooling/Inclusions/StandardLibrary.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace clang::include_cleaner {
 
-namespace {
-// Gets all the providers for a symbol by tarversing each location.
-llvm::SmallVector<Header> headersForSymbol(const Symbol &S,
-                                           const SourceManager &SM,
-                                           const PragmaIncludes *PI) {
-  llvm::SmallVector<Header> Headers;
-  for (auto &Loc : locateSymbol(S))
-    Headers.append(findHeaders(Loc, SM, PI));
-  return Headers;
-}
-} // namespace
-
 void walkUsed(llvm::ArrayRef<Decl *> ASTRoots,
               llvm::ArrayRef<SymbolReference> MacroRefs,
               const PragmaIncludes *PI, const SourceManager &SM,
@@ -55,7 +45,7 @@
     assert(MacroRef.Target.kind() == Symbol::Macro);
     if (!SM.isWrittenInMainFile(SM.getSpellingLoc(MacroRef.RefLocation)))
       continue;
-    CB(MacroRef, findHeaders(MacroRef.Target.macro().Definition, SM, PI));
+    CB(MacroRef, headersForSymbol(MacroRef.Target, SM, PI));
   }
 }
 
Index: clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
===================================================================
--- clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
+++ clang-tools-extra/include-cleaner/include/clang-include-cleaner/Types.h
@@ -29,6 +29,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringMap.h"
 #include <memory>
+#include <utility>
 #include <vector>
 
 namespace llvm {
@@ -117,6 +118,7 @@
 
   Kind kind() const { return static_cast<Kind>(Storage.index()); }
   bool operator==(const Header &RHS) const { return Storage == RHS.Storage; }
+  bool operator<(const Header &RHS) const;
 
   const FileEntry *physical() const { return std::get<Physical>(Storage); }
   tooling::stdlib::Header standard() const {
@@ -127,6 +129,10 @@
 private:
   // Order must match Kind enum!
   std::variant<const FileEntry *, tooling::stdlib::Header, StringRef> Storage;
+
+  Header(std::in_place_t, decltype(Storage) Sentinel)
+      : Storage(std::move(Sentinel)) {}
+  friend llvm::DenseMapInfo<Header>;
 };
 llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Header &);
 
@@ -202,6 +208,23 @@
     return Base::isEqual(LHS.Definition, RHS.Definition);
   }
 };
+template <> struct DenseMapInfo<clang::include_cleaner::Header> {
+  using Outer = clang::include_cleaner::Header;
+  using Base = DenseMapInfo<decltype(Outer::Storage)>;
+
+  static inline Outer getEmptyKey() {
+    return {std::in_place, Base::getEmptyKey()};
+  }
+  static inline Outer getTombstoneKey() {
+    return {std::in_place, Base::getTombstoneKey()};
+  }
+  static unsigned getHashValue(const Outer &Val) {
+    return Base::getHashValue(Val.Storage);
+  }
+  static bool isEqual(const Outer &LHS, const Outer &RHS) {
+    return Base::isEqual(LHS.Storage, RHS.Storage);
+  }
+};
 } // namespace llvm
 
 #endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to