VitaNuo updated this revision to Diff 509015.
VitaNuo added a comment.

Remove redundant code.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D147044/new/

https://reviews.llvm.org/D147044

Files:
  clang-tools-extra/clangd/XRefs.cpp
  clang-tools-extra/clangd/unittests/XRefsTests.cpp

Index: clang-tools-extra/clangd/unittests/XRefsTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/XRefsTests.cpp
+++ clang-tools-extra/clangd/unittests/XRefsTests.cpp
@@ -5,8 +5,8 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-#include "Annotations.h"
 #include "AST.h"
+#include "Annotations.h"
 #include "ParsedAST.h"
 #include "Protocol.h"
 #include "SourceCode.h"
@@ -43,6 +43,10 @@
 using ::testing::UnorderedElementsAreArray;
 using ::testing::UnorderedPointwise;
 
+std::string guard(llvm::StringRef Code) {
+  return "#pragma once\n" + Code.str();
+}
+
 MATCHER_P2(FileRange, File, Range, "") {
   return Location{URIForFile::canonicalize(File, testRoot()), Range} == arg;
 }
@@ -1876,8 +1880,8 @@
     ASSERT_GT(A.points().size(), 0u) << Case;
     for (auto Pos : A.points())
       EXPECT_THAT(findType(AST, Pos),
-                  ElementsAre(
-                    sym("Target", HeaderA.range("Target"), HeaderA.range("Target"))))
+                  ElementsAre(sym("Target", HeaderA.range("Target"),
+                                  HeaderA.range("Target"))))
           << Case;
   }
 
@@ -1888,11 +1892,12 @@
     TU.Code = A.code().str();
     ParsedAST AST = TU.build();
 
-    EXPECT_THAT(findType(AST, A.point()),
-                UnorderedElementsAre(
-                  sym("Target", HeaderA.range("Target"), HeaderA.range("Target")),
-                  sym("smart_ptr", HeaderA.range("smart_ptr"), HeaderA.range("smart_ptr"))
-                ))
+    EXPECT_THAT(
+        findType(AST, A.point()),
+        UnorderedElementsAre(
+            sym("Target", HeaderA.range("Target"), HeaderA.range("Target")),
+            sym("smart_ptr", HeaderA.range("smart_ptr"),
+                HeaderA.range("smart_ptr"))))
         << Case;
   }
 }
@@ -1901,6 +1906,25 @@
   Annotations T(Test);
   auto TU = TestTU::withCode(T.code());
   TU.ExtraArgs.push_back("-std=c++20");
+  TU.AdditionalFiles["bar.h"] = guard(R"cpp(
+    #define BAR 5
+    int bar1();
+    int bar2();
+    class Bar {};            
+  )cpp");
+  TU.AdditionalFiles["private.h"] = guard(R"cpp(
+    // IWYU pragma: private, include "public.h"
+    int foo(); 
+  )cpp");
+  TU.AdditionalFiles["public.h"] = guard("");
+  TU.AdditionalFiles["system/vector"] = guard(R"cpp(
+    namespace std {
+      template<typename>
+      class vector{};
+    }
+  )cpp");
+  TU.AdditionalFiles["forward.h"] = guard("class Bar;");
+  TU.ExtraArgs.push_back("-isystem" + testPath("system"));
 
   auto AST = TU.build();
   std::vector<Matcher<ReferencesResult::Reference>> ExpectedLocations;
@@ -2293,6 +2317,42 @@
     checkFindRefs(Test);
 }
 
+TEST(FindReferences, UsedSymbolsFromInclude) {
+  const char *Tests[] = {
+      R"cpp([[#include ^"bar.h"]]
+        int fstBar = [[bar1]]();
+        int sndBar = [[bar2]]();
+        [[Bar]] bar;
+        int macroBar = [[BAR]];
+      )cpp",
+
+      R"cpp([[#in^clude <vector>]]
+        std::[[vector]]<int> vec;
+      )cpp",
+
+      R"cpp([[#in^clude "public.h"]]
+        #include "private.h"
+        int fooVar = [[foo]]();
+      )cpp",
+
+      R"cpp(#include "bar.h"
+        #include "for^ward.h"
+        Bar *x;
+      )cpp",
+
+      R"cpp([[#include "b^ar.h"]]
+        #define DEF(X) const Bar *X
+        [[DEF]](a);
+      )cpp",
+
+      R"cpp([[#in^clude "bar.h"]]
+        #define BAZ(X) const X x
+        BAZ([[Bar]]);
+      )cpp"};
+  for (const char *Test : Tests)
+    checkFindRefs(Test);
+}
+
 TEST(FindReferences, NeedsIndexForSymbols) {
   const char *Header = "int foo();";
   Annotations Main("int main() { [[f^oo]](); }");
Index: clang-tools-extra/clangd/XRefs.cpp
===================================================================
--- clang-tools-extra/clangd/XRefs.cpp
+++ clang-tools-extra/clangd/XRefs.cpp
@@ -10,12 +10,15 @@
 #include "FindSymbols.h"
 #include "FindTarget.h"
 #include "HeuristicResolver.h"
+#include "IncludeCleaner.h"
 #include "ParsedAST.h"
 #include "Protocol.h"
 #include "Quality.h"
 #include "Selection.h"
 #include "SourceCode.h"
 #include "URI.h"
+#include "clang-include-cleaner/Analysis.h"
+#include "clang-include-cleaner/Types.h"
 #include "index/Index.h"
 #include "index/Merge.h"
 #include "index/Relation.h"
@@ -48,6 +51,7 @@
 #include "clang/Index/IndexingAction.h"
 #include "clang/Index/IndexingOptions.h"
 #include "clang/Index/USRGeneration.h"
+#include "clang/Lex/Lexer.h"
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
@@ -61,6 +65,7 @@
 #include "llvm/Support/Path.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
+#include <string>
 #include <vector>
 
 namespace clang {
@@ -1324,6 +1329,59 @@
     return {};
   }
 
+  const auto &Includes = AST.getIncludeStructure().MainFileIncludes;
+  const auto &ConvertedMainFileIncludes =
+      convertIncludes(SM, Includes);
+  for (auto &Inc : Includes) {
+    if (Inc.HashLine != Pos.line)
+      continue;
+
+    const auto &ReferencedInclude = convertIncludes(SM, Inc);
+    include_cleaner::walkUsed(
+        AST.getLocalTopLevelDecls(), collectMacroReferences(AST),
+        AST.getPragmaIncludes(), SM,
+        [&](const include_cleaner::SymbolReference &Ref,
+            llvm::ArrayRef<include_cleaner::Header> Providers) {
+          if (Ref.RT != include_cleaner::RefType::Explicit)
+            return;
+
+          const auto &Loc = SM.getFileLoc(Ref.RefLocation);
+          for (const auto &H : Providers) {
+            const auto &MatchingIncludes = ConvertedMainFileIncludes.match(H);
+            // No match for this provider in the main file.
+            if (MatchingIncludes.empty())
+              continue;
+
+            // Check if the referenced include matches this provider.
+            if (!ReferencedInclude.match(H).empty()) {
+              ReferencesResult::Reference Result;
+              auto TokLen =
+                  Lexer::MeasureTokenLength(Loc, SM, AST.getLangOpts());
+              Result.Loc.range =
+                  halfOpenToRange(SM, CharSourceRange::getCharRange(
+                                          Loc, Loc.getLocWithOffset(TokLen)));
+              Result.Loc.uri = URIMainFile;
+              Results.References.push_back(std::move(Result));
+            }
+
+            // Don't look for rest of the providers once we've found a match
+            // in the main file.
+            return;
+          }
+        });
+    if (Results.References.empty())
+      return {};
+
+    // Add the #include line to the references list.
+    auto IncludeLen =
+        std::string{"#include"}.length() + Inc.Written.length() + 1;
+    ReferencesResult::Reference Result;
+    Result.Loc.range = clangd::Range{Position{Inc.HashLine, 0},
+                                     Position{Inc.HashLine, (int)IncludeLen}};
+    Result.Loc.uri = URIMainFile;
+    Results.References.push_back(std::move(Result));
+  }
+
   llvm::DenseSet<SymbolID> IDsToQuery, OverriddenMethods;
 
   const auto *IdentifierAtCursor =
@@ -1944,15 +2002,15 @@
   return QualType();
 }
 
-// Given a type targeted by the cursor, return one or more types that are more interesting
-// to target.
-static void unwrapFindType(
-    QualType T, const HeuristicResolver* H, llvm::SmallVector<QualType>& Out) {
+// Given a type targeted by the cursor, return one or more types that are more
+// interesting to target.
+static void unwrapFindType(QualType T, const HeuristicResolver *H,
+                           llvm::SmallVector<QualType> &Out) {
   if (T.isNull())
     return;
 
   // If there's a specific type alias, point at that rather than unwrapping.
-  if (const auto* TDT = T->getAs<TypedefType>())
+  if (const auto *TDT = T->getAs<TypedefType>())
     return Out.push_back(QualType(TDT, 0));
 
   // Pointers etc => pointee type.
@@ -1968,30 +2026,31 @@
     return unwrapFindType(FT->getReturnType(), H, Out);
   if (auto *CRD = T->getAsCXXRecordDecl()) {
     if (CRD->isLambda())
-      return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H, Out);
+      return unwrapFindType(CRD->getLambdaCallOperator()->getReturnType(), H,
+                            Out);
     // FIXME: more cases we'd prefer the return type of the call operator?
     //        std::function etc?
   }
 
   // For smart pointer types, add the underlying type
   if (H)
-    if (const auto* PointeeType = H->getPointeeType(T.getNonReferenceType().getTypePtr())) {
-        unwrapFindType(QualType(PointeeType, 0), H, Out);
-        return Out.push_back(T);
+    if (const auto *PointeeType =
+            H->getPointeeType(T.getNonReferenceType().getTypePtr())) {
+      unwrapFindType(QualType(PointeeType, 0), H, Out);
+      return Out.push_back(T);
     }
 
   return Out.push_back(T);
 }
 
 // Convenience overload, to allow calling this without the out-parameter
-static llvm::SmallVector<QualType> unwrapFindType(
-    QualType T, const HeuristicResolver* H) {
-    llvm::SmallVector<QualType> Result;
-    unwrapFindType(T, H, Result);
-    return Result;
+static llvm::SmallVector<QualType> unwrapFindType(QualType T,
+                                                  const HeuristicResolver *H) {
+  llvm::SmallVector<QualType> Result;
+  unwrapFindType(T, H, Result);
+  return Result;
 }
 
-
 std::vector<LocatedSymbol> findType(ParsedAST &AST, Position Pos) {
   const SourceManager &SM = AST.getSourceManager();
   auto Offset = positionToOffset(SM.getBufferData(SM.getMainFileID()), Pos);
@@ -2007,11 +2066,13 @@
     std::vector<LocatedSymbol> LocatedSymbols;
 
     // NOTE: unwrapFindType might return duplicates for something like
-    // unique_ptr<unique_ptr<T>>. Let's *not* remove them, because it gives you some
-    // information about the type you may have not known before
-    // (since unique_ptr<unique_ptr<T>> != unique_ptr<T>).
-    for (const QualType& Type : unwrapFindType(typeForNode(N), AST.getHeuristicResolver()))
-        llvm::copy(locateSymbolForType(AST, Type), std::back_inserter(LocatedSymbols));
+    // unique_ptr<unique_ptr<T>>. Let's *not* remove them, because it gives you
+    // some information about the type you may have not known before (since
+    // unique_ptr<unique_ptr<T>> != unique_ptr<T>).
+    for (const QualType &Type :
+         unwrapFindType(typeForNode(N), AST.getHeuristicResolver()))
+      llvm::copy(locateSymbolForType(AST, Type),
+                 std::back_inserter(LocatedSymbols));
 
     return LocatedSymbols;
   };
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to