sepavloff updated this revision to Diff 63386.
sepavloff added a comment.

Updated patch


http://reviews.llvm.org/D16579

Files:
  include/clang/AST/TypeInstantiationMatcher.h
  include/clang/AST/TypeMatcher.h
  include/clang/Basic/DiagnosticGroups.td
  include/clang/Basic/DiagnosticSemaKinds.td
  include/clang/Sema/Sema.h
  lib/Sema/Sema.cpp
  lib/Sema/SemaChecking.cpp
  lib/Sema/SemaDecl.cpp
  lib/Sema/SemaDeclCXX.cpp
  lib/Sema/SemaTemplate.cpp
  test/CXX/drs/dr3xx.cpp
  test/CXX/drs/dr5xx.cpp
  test/SemaCXX/friend.cpp
  test/SemaCXX/overload-call.cpp

Index: test/SemaCXX/overload-call.cpp
===================================================================
--- test/SemaCXX/overload-call.cpp
+++ test/SemaCXX/overload-call.cpp
@@ -574,13 +574,17 @@
   // Ensure that overload resolution attempts to complete argument types when
   // performing ADL.
   template<typename T> struct S {
-    friend int f(const S&);
+    friend int f(const S&);  // expected-warning{{friend declaration 'IncompleteArg::f' depends on template parameter but is not a function template}}
+                             // expected-note@-1{{declare function outside class template to suppress this warning}}
+                             // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
   };
   extern S<int> s;
   int k = f(s);
 
   template<typename T> struct Op {
-    friend bool operator==(const Op &, const Op &);
+    friend bool operator==(const Op &, const Op &);  // expected-warning{{friend declaration 'IncompleteArg::operator==' depends on template parameter but is not a function template}}
+                             // expected-note@-1{{declare function outside class template to suppress this warning}}
+                             // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
   };
   extern Op<char> op;
   bool b = op == op;
Index: test/SemaCXX/friend.cpp
===================================================================
--- test/SemaCXX/friend.cpp
+++ test/SemaCXX/friend.cpp
@@ -379,3 +379,111 @@
     X *q = p;
   }
 }
+
+
+template<typename T> void pr23342_func(T x);
+template<typename T>
+struct pr23342_C1 {
+  friend void pr23342_func<>(T x);
+  friend bool func(T x);  // expected-warning{{friend declaration 'func' depends on template parameter but is not a function template}}
+                          // expected-note@-1{{declare function outside class template to suppress this warning}}
+                          // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+  friend bool func2(int x);
+  template<typename T2> friend bool func3(T2 x);
+  friend T func4();  // expected-warning{{friend declaration 'func4' depends on template parameter but is not a function template}}
+                     // expected-note@-1{{declare function outside class template to suppress this warning}}
+                     // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+
+namespace pr23342 {
+
+template<typename T>
+struct C1 {
+  friend void pr23342_func<>(T x);
+  friend bool func(T x);  // expected-warning{{friend declaration 'pr23342::func' depends on template parameter but is not a function template}}
+                          // expected-note@-1{{declare function outside class template to suppress this warning}}
+                          // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+  friend bool func2(int x);
+  template<typename T2> friend bool func3(T2 x);
+  friend T func4();    // expected-warning{{friend declaration 'pr23342::func4' depends on template parameter but is not a function template}}
+                       // expected-note@-1{{declare function outside class template to suppress this warning}}
+                       // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+
+template <typename T>
+struct Arg {
+  friend bool operator==(const Arg& lhs, T rhs) {
+   return false;
+  }
+  friend bool operator!=(const Arg& lhs, T rhs);  // expected-warning{{friend declaration 'pr23342::operator!=' depends on template parameter but is not a function template}}
+                       // expected-note@-1{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+template <typename T>
+bool operator!=(const Arg<T>& lhs, T rhs) {
+  return true;
+}
+bool foo() {
+  Arg<int> arg;
+  return (arg == 42) || (arg != 42);
+}
+
+
+template<typename T> class C0 {
+  friend void func0(C0<T> &);  // expected-warning{{friend declaration 'pr23342::func0' depends on template parameter but is not a function template}}
+                               // expected-note@-1{{declare function outside class template to suppress this warning}}
+                               // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+
+template<typename T> class C0a {
+  friend void func0a(C0a<T> &);  // expected-warning{{friend declaration 'pr23342::func0a' depends on template parameter but is not a function template}}
+                                 // expected-note@-1{{declare function outside class template to suppress this warning}}
+                                 // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+void func0a(C0a<int> x);
+void func0a(C0a<int> *x);
+void func0a(const C0a<int> &x);
+int func0a(C0a<int> &x);
+
+template<typename T> class C0b {
+  friend void func0b(C0b<T> &x);
+};
+void func0b(C0b<int> &x);
+
+
+template<typename T> class C1a {
+  friend void func1a(T x, C1a<T> &y); // expected-warning{{friend declaration 'pr23342::func1a' depends on template parameter but is not a function template}}
+                                      // expected-note@-1{{declare function outside class template to suppress this warning}}
+                                      // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+void func1a(char x, C1a<int> &y);
+
+template<typename T> class C1b {
+  friend void func1b(T x, C1b<T> &y);
+};
+void func1b(int x, C1b<int> &y);
+
+
+template<typename T> class C2a {
+  friend void func2a(C2a<T> &); // expected-warning{{friend declaration 'pr23342::func2a' depends on template parameter but is not a function template}}
+                                // expected-note@-1{{declare function outside class template to suppress this warning}}
+                                // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+};
+template<typename T>
+void func2a(const C2a<T> &x);
+
+
+template<typename T> class C2b {
+  friend void func2b(C2b<T> &); // expected-warning{{friend declaration 'pr23342::func2b' depends on template parameter but is not a function template}}
+                                // expected-note@-1{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}};
+};
+template<typename T>
+void func2b(C2b<T> &x);
+
+template<typename T> class C2c;
+template<typename T>
+void func2c(C2c<T> &x);
+template<typename T> class C2c {
+  friend void func2c<>(C2c<T> &);
+};
+
+}
Index: test/CXX/drs/dr5xx.cpp
===================================================================
--- test/CXX/drs/dr5xx.cpp
+++ test/CXX/drs/dr5xx.cpp
@@ -581,8 +581,12 @@
 
 namespace dr557 { // dr557: yes
   template<typename T> struct S {
-    friend void f(S<T> *);
-    friend void g(S<S<T> > *);
+    friend void f(S<T> *);  // expected-warning{{friend declaration 'dr557::f' depends on template parameter but is not a function template}}
+                            // expected-note@-1{{declare function outside class template to suppress this warning}}
+                            // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
+    friend void g(S<S<T> > *); // expected-warning{{friend declaration 'dr557::g' depends on template parameter but is not a function template}}
+                               // expected-note@-1{{declare function outside class template to suppress this warning}}
+                               // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
   };
   void x(S<int> *p, S<S<int> > *q) {
     f(p);
Index: test/CXX/drs/dr3xx.cpp
===================================================================
--- test/CXX/drs/dr3xx.cpp
+++ test/CXX/drs/dr3xx.cpp
@@ -291,7 +291,9 @@
   template void g(N::A<0>::B<0>);
 
   namespace N {
-    template<typename> struct I { friend bool operator==(const I&, const I&); };
+    template<typename> struct I { friend bool operator==(const I&, const I&); };  // expected-warning{{friend declaration 'dr321::N::operator==' depends on template parameter but is not a function template}}
+                                      // expected-note@-1{{declare function outside class template to suppress this warning}}
+                                      // expected-note@-2{{to befriend a template specialization, make sure the function template has already been declared and use '<>'}}
   }
   N::I<int> i, j;
   bool x = i == j;
Index: lib/Sema/SemaTemplate.cpp
===================================================================
--- lib/Sema/SemaTemplate.cpp
+++ lib/Sema/SemaTemplate.cpp
@@ -2113,7 +2113,8 @@
 
 QualType Sema::CheckTemplateIdType(TemplateName Name,
                                    SourceLocation TemplateLoc,
-                                   TemplateArgumentListInfo &TemplateArgs) {
+                                   TemplateArgumentListInfo &TemplateArgs,
+                                   bool IsFriend) {
   DependentTemplateName *DTN
     = Name.getUnderlying().getAsDependentTemplateName();
   if (DTN && DTN->isIdentifier())
@@ -2203,7 +2204,7 @@
     // TODO: in theory this could be a simple hashtable lookup; most
     // changes to CurContext don't change the set of current
     // instantiations.
-    if (isa<ClassTemplateDecl>(Template)) {
+    if (isa<ClassTemplateDecl>(Template) && !IsFriend) {
       for (DeclContext *Ctx = CurContext; Ctx; Ctx = Ctx->getLookupParent()) {
         // If we get out to a namespace, we're done.
         if (Ctx->isFileContext()) break;
Index: lib/Sema/SemaDeclCXX.cpp
===================================================================
--- lib/Sema/SemaDeclCXX.cpp
+++ lib/Sema/SemaDeclCXX.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaInternal.h"
+#include "TreeTransform.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/ASTLambda.h"
Index: lib/Sema/SemaDecl.cpp
===================================================================
--- lib/Sema/SemaDecl.cpp
+++ lib/Sema/SemaDecl.cpp
@@ -8615,6 +8615,15 @@
     AddToScope = false;
   }
 
+  if (isFriend && !NewFD->isInvalidDecl()) {
+    if (TemplateParamLists.empty() && !DC->isRecord() &&
+        !isFunctionTemplateSpecialization &&
+        NewFD->getType()->isDependentType() &&
+        !NewFD->isThisDeclarationADefinition()) {
+      FriendsOfTemplates.push_back(NewFD);
+    }
+  }
+
   return NewFD;
 }
 
Index: lib/Sema/SemaChecking.cpp
===================================================================
--- lib/Sema/SemaChecking.cpp
+++ lib/Sema/SemaChecking.cpp
@@ -24,6 +24,7 @@
 #include "clang/AST/ExprOpenMP.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtObjC.h"
+#include "clang/AST/TypeInstantiationMatcher.h"
 #include "clang/Analysis/Analyses/FormatString.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/TargetBuiltins.h"
@@ -10922,3 +10923,123 @@
         << ArgumentExpr->getSourceRange()
         << TypeTagExpr->getSourceRange();
 }
+
+namespace {
+
+/// \brief Helper class used to check if a friend declaration may refer to
+/// another function declaration.
+///
+/// The class is used to compare two function declarations, one is a friend
+/// function declared in template class, the other is a function declared at
+/// file level. Both functions must have the same name, this class checks only
+/// function types.
+/*
+class FunctionMatcher : public TypeMatcher<FunctionMatcher> {
+  const Type *InstType;
+public:
+  FunctionMatcher() : InstType(nullptr) {}
+
+  bool VisitTemplateSpecializationType(const TemplateSpecializationType *T) {
+    if (T == InstType)
+      return true;
+    TemplateName TN = T->getTemplateName();
+    if (TN.getKind() != TemplateName::Template)
+      return false;
+    TemplateDecl *TemplD = TN.getAsTemplateDecl();
+    auto *ClassTD = dyn_cast<ClassTemplateDecl>(TemplD);
+    if (!ClassTD)
+      return false;
+    auto Params = ClassTD->getTemplateParameters();
+
+    if (CXXRecordDecl *IClassD = InstType->getAsCXXRecordDecl()) {
+      auto *IClassSD = dyn_cast<ClassTemplateSpecializationDecl>(IClassD);
+      if (!IClassSD)
+        return false;
+      ClassTemplateDecl *IClassTD = IClassSD->getSpecializedTemplate();
+      if (ClassTD->getCanonicalDecl() != IClassTD->getCanonicalDecl())
+        return false;
+      const TemplateArgumentList &IArgs = IClassSD->getTemplateArgs();
+      if (Params->size() != IArgs.size())
+        return false;
+      for (unsigned I = 0; I < Params->size(); ++I) {
+        TypeDecl *Param = cast<TypeDecl>(Params->getParam(I));
+        const TemplateArgument &IArg = IArgs.get(I);
+        QualType IArgT = IArg.getAsType();
+        if (!IArgT.isNull()) {
+          if (!match(Param->getTypeForDecl(), IArgT.getTypePtr()))
+            return false;
+        }
+      }
+      return true;
+    }
+
+    if (auto *ITST = dyn_cast<TemplateSpecializationType>(InstType)) {
+      TemplateName ITN = ITST->getTemplateName();
+      if (ITN.getKind() != TemplateName::Template)
+        return false;
+      TemplateDecl *ITemplD = ITN.getAsTemplateDecl();
+      auto *IClassTD = dyn_cast<ClassTemplateDecl>(ITemplD);
+      if (!IClassTD)
+        return false;
+      return IClassTD->getCanonicalDecl() == ClassTD->getCanonicalDecl();
+    }
+   return false;
+  }
+};
+*/
+
+/// \brief Given a friend function declaration checks if it might be misused.
+static void CheckDependentFriend(Sema &S, FunctionDecl *FriendD) {
+  if (FriendD->isInvalidDecl())
+    return;
+  LookupResult FRes(S, FriendD->getNameInfo(), Sema::LookupOrdinaryName);
+  if (S.LookupQualifiedName(FRes, FriendD->getDeclContext())) {
+    QualType FriendT = FriendD->getType().getCanonicalType();
+    // First check if there is suitable function template, this is more probable
+    // misuse case.
+    for (NamedDecl *D : FRes) {
+      if (D->isInvalidDecl())
+        continue;
+      if (auto *FTD = dyn_cast<FunctionTemplateDecl>(D)) {
+        FunctionDecl *FD = FTD->getTemplatedDecl();
+        QualType FDT = FD->getType().getCanonicalType();
+        if (TypeInstantiationMatcher().match(FriendT, FDT)) {
+          // Appropriate function template is found.
+          S.Diag(FriendD->getLocation(), diag::warn_non_template_friend)
+            << FriendD;
+          SourceLocation NameLoc = FriendD->getNameInfo().getLocEnd();
+          SourceLocation PastName = S.getLocForEndOfToken(NameLoc);
+          S.Diag(PastName, diag::note_befriend_template)
+            << FixItHint::CreateInsertion(PastName, "<>");
+          return;
+        }
+      }
+    }
+    // Then check for suitable functions that uses particular specialization of
+    // parameter type.
+    for (NamedDecl *D : FRes) {
+      if (D->isInvalidDecl())
+        continue;
+      if (auto *FD = dyn_cast<FunctionDecl>(D)) {
+        QualType FT = FD->getType().getCanonicalType();
+        if (TypeInstantiationMatcher().match(FriendT, FT))
+          // This is suitable file-level function, do not issue warnings.
+          return;
+      }
+    }
+  }
+  S.Diag(FriendD->getLocation(), diag::warn_non_template_friend) << FriendD;
+  S.Diag(FriendD->getLocation(), diag::note_add_template_friend_decl);
+  S.Diag(FriendD->getLocation(), diag::note_befriend_template);
+}
+
+}
+
+void Sema::CheckDependentFriends() {
+  for (FunctionDecl *FriendD : FriendsOfTemplates) {
+    if (FriendD->getType()->isDependentType() &&
+        !FriendD->isThisDeclarationADefinition())
+      CheckDependentFriend(*this, FriendD);
+  }
+  FriendsOfTemplates.clear();
+}
Index: lib/Sema/Sema.cpp
===================================================================
--- lib/Sema/Sema.cpp
+++ lib/Sema/Sema.cpp
@@ -689,6 +689,7 @@
       LateTemplateParserCleanup(OpaqueParser);
 
     CheckDelayedMemberExceptionSpecs();
+    CheckDependentFriends();
   }
 
   // All delayed member exception specs should be checked or we end up accepting
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -5752,7 +5752,8 @@
 
   QualType CheckTemplateIdType(TemplateName Template,
                                SourceLocation TemplateLoc,
-                              TemplateArgumentListInfo &TemplateArgs);
+                               TemplateArgumentListInfo &TemplateArgs,
+                               bool IsFriend = false);
 
   TypeResult
   ActOnTemplateIdType(CXXScopeSpec &SS, SourceLocation TemplateKWLoc,
@@ -9462,6 +9463,15 @@
   /// attempts to add itself into the container
   void CheckObjCCircularContainer(ObjCMessageExpr *Message);
 
+  /// \brief Set of file level friend function declared in template classes.
+  /// Such functions are not added to redeclaration chains until instantiation
+  /// of proper templates, but they are needed for checks.
+  SmallVector<FunctionDecl *, 16> FriendsOfTemplates;
+
+  /// \brief Check dependent friend functions for misinterpretation as function
+  /// templates.
+  void CheckDependentFriends();
+
   void AnalyzeDeleteExprMismatch(const CXXDeleteExpr *DE);
   void AnalyzeDeleteExprMismatch(FieldDecl *Field, SourceLocation DeleteLoc,
                                  bool DeleteWasArrayForm);
Index: include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- include/clang/Basic/DiagnosticSemaKinds.td
+++ include/clang/Basic/DiagnosticSemaKinds.td
@@ -1150,6 +1150,13 @@
   "enclosing namespace is a Microsoft extension; add a nested name specifier">,
   InGroup<MicrosoftUnqualifiedFriend>;
 def err_pure_friend : Error<"friend declaration cannot have a pure-specifier">;
+def warn_non_template_friend : Warning<"friend declaration %q0 depends on "
+  "template parameter but is not a function template">,
+   InGroup<NonTemplateFriend>;
+def note_add_template_friend_decl : Note<"declare function outside class "
+  "template to suppress this warning">;
+def note_befriend_template : Note<"to befriend a template specialization, "
+  "make sure the function template has already been declared and use '<>'">;
 
 def err_invalid_member_in_interface : Error<
   "%select{data member |non-public member function |static member function |"
Index: include/clang/Basic/DiagnosticGroups.td
===================================================================
--- include/clang/Basic/DiagnosticGroups.td
+++ include/clang/Basic/DiagnosticGroups.td
@@ -282,6 +282,7 @@
 def InitializerOverrides : DiagGroup<"initializer-overrides">;
 def NonNull : DiagGroup<"nonnull">;
 def NonPODVarargs : DiagGroup<"non-pod-varargs">;
+def NonTemplateFriend : DiagGroup<"non-template-friend">;
 def ClassVarargs : DiagGroup<"class-varargs", [NonPODVarargs]>;
 def : DiagGroup<"nonportable-cfstrings">;
 def NonVirtualDtor : DiagGroup<"non-virtual-dtor">;
Index: include/clang/AST/TypeMatcher.h
===================================================================
--- /dev/null
+++ include/clang/AST/TypeMatcher.h
@@ -0,0 +1,428 @@
+//===--- TypeMatcher.h - Visitor for Type subclasses ------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the TypeMatcher interface, which recursively compares
+//  two types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_TYPEMATCHER_H
+#define LLVM_CLANG_AST_TYPEMATCHER_H
+
+#include "clang/AST/Type.h"
+
+namespace clang {
+
+// A helper macro to implement short-circuiting when recursing.  It invokes
+// CALL_EXPR, which must be a method call, on the derived object.
+#define TRY_TO(CALL_EXPR)                                                      \
+  do {                                                                         \
+    if (!getDerived().CALL_EXPR)                                               \
+      return false;                                                            \
+  } while (0)
+
+
+/// \brief A visitor that traverses one type, a "pattern" and compares it with
+/// another type which is traversed synchronously.
+///
+/// The visitor is used much like other visitor classes, for instance,
+/// RecursiveASTVisitor. The necessary functionality is implemented in a new
+/// class that must be inherited from this template according to Curiously
+/// Recurring Template Pattern:
+///
+/// \code
+///     class AMatcher : public TypeMatcher<AMatcher> {
+///       ...
+///     };
+/// \endcode
+///
+/// The derived class overrides methods declared in TypeMatcher to implement the
+/// necessary behavior.
+///
+/// Most of matching is made by miscellaneous Visit methods that are called for
+/// pairs of QualTypes found in type hierarchy. There are Visit methods for each
+/// subclass of Type, including abstract types. So, for a particular pattern
+/// type usually there are several Visit methods that can be called. Other
+/// methods of this visitor defines which Visit methods are called and in what
+/// order.
+///
+/// Pattern matching starts by calling method TraverseType, which implements
+/// polymorphic operation on an a pair of objects of type derived from Type:
+///
+/// \code
+///     TraverseType(PatternType, EvaluatedType)
+/// \endcode
+///
+/// The function returns true if \c EvaluatedType matches \c PatternType, false
+/// otherwise.
+///
+/// Depending on actual type of \c PatternType this function dispatches call to
+/// function that processes particular type, for instance
+/// \c TraverseConstantArrayType, if the pattern type is \c ConstantArrayType.
+/// This function first calls function \c WalkUpFromConstantArrayType, which
+/// calls appropriate Visit methods. Then it recursively calls itself on
+/// components of the pattern type.
+///
+/// Function \c WalkUpFrom* exists for every type. It calls Visit methods for
+/// the particular pattern type. Default implementation makes calls the starting
+/// from the most general type, so if \c PatternType is ConstantArrayType, then
+/// sequence of calls is as follows:
+///
+/// \li VisitType
+/// \li VisitArrayType
+/// \li VisitConstantArrayType
+///
+/// If any of \c Visit calls returns false, types do not match, any other checks
+/// are not made.
+///
+template<typename ImplClass>
+class TypeMatcher {
+public:
+
+  // Entry point of matcher.
+  bool match(QualType PatternT, QualType InstT) {
+    return TraverseType(PatternT.getCanonicalType(), InstT.getCanonicalType());
+  }
+
+  // Dispatches call to particular Traverse* method.
+  bool TraverseType(QualType PatternT, QualType InstT) {
+    switch (PatternT->getTypeClass()) {
+#define ABSTRACT_TYPE(CLASS, BASE)
+#define TYPE(CLASS, BASE)                                                      \
+  case Type::CLASS:                                                            \
+    return getDerived().Traverse##CLASS##Type(PatternT, InstT);
+#include "clang/AST/TypeNodes.def"
+    default:
+      llvm_unreachable("Unknown type class!");
+    }
+    return false;
+  }
+
+  // Declare Traverse*() for all concrete Type classes.
+#define ABSTRACT_TYPE(CLASS, BASE)
+#define TYPE(CLASS, BASE) bool Traverse##CLASS##Type(QualType P, QualType T);
+#include "clang/AST/TypeNodes.def"
+
+  // Define WalkUpFrom*() for all Type classes.
+#define TYPE(CLASS, BASE)                                                      \
+  bool WalkUpFrom##CLASS##Type(QualType P, QualType S) {                       \
+    TRY_TO(WalkUpFrom##BASE(P, S));                                            \
+    TRY_TO(Visit##CLASS##Type(P, S));                                          \
+    return true;                                                               \
+  }
+#include "clang/AST/TypeNodes.def"
+
+  // Generic version of WalkUpFrom*, which is called first for every type
+  // matching.
+  bool WalkUpFromType(QualType PatternT, QualType T) {
+    return getDerived().VisitType(PatternT, T);
+  }
+
+  // Define default Visit*() for all Type classes.
+#define TYPE(CLASS, BASE)                                                      \
+  bool Visit##CLASS##Type(QualType P, QualType T) {                            \
+    return P->getTypeClass() == T->getTypeClass();                             \
+  }
+#include "clang/AST/TypeNodes.def"
+
+  // Generic version of Visit*, it is called first for every type matching.
+  bool VisitType(QualType PatternT, QualType T) {
+    return true;
+  }
+
+  /// \brief Return a reference to the derived class.
+  ImplClass &getDerived() { return *static_cast<ImplClass *>(this); }
+
+  bool TraverseTemplateName(const TemplateName &PN, const TemplateName &TN) {
+    return PN.getKind() == TN.getKind();
+  }
+
+  bool TraverseTemplateArguments(const TemplateArgument *PArgs, unsigned PNum,
+                                 const TemplateArgument *SArgs, unsigned SNum) {
+    if (PNum != SNum)
+      return false;
+    for (unsigned I = 0; I != PNum; ++I) {
+      TRY_TO(TraverseTemplateArgument(PArgs[I], SArgs[I]));
+    }
+    return true;
+  }
+
+  bool TraverseTemplateArgument(const TemplateArgument &PArg,
+                                const TemplateArgument &SArg) {
+    if (PArg.getKind() != SArg.getKind())
+      return false;
+    switch (PArg.getKind()) {
+    case TemplateArgument::Null:
+    case TemplateArgument::Declaration:
+    case TemplateArgument::Integral:
+    case TemplateArgument::NullPtr:
+      return true;
+
+    case TemplateArgument::Type:
+      return getDerived().TraverseType(PArg.getAsType(), SArg.getAsType());
+
+    case TemplateArgument::Template:
+    case TemplateArgument::TemplateExpansion:
+      return getDerived().TraverseTemplateName(
+        PArg.getAsTemplateOrTemplatePattern(),
+        SArg.getAsTemplateOrTemplatePattern());
+
+    case TemplateArgument::Expression:
+      return getDerived().TraverseStmt(PArg.getAsExpr(), SArg.getAsExpr());
+
+    case TemplateArgument::Pack:
+      return getDerived().TraverseTemplateArguments(PArg.pack_begin(),
+                         PArg.pack_size(), SArg.pack_begin(), SArg.pack_size());
+    }
+
+    return true;
+  }
+
+  bool TraverseNestedNameSpecifier(NestedNameSpecifier *PNNS,
+                                   NestedNameSpecifier *SNNS) {
+    if (!PNNS)
+      return !SNNS;
+    if (!SNNS)
+      return false;
+
+    TRY_TO(TraverseNestedNameSpecifier(PNNS->getPrefix(), SNNS->getPrefix()));
+
+    if (PNNS->getKind() != SNNS->getKind())
+      return false;
+    switch (PNNS->getKind()) {
+    case NestedNameSpecifier::Identifier:
+    case NestedNameSpecifier::Namespace:
+    case NestedNameSpecifier::NamespaceAlias:
+    case NestedNameSpecifier::Global:
+    case NestedNameSpecifier::Super:
+      return true;
+
+    case NestedNameSpecifier::TypeSpec:
+    case NestedNameSpecifier::TypeSpecWithTemplate:
+      TRY_TO(TraverseType(QualType(PNNS->getAsType(), 0),
+                          QualType(SNNS->getAsType(), 0)));
+    }
+
+    return true;
+  }
+
+  // This function in fact compares expressions, that may be found as a part of
+  // type definition (as in VariableArrayType for instance).
+  bool TraverseStmt(Stmt *P, Stmt *S) {
+    if (!P || !S)
+      return !P == !S;
+    Expr *PE = cast<Expr>(P);
+    Expr *SE = cast<Expr>(S);
+    return TraverseType(PE->getType(), SE->getType());
+  }
+};
+
+// Implementation of Traverse* methods.
+//
+// These methods are obtained from corresponding definitions of
+// RecursiveASTVisitor almost mechanically.  The main difference is that
+// Traverse* functions get arguments in pairs: for pattern type and for sample.
+
+// Defines method TypeMatcher<Derived>::Traverse* for the given TYPE.
+// This macro makes available variables P and T, that represent pattern type
+// class and sample type class respectively.
+#define DEF_TRAVERSE_TYPE(TYPE, CODE)                                          \
+  template <typename Derived>                                                  \
+  bool TypeMatcher<Derived>::Traverse##TYPE(QualType PQ, QualType TQ) {        \
+    if (!WalkUpFrom##TYPE(PQ, TQ))                                             \
+      return false;                                                            \
+    const TYPE *P = cast<TYPE>(PQ.getTypePtr());                               \
+    const TYPE *T = cast<TYPE>(TQ.getTypePtr());                               \
+    (void)P; (void)T;                                                          \
+    CODE;                                                                      \
+    return true;                                                               \
+  }
+
+DEF_TRAVERSE_TYPE(BuiltinType, {})
+
+DEF_TRAVERSE_TYPE(ComplexType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(PointerType, {
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType()));
+})
+
+DEF_TRAVERSE_TYPE(BlockPointerType, {
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType()));
+})
+
+DEF_TRAVERSE_TYPE(LValueReferenceType, {
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType()));
+})
+
+DEF_TRAVERSE_TYPE(RValueReferenceType, {
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType()));
+})
+
+DEF_TRAVERSE_TYPE(MemberPointerType, {
+  TRY_TO(TraverseType(QualType(P->getClass(), 0), QualType(T->getClass(), 0)));
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType()));
+})
+
+DEF_TRAVERSE_TYPE(AdjustedType, {
+  TRY_TO(TraverseType(P->getOriginalType(), T->getOriginalType()));
+})
+
+DEF_TRAVERSE_TYPE(DecayedType, {
+  TRY_TO(TraverseType(P->getOriginalType(), T->getOriginalType()));
+})
+
+DEF_TRAVERSE_TYPE(ConstantArrayType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(IncompleteArrayType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(VariableArrayType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+  TRY_TO(TraverseStmt(P->getSizeExpr(), T->getSizeExpr()));
+})
+
+DEF_TRAVERSE_TYPE(DependentSizedArrayType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+  TRY_TO(TraverseStmt(P->getSizeExpr(), T->getSizeExpr()));
+})
+
+DEF_TRAVERSE_TYPE(DependentSizedExtVectorType, {
+  TRY_TO(TraverseStmt(P->getSizeExpr(), T->getSizeExpr()));
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(VectorType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(ExtVectorType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+DEF_TRAVERSE_TYPE(FunctionNoProtoType, {
+  TRY_TO(TraverseType(P->getReturnType(), T->getReturnType()));
+})
+
+DEF_TRAVERSE_TYPE(FunctionProtoType, {
+  // Number of parameters must be equal for pattern and sample types, otherwise
+  // we cannot compare these types in pairs. This check must be done by some of
+  // Visit* methods. The same is for number of exceptions.
+  assert(P->getNumParams() == T->getNumParams());
+  assert(P->exceptions().size() == T->exceptions().size());
+
+  TRY_TO(TraverseType(P->getReturnType(), T->getReturnType()));
+
+  for (unsigned I = 0; I < P->getNumParams(); ++I) {
+    TRY_TO(TraverseType(P->getParamType(I), T->getParamType(I)));
+  }
+
+  for (unsigned I = 0; I < P->exceptions().size(); ++I) {
+    TRY_TO(TraverseType(P->exceptions()[I], T->exceptions()[I]));
+  }
+
+  TRY_TO(TraverseStmt(P->getNoexceptExpr(), T->getNoexceptExpr()));
+})
+
+DEF_TRAVERSE_TYPE(UnresolvedUsingType, {})
+DEF_TRAVERSE_TYPE(TypedefType, {})
+
+DEF_TRAVERSE_TYPE(TypeOfExprType, {
+  TRY_TO(TraverseStmt(P->getUnderlyingExpr(), T->getUnderlyingExpr()));
+})
+
+DEF_TRAVERSE_TYPE(TypeOfType, {
+  TRY_TO(TraverseType(P->getUnderlyingType(), T->getUnderlyingType()));
+})
+
+DEF_TRAVERSE_TYPE(DecltypeType, {
+  TRY_TO(TraverseStmt(P->getUnderlyingExpr(), T->getUnderlyingExpr()));
+})
+
+DEF_TRAVERSE_TYPE(UnaryTransformType, {
+  TRY_TO(TraverseType(P->getBaseType(), T->getBaseType()));
+  TRY_TO(TraverseType(P->getUnderlyingType(), T->getUnderlyingType()));
+})
+
+DEF_TRAVERSE_TYPE(AutoType, {
+  TRY_TO(TraverseType(P->getDeducedType(), T->getDeducedType()));
+})
+
+DEF_TRAVERSE_TYPE(RecordType, {})
+DEF_TRAVERSE_TYPE(EnumType, {})
+DEF_TRAVERSE_TYPE(TemplateTypeParmType, {})
+DEF_TRAVERSE_TYPE(SubstTemplateTypeParmType, {})
+DEF_TRAVERSE_TYPE(SubstTemplateTypeParmPackType, {})
+
+DEF_TRAVERSE_TYPE(TemplateSpecializationType, {
+  TRY_TO(TraverseTemplateName(P->getTemplateName(), T->getTemplateName()));
+  TRY_TO(TraverseTemplateArguments(P->getArgs(), P->getNumArgs(),
+                                   T->getArgs(), T->getNumArgs()));
+})
+
+DEF_TRAVERSE_TYPE(InjectedClassNameType, {})
+
+DEF_TRAVERSE_TYPE(AttributedType, {
+  TRY_TO(TraverseType(P->getModifiedType(), T->getModifiedType()));
+})
+
+DEF_TRAVERSE_TYPE(ParenType, {
+  TRY_TO(TraverseType(P->getInnerType(), T->getInnerType()));
+})
+
+DEF_TRAVERSE_TYPE(ElaboratedType, {
+  TRY_TO(TraverseNestedNameSpecifier(P->getQualifier(), T->getQualifier()));
+  TRY_TO(TraverseType(P->getNamedType(), T->getNamedType()));
+})
+
+DEF_TRAVERSE_TYPE(DependentNameType, {
+  TRY_TO(TraverseNestedNameSpecifier(P->getQualifier(), T->getQualifier()));
+})
+
+DEF_TRAVERSE_TYPE(DependentTemplateSpecializationType, {
+  TRY_TO(TraverseNestedNameSpecifier(P->getQualifier(), T->getQualifier()));
+  TRY_TO(TraverseTemplateArguments(P->getArgs(), P->getNumArgs(),
+                                   T->getArgs(), T->getNumArgs()));
+})
+
+DEF_TRAVERSE_TYPE(PackExpansionType, {
+  TRY_TO(TraverseType(P->getPattern(), T->getPattern()));
+})
+
+DEF_TRAVERSE_TYPE(ObjCInterfaceType, {})
+
+DEF_TRAVERSE_TYPE(ObjCObjectType, {
+  assert(P->getTypeArgsAsWritten().size() == T->getTypeArgsAsWritten().size());
+    return false;
+  TRY_TO(TraverseType(P->getBaseType(), T->getBaseType()));
+  for (unsigned I = 0, E = P->getTypeArgsAsWritten().size(); I != E; ++I) {
+    TRY_TO(TraverseType(P->getTypeArgsAsWritten()[I],
+                        T->getTypeArgsAsWritten()[I]));
+  }
+})
+
+DEF_TRAVERSE_TYPE(ObjCObjectPointerType, {
+  TRY_TO(TraverseType(P->getPointeeType(), T->getPointeeType())); })
+
+DEF_TRAVERSE_TYPE(AtomicType, {
+  TRY_TO(TraverseType(P->getValueType(), T->getValueType()));
+})
+
+DEF_TRAVERSE_TYPE(PipeType, {
+  TRY_TO(TraverseType(P->getElementType(), T->getElementType()));
+})
+
+#undef DEF_TRAVERSE_TYPE
+
+}
+#endif
Index: include/clang/AST/TypeInstantiationMatcher.h
===================================================================
--- /dev/null
+++ include/clang/AST/TypeInstantiationMatcher.h
@@ -0,0 +1,145 @@
+//===--- TypeInstantiationMatcher.h -----------------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+//  Defines TypeMatcher that is used to check if one type is or could be an
+//  instantiation of other type.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_TYPEINSTANTIATIONMATCHER_H
+#define LLVM_CLANG_AST_TYPEINSTANTIATIONMATCHER_H
+
+#include "clang/AST/TypeMatcher.h"
+#include "clang/AST/DeclTemplate.h"
+#include <map>
+
+namespace clang {
+
+/// \brief Visitor class used to check if one type is instantiation of another.
+///
+class TypeInstantiationMatcher : public TypeMatcher<TypeInstantiationMatcher> {
+  std::map<const TemplateTypeParmType *, QualType> ParamMapping;
+
+public:
+
+  bool VisitType(QualType PQ, QualType TQ) {
+    return PQ.getQualifiers() == TQ.getQualifiers();
+  }
+
+  bool VisitBuiltinType(QualType PQ, QualType TQ) {
+    return PQ.getCanonicalType() == TQ.getCanonicalType();
+  }
+
+  bool VisitConstantArrayType(QualType PQ, QualType TQ) {
+    const auto *P = cast<ConstantArrayType>(PQ.getTypePtr());
+    const auto *T = cast<ConstantArrayType>(TQ.getTypePtr());
+    return P->getSize() == T->getSize();
+  }
+
+  bool VisitVectorType(QualType PQ, QualType TQ) {
+    const auto *P = PQ->getAs<VectorType>();
+    const auto *T = TQ->getAs<VectorType>();
+    return P->getNumElements() == T->getNumElements();
+  }
+
+  bool VisitFunctionProtoType(QualType PQ, QualType TQ) {
+    const auto *P = PQ->getAs<FunctionProtoType>();
+    const auto *T = TQ->getAs<FunctionProtoType>();
+    if (P->getNumParams() != T->getNumParams())
+      return false;
+    if (P->exceptions().size() != T->exceptions().size())
+      return false;
+    return true;
+  }
+
+  bool VisitTemplateTypeParmType(QualType PQ, QualType TQ) {
+    auto P = PQ->getCanonicalTypeInternal()->getAs<TemplateTypeParmType>();
+    auto Ptr = ParamMapping.find(P);
+    if (Ptr != ParamMapping.end())
+      return Ptr->second == TQ;
+    ParamMapping[P] = TQ;
+    return true;
+  }
+
+  bool TraverseTemplateTypeParmType(QualType PQ, QualType TQ) {
+    return WalkUpFromTemplateTypeParmType(PQ, TQ);
+  }
+
+  bool matchTemplateParametersAndArguments(const TemplateParameterList &PL,
+                                           const TemplateArgumentList &AL) {
+    if (PL.size() != AL.size())
+      return false;
+    for (unsigned I = 0, E = PL.size(); I < E; ++I) {
+      const NamedDecl *PD = PL.getParam(I);
+      const TemplateArgument &TA = AL.get(I);
+      TemplateArgument::ArgKind kind = TA.getKind();
+      if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(PD)) {
+        if (kind != TemplateArgument::Type)
+          return false;
+        auto Type = TTP->getTypeForDecl()->getCanonicalTypeInternal()
+            ->getAs<TemplateTypeParmType>();
+        auto Ptr = ParamMapping.find(Type);
+        if (Ptr != ParamMapping.end())
+          return Ptr->second == TA.getAsType();
+        ParamMapping[Type] = TA.getAsType();
+      }
+    }
+    return true;
+  }
+
+  bool TraverseTemplateSpecializationType(QualType PQ, QualType TQ) {
+    return WalkUpFromTemplateSpecializationType(PQ, TQ);
+  }
+
+  bool VisitTemplateSpecializationType(QualType PQ, QualType TQ) {
+    const auto *P = PQ->getAs<TemplateSpecializationType>();
+    TemplateName PTN = P->getTemplateName();
+    if (PTN.getKind() != TemplateName::Template)
+      return false;
+    TemplateDecl *PT = PTN.getAsTemplateDecl();
+    auto *PTD = dyn_cast<ClassTemplateDecl>(PT);
+    if (!PTD)
+      return false;
+    auto PParams = PTD->getTemplateParameters();
+
+    if (const CXXRecordDecl *CD = TQ->getAsCXXRecordDecl()) {
+      if (const auto *SD = dyn_cast<ClassTemplateSpecializationDecl>(CD))
+        if (ClassTemplateDecl *TD = SD->getSpecializedTemplate()) {
+          if (PTD->getCanonicalDecl() != TD->getCanonicalDecl())
+            return false;
+          const TemplateArgumentList &TArgs = SD->getTemplateArgs();
+          if (PParams->size() != TArgs.size())
+            return false;
+          return matchTemplateParametersAndArguments(*PParams, TArgs);
+        }
+      return false;
+    }
+
+    return true;
+  }
+
+  bool VisitInjectedClassNameType(QualType PQ, QualType TQ) {
+    const auto *P = PQ->getAs<InjectedClassNameType>();
+    return match(QualType(P->getInjectedSpecializationType()),
+                 TQ.getUnqualifiedType());
+  }
+
+  bool TraverseInjectedClassNameType(QualType PQ, QualType TQ) {
+    return WalkUpFromInjectedClassNameType(PQ, TQ);
+  }
+
+  bool VisitObjCObjectType(QualType PQ, QualType TQ) {
+    const auto *P = PQ->getAs<ObjCObjectType>();
+    const auto *T = TQ->getAs<ObjCObjectType>();
+    return P->getTypeArgsAsWritten().size() == T->getTypeArgsAsWritten().size();
+  }
+};
+
+}
+#endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to