sammccall created this revision.
sammccall added reviewers: nridge, Trass3r.
Herald added subscribers: usaxena95, kadircet, arphaman.
sammccall requested review of this revision.
Herald added subscribers: cfe-commits, MaskRay, ilya-biryukov.
Herald added a project: clang-tools-extra.
This takes a similar approach as b9b6938183e 
<https://reviews.llvm.org/rGb9b6938183e837e66ff7450fb2b8a73dce5889c0>, and 
shares some code.
The code sharing is limited as inlay hints wants to deduce the type of the
variable rather than the type of the `auto` per-se.

It drops support (in both places) for multiple instantiations yielding the same
type, as this is pretty rare and hard to build a nice API around.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D120258

Files:
  clang-tools-extra/clangd/AST.cpp
  clang-tools-extra/clangd/AST.h
  clang-tools-extra/clangd/InlayHints.cpp
  clang-tools-extra/clangd/unittests/ASTTests.cpp
  clang-tools-extra/clangd/unittests/InlayHintTests.cpp
  clang-tools-extra/clangd/unittests/TestTU.cpp

Index: clang-tools-extra/clangd/unittests/TestTU.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/TestTU.cpp
+++ clang-tools-extra/clangd/unittests/TestTU.cpp
@@ -247,7 +247,7 @@
   Visitor.F = Filter;
   Visitor.TraverseDecl(AST.getASTContext().getTranslationUnitDecl());
   if (Visitor.Decls.size() != 1) {
-    llvm::errs() << Visitor.Decls.size() << " symbols matched.";
+    llvm::errs() << Visitor.Decls.size() << " symbols matched.\n";
     assert(Visitor.Decls.size() == 1);
   }
   return *Visitor.Decls.front();
Index: clang-tools-extra/clangd/unittests/InlayHintTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/InlayHintTests.cpp
+++ clang-tools-extra/clangd/unittests/InlayHintTests.cpp
@@ -676,6 +676,15 @@
                   ExpectedHint{": int", "var"});
 }
 
+TEST(TypeHints, SinglyInstantiatedTemplate) {
+  assertTypeHints(R"cpp(
+    auto $lambda[[x]] = [](auto *$param[[y]]) { return 42; };
+    int m = x("foo");
+  )cpp",
+                  ExpectedHint{": (lambda)", "lambda"},
+                  ExpectedHint{": const char *", "param"});
+}
+
 TEST(DesignatorHints, Basic) {
   assertDesignatorHints(R"cpp(
     struct S { int x, y, z; };
Index: clang-tools-extra/clangd/unittests/ASTTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/ASTTests.cpp
+++ clang-tools-extra/clangd/unittests/ASTTests.cpp
@@ -30,6 +30,7 @@
 namespace {
 using testing::Contains;
 using testing::Each;
+using testing::IsEmpty;
 
 TEST(GetDeducedType, KwAutoKwDecltypeExpansion) {
   struct Test {
@@ -192,12 +193,12 @@
           R"cpp(
             // Generic lambda instantiated twice, matching deduction.
             struct Foo{};
-            using Bar = Foo;
             auto Generic = [](^auto x, auto y) { return 0; };
-            int m = Generic(Bar{}, "one");
+            int m = Generic(Foo{}, "one");
             int n = Generic(Foo{}, 2);
           )cpp",
-          "struct Foo",
+          // No deduction although both instantiations yield the same result :-(
+          nullptr,
       },
       {
           R"cpp(
@@ -253,6 +254,119 @@
   }
 }
 
+TEST(ClangdAST, GetOnlyInstantiation) {
+  struct {
+    const char *Code;
+    llvm::StringLiteral NodeType;
+    const char *Name;
+  } Cases[] = {
+      {
+          R"cpp(
+            template <typename> class X {};
+            X<int> x;
+          )cpp",
+          "CXXRecord",
+          "template<> class X<int> {}",
+      },
+      {
+          R"cpp(
+            template <typename T> T X = T{};
+            int y = X<char>;
+          )cpp",
+          "Var",
+          // VarTemplateSpecializationDecl doesn't print as template<>...
+          "char X = char{}",
+      },
+      {
+          R"cpp(
+            template <typename T> int X(T) { return 42; }
+            int y = X("text");
+          )cpp",
+          "Function",
+          "template<> int X<const char *>(const char *)",
+      },
+      {
+          R"cpp(
+            int X(auto *x) { return 42; }
+            int y = X("text");
+          )cpp",
+          "Function",
+          "template<> int X<const char>(const char *x)",
+      },
+  };
+
+  for (const auto &Case : Cases) {
+    SCOPED_TRACE(Case.Code);
+    auto TU = TestTU::withCode(Case.Code);
+    TU.ExtraArgs.push_back("-std=c++20");
+    auto AST = TU.build();
+    AST.getASTContext().getTranslationUnitDecl()->dump();
+    PrintingPolicy PP = AST.getASTContext().getPrintingPolicy();
+    PP.TerseOutput = true;
+    std::string Name;
+    if (auto *Result = getOnlyInstantiation(
+            const_cast<NamedDecl *>(&findDecl(AST, [&](const NamedDecl &D) {
+              return D.getDescribedTemplate() != nullptr &&
+                     D.getDeclKindName() == Case.NodeType;
+            })))) {
+      llvm::raw_string_ostream OS(Name);
+      Result->print(OS, PP);
+    }
+
+    if (Case.Name)
+      EXPECT_EQ(Case.Name, Name);
+    else
+      EXPECT_THAT(Name, IsEmpty());
+  }
+}
+
+TEST(ClangdAST, GetContainedAutoParamType) {
+  auto TU = TestTU::withCode(R"cpp(
+    int withAuto(
+       auto a,
+       auto *b,
+       const auto *c,
+       auto &&d,
+       auto *&e,
+       auto (*f)(int)
+    ){};
+
+    int withoutAuto(
+      int a,
+      int *b,
+      const int *c,
+      int &&d,
+      int *&e,
+      int (*f)(int)
+    ){};
+  )cpp");
+  TU.ExtraArgs.push_back("-std=c++20");
+  auto AST = TU.build();
+
+  const auto &WithAuto =
+      llvm::cast<FunctionTemplateDecl>(findDecl(AST, "withAuto"));
+  auto ParamsWithAuto = WithAuto.getTemplatedDecl()->parameters();
+  auto *TemplateParamsWithAuto = WithAuto.getTemplateParameters();
+  ASSERT_EQ(ParamsWithAuto.size(), TemplateParamsWithAuto->size());
+
+  for (unsigned I = 0; I < ParamsWithAuto.size(); ++I) {
+    SCOPED_TRACE(ParamsWithAuto[I]->getNameAsString());
+    auto Loc = getContainedAutoParamType(
+        ParamsWithAuto[I]->getTypeSourceInfo()->getTypeLoc());
+    ASSERT_TRUE(Loc.hasValue() && !Loc->isNull());
+    EXPECT_EQ(Loc->getTypePtr()->getDecl(),
+              TemplateParamsWithAuto->getParam(I));
+  }
+
+  const auto &WithoutAuto =
+      llvm::cast<FunctionDecl>(findDecl(AST, "withoutAuto"));
+  for (auto *ParamWithoutAuto : WithoutAuto.parameters()) {
+    ASSERT_FALSE(getContainedAutoParamType(
+                     ParamWithoutAuto->getTypeSourceInfo()->getTypeLoc())
+                     .hasValue());
+  }
+}
+
 TEST(ClangdAST, GetQualification) {
   // Tries to insert the decl `Foo` into position of each decl named `insert`.
   // This is done to get an appropriate DeclContext for the insertion location.
Index: clang-tools-extra/clangd/InlayHints.cpp
===================================================================
--- clang-tools-extra/clangd/InlayHints.cpp
+++ clang-tools-extra/clangd/InlayHints.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "InlayHints.h"
+#include "AST.h"
 #include "Config.h"
 #include "HeuristicResolver.h"
 #include "ParsedAST.h"
@@ -286,9 +287,42 @@
         addTypeHint(D->getLocation(), D->getType(), /*Prefix=*/": ");
       }
     }
+
+    // Handle templates like `int foo(auto x)` with exactly one instantiation.
+    if (auto *PVD = llvm::dyn_cast<ParmVarDecl>(D)) {
+      if (PVD->getType()->isDependentType() &&
+          getContainedAutoParamType(D->getTypeSourceInfo()->getTypeLoc())
+              .hasValue()) {
+        if (auto *IPVD = getOnlyParamInstantiation(PVD))
+          addTypeHint(D->getLocation(), IPVD->getType(), /*Prefix=*/": ");
+      }
+    }
+
     return true;
   }
 
+  ParmVarDecl *getOnlyParamInstantiation(ParmVarDecl *D) {
+    if (D->getType()->containsUnexpandedParameterPack())
+      return nullptr;
+    auto *TemplateFunction = llvm::dyn_cast<FunctionDecl>(D->getDeclContext());
+    if (!TemplateFunction)
+      return nullptr;
+    auto *InstantiatedFunction = llvm::dyn_cast_or_null<FunctionDecl>(
+        getOnlyInstantiation(TemplateFunction));
+    if (!InstantiatedFunction)
+      return nullptr;
+
+    unsigned ParamIdx = 0;
+    while (ParamIdx < TemplateFunction->getNumParams() &&
+           D != TemplateFunction->getParamDecl(ParamIdx))
+      ++ParamIdx;
+    assert(ParamIdx < TemplateFunction->getNumParams() &&
+           "Couldn't find param in list?");
+    assert(ParamIdx < InstantiatedFunction->getNumParams() &&
+           "Instantiated function has fewer (non-pack) parameters?");
+    return InstantiatedFunction->getParamDecl(ParamIdx);
+  }
+
   bool VisitInitListExpr(InitListExpr *Syn) {
     // We receive the syntactic form here (shouldVisitImplicitCode() is false).
     // This is the one we will ultimately attach designators to.
Index: clang-tools-extra/clangd/AST.h
===================================================================
--- clang-tools-extra/clangd/AST.h
+++ clang-tools-extra/clangd/AST.h
@@ -17,9 +17,9 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/TypeLoc.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Lex/MacroInfo.h"
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include <string>
 #include <vector>
@@ -128,6 +128,17 @@
 /// If the type is an undeduced auto, returns the type itself.
 llvm::Optional<QualType> getDeducedType(ASTContext &, SourceLocation Loc);
 
+// Find the abbreviated-function-template `auto` within a type.
+// Similar to getContainedAutoTypeLoc, but these `auto`s are
+// TemplateTypeParmTypes for implicit TTPs, instead of AutoTypes.
+// Also we don't look very hard, just stripping const, references, pointers.
+// FIXME: handle more type patterns.
+llvm::Optional<TemplateTypeParmTypeLoc> getContainedAutoParamType(TypeLoc TL);
+
+// If TemplatedDecl is the generic body of a template, and the template has
+// exactly one visible instantiation, return the instantiated body.
+NamedDecl *getOnlyInstantiation(NamedDecl *TemplatedDecl);
+
 /// Return attributes attached directly to a node.
 std::vector<const Attr *> getAttributes(const DynTypedNode &);
 
Index: clang-tools-extra/clangd/AST.cpp
===================================================================
--- clang-tools-extra/clangd/AST.cpp
+++ clang-tools-extra/clangd/AST.cpp
@@ -487,76 +487,42 @@
   }
 
   // Handle functions/lambdas with `auto` typed parameters.
-  // We'll examine visible specializations and see if they yield a unique type.
+  // We deduce the type if there's exactly one instantiation visible.
   bool VisitParmVarDecl(ParmVarDecl *PVD) {
     if (!PVD->getType()->isDependentType())
       return true;
     // 'auto' here does not name an AutoType, but an implicit template param.
-    TemplateTypeParmTypeLoc Auto =
-        findContainedAutoTTPLoc(PVD->getTypeSourceInfo()->getTypeLoc());
-    if (Auto.isNull() || Auto.getNameLoc() != SearchedLocation)
+    llvm::Optional<TemplateTypeParmTypeLoc> Auto =
+        getContainedAutoParamType(PVD->getTypeSourceInfo()->getTypeLoc());
+    if (!Auto || Auto->getNameLoc() != SearchedLocation)
       return true;
+
     // We expect the TTP to be attached to this function template.
     // Find the template and the param index.
-    auto *FD = llvm::dyn_cast<FunctionDecl>(PVD->getDeclContext());
-    if (!FD)
+    auto *Templated = llvm::dyn_cast<FunctionDecl>(PVD->getDeclContext());
+    if (!Templated)
       return true;
-    auto *FTD = FD->getDescribedFunctionTemplate();
+    auto *FTD = Templated->getDescribedFunctionTemplate();
     if (!FTD)
       return true;
-    int ParamIndex = paramIndex(*FTD, *Auto.getDecl());
+    int ParamIndex = paramIndex(*FTD, *Auto->getDecl());
     if (ParamIndex < 0) {
       assert(false && "auto TTP is not from enclosing function?");
       return true;
     }
 
-    // Now determine the unique type arg among the implicit specializations.
-    const ASTContext &Ctx = PVD->getASTContext();
-    QualType UniqueType;
-    CanQualType CanUniqueType;
-    for (const FunctionDecl *Spec : FTD->specializations()) {
-      // Meaning `auto` is a bit overloaded if the function is specialized.
-      if (Spec->getTemplateSpecializationKind() == TSK_ExplicitSpecialization)
-        return true;
-      // Find the type for this specialization.
-      const auto *Args = Spec->getTemplateSpecializationArgs();
-      if (Args->size() != FTD->getTemplateParameters()->size())
-        continue; // no weird variadic stuff
-      QualType SpecType = Args->get(ParamIndex).getAsType();
-      if (SpecType.isNull())
-        continue;
-
-      // Deduced types need only be *canonically* equal.
-      CanQualType CanSpecType = Ctx.getCanonicalType(SpecType);
-      if (CanUniqueType.isNull()) {
-        CanUniqueType = CanSpecType;
-        UniqueType = SpecType;
-        continue;
-      }
-      if (CanUniqueType != CanSpecType)
-        return true; // deduced type is not unique
-    }
-    DeducedType = UniqueType;
+    // Now find the instantiation and the deduced template type arg.
+    auto *Instantiation =
+        llvm::dyn_cast_or_null<FunctionDecl>(getOnlyInstantiation(Templated));
+    if (!Instantiation)
+      return true;
+    const auto *Args = Instantiation->getTemplateSpecializationArgs();
+    if (Args->size() != FTD->getTemplateParameters()->size())
+      return true; // no weird variadic stuff
+    DeducedType = Args->get(ParamIndex).getAsType();
     return true;
   }
 
-  // Find the abbreviated-function-template `auto` within a type.
-  // Similar to getContainedAutoTypeLoc, but these `auto`s are
-  // TemplateTypeParmTypes for implicit TTPs, instead of AutoTypes.
-  // Also we don't look very hard, just stripping const, references, pointers.
-  // FIXME: handle more types: vector<auto>?
-  static TemplateTypeParmTypeLoc findContainedAutoTTPLoc(TypeLoc TL) {
-    if (auto QTL = TL.getAs<QualifiedTypeLoc>())
-      return findContainedAutoTTPLoc(QTL.getUnqualifiedLoc());
-    if (llvm::isa<PointerType, ReferenceType>(TL.getTypePtr()))
-      return findContainedAutoTTPLoc(TL.getNextTypeLoc());
-    if (auto TTPTL = TL.getAs<TemplateTypeParmTypeLoc>()) {
-      if (TTPTL.getTypePtr()->getDecl()->isImplicit())
-        return TTPTL;
-    }
-    return {};
-  }
-
   static int paramIndex(const TemplateDecl &TD, NamedDecl &Param) {
     unsigned I = 0;
     for (auto *ND : *TD.getTemplateParameters()) {
@@ -582,6 +548,44 @@
   return V.DeducedType;
 }
 
+llvm::Optional<TemplateTypeParmTypeLoc> getContainedAutoParamType(TypeLoc TL) {
+  if (auto QTL = TL.getAs<QualifiedTypeLoc>())
+    return getContainedAutoParamType(QTL.getUnqualifiedLoc());
+  if (llvm::isa<PointerType, ReferenceType, ParenType>(TL.getTypePtr()))
+    return getContainedAutoParamType(TL.getNextTypeLoc());
+  if (auto FTL = TL.getAs<FunctionTypeLoc>())
+    return getContainedAutoParamType(FTL.getReturnLoc());
+  if (auto TTPTL = TL.getAs<TemplateTypeParmTypeLoc>()) {
+    if (TTPTL.getTypePtr()->getDecl()->isImplicit())
+      return TTPTL;
+  }
+  return {};
+}
+
+NamedDecl *getOnlyInstantiation(NamedDecl *TemplatedDecl) {
+  TemplateDecl *TD = TemplatedDecl->getDescribedTemplate();
+  if (!TD)
+    return nullptr;
+
+  NamedDecl *Only = nullptr;
+#define TEMPLATE_TYPE(SomeTemplateDecl)                                        \
+  if (auto *STD = llvm::dyn_cast<SomeTemplateDecl>(TD)) {                      \
+    for (auto *Spec : STD->specializations()) {                                \
+      if (Spec->getTemplateSpecializationKind() == TSK_ExplicitSpecialization) \
+        continue;                                                              \
+      if (Only != nullptr)                                                     \
+        return nullptr;                                                        \
+      Only = Spec;                                                             \
+    }                                                                          \
+  }
+  TEMPLATE_TYPE(FunctionTemplateDecl);
+  TEMPLATE_TYPE(VarTemplateDecl);
+  TEMPLATE_TYPE(ClassTemplateDecl);
+#undef TEMPLATE_TYPE
+
+  return Only;
+}
+
 std::vector<const Attr *> getAttributes(const DynTypedNode &N) {
   std::vector<const Attr *> Result;
   if (const auto *TL = N.get<TypeLoc>()) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to