steveire created this revision.
steveire added a reviewer: aaron.ballman.
Herald added a subscriber: cfe-commits.

This is necessary so that when we wish to print the matchers for a
binding of type `CXXMeethodDecl`, but which was matched with a base
matcher such as `functionDecl()` we can inform the user that they can
write `functionDecl(cxxMethodDecl(isOverride()))` etc.


Repository:
  rC Clang

https://reviews.llvm.org/D54407

Files:
  include/clang/ASTMatchers/ASTMatchersInternal.h
  lib/ASTMatchers/ASTMatchersInternal.cpp
  lib/Tooling/RefactoringCallbacks.cpp
  unittests/ASTMatchers/ASTMatchersNodeTest.cpp
  unittests/ASTMatchers/ASTMatchersTest.h

Index: unittests/ASTMatchers/ASTMatchersTest.h
===================================================================
--- unittests/ASTMatchers/ASTMatchersTest.h
+++ unittests/ASTMatchers/ASTMatchersTest.h
@@ -349,12 +349,12 @@
       BoundNodes::IDToNodeMap::const_iterator I = M.find(Id);
       EXPECT_NE(M.end(), I);
       if (I != M.end()) {
-        EXPECT_EQ(Nodes->getNodeAs<T>(Id), I->second.get<T>());
+        EXPECT_EQ(Nodes->getNodeAs<T>(Id), I->second.first.get<T>());
       }
       return true;
     }
     EXPECT_TRUE(M.count(Id) == 0 ||
-      M.find(Id)->second.template get<T>() == nullptr);
+      M.find(Id)->second.first.get<T>() == nullptr);
     return false;
   }
 
Index: unittests/ASTMatchers/ASTMatchersNodeTest.cpp
===================================================================
--- unittests/ASTMatchers/ASTMatchersNodeTest.cpp
+++ unittests/ASTMatchers/ASTMatchersNodeTest.cpp
@@ -1724,19 +1724,30 @@
   std::string SourceCode = "struct A { void f() {} };";
   auto Matcher = functionDecl(isDefinition()).bind("method");
 
+  using namespace ast_type_traits;
+
   auto astUnit = tooling::buildASTFromCode(SourceCode);
 
   auto GlobalBoundNodes = matchDynamic(Matcher, astUnit->getASTContext());
 
   EXPECT_EQ(GlobalBoundNodes.size(), 1u);
   EXPECT_EQ(GlobalBoundNodes[0].getMap().size(), 1u);
+  auto GlobalMapPair = *GlobalBoundNodes[0].getMap().begin();
+  EXPECT_TRUE(GlobalMapPair.second.first.getNodeKind().isSame(ASTNodeKind::getFromNodeKind<CXXMethodDecl>()));
+  EXPECT_TRUE(GlobalMapPair.second.second.isSame(ASTNodeKind::getFromNodeKind<FunctionDecl>()));
 
   auto GlobalMethodNode = GlobalBoundNodes[0].getNodeAs<FunctionDecl>("method");
   EXPECT_TRUE(GlobalMethodNode != nullptr);
 
   auto MethodBoundNodes = matchDynamic(Matcher, *GlobalMethodNode, astUnit->getASTContext());
   EXPECT_EQ(MethodBoundNodes.size(), 1u);
   EXPECT_EQ(MethodBoundNodes[0].getMap().size(), 1u);
+  auto MethodMapPair = *MethodBoundNodes[0].getMap().begin();
+  EXPECT_TRUE(MethodMapPair.second.first.getNodeKind().isSame(ASTNodeKind::getFromNodeKind<CXXMethodDecl>()));
+  EXPECT_TRUE(MethodMapPair.second.second.isSame(ASTNodeKind::getFromNodeKind<FunctionDecl>()));
+  EXPECT_EQ(MethodMapPair.second.first, GlobalMapPair.second.first);
+  EXPECT_TRUE(MethodMapPair.second.second.isSame(GlobalMapPair.second.second));
+
 
   auto MethodNode = MethodBoundNodes[0].getNodeAs<FunctionDecl>("method");
   EXPECT_EQ(MethodNode, GlobalMethodNode);
Index: lib/Tooling/RefactoringCallbacks.cpp
===================================================================
--- lib/Tooling/RefactoringCallbacks.cpp
+++ lib/Tooling/RefactoringCallbacks.cpp
@@ -213,8 +213,8 @@
                      << " used in replacement template not bound in Matcher \n";
         llvm::report_fatal_error("Unbound node in replacement template.");
       }
-      CharSourceRange Source =
-          CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
+      CharSourceRange Source = CharSourceRange::getTokenRange(
+          NodeIter->second.first.getSourceRange());
       ToText += Lexer::getSourceText(Source, *Result.SourceManager,
                                      Result.Context->getLangOpts());
       break;
@@ -227,8 +227,8 @@
     llvm::report_fatal_error("FromId node not bound in MatchResult");
   }
   auto Replacement =
-      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
-                           Result.Context->getLangOpts());
+      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId).first,
+                           ToText, Result.Context->getLangOpts());
   llvm::Error Err = Replace.add(Replacement);
   if (Err) {
     llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
Index: lib/ASTMatchers/ASTMatchersInternal.cpp
===================================================================
--- lib/ASTMatchers/ASTMatchersInternal.cpp
+++ lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -52,21 +52,25 @@
 
 bool NotUnaryOperator(const ast_type_traits::DynTypedNode &DynNode,
                       ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder,
+                      ast_type_traits::ASTNodeKind NodeKind,
                       ArrayRef<DynTypedMatcher> InnerMatchers);
 
 bool AllOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                            ASTMatchFinder *Finder,
                            BoundNodesTreeBuilder *Builder,
+                           ast_type_traits::ASTNodeKind NodeKind,
                            ArrayRef<DynTypedMatcher> InnerMatchers);
 
 bool EachOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                             ASTMatchFinder *Finder,
                             BoundNodesTreeBuilder *Builder,
+                            ast_type_traits::ASTNodeKind NodeKind,
                             ArrayRef<DynTypedMatcher> InnerMatchers);
 
 bool AnyOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                            ASTMatchFinder *Finder,
                            BoundNodesTreeBuilder *Builder,
+                           ast_type_traits::ASTNodeKind NodeKind,
                            ArrayRef<DynTypedMatcher> InnerMatchers);
 
 void BoundNodesTreeBuilder::visitMatches(Visitor *ResultVisitor) {
@@ -81,18 +85,19 @@
 
 using VariadicOperatorFunction = bool (*)(
     const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder,
-    BoundNodesTreeBuilder *Builder, ArrayRef<DynTypedMatcher> InnerMatchers);
+    BoundNodesTreeBuilder *Builder, ast_type_traits::ASTNodeKind NodeKind,
+    ArrayRef<DynTypedMatcher> InnerMatchers);
 
 template <VariadicOperatorFunction Func>
 class VariadicMatcher : public DynMatcherInterface {
 public:
   VariadicMatcher(std::vector<DynTypedMatcher> InnerMatchers)
       : InnerMatchers(std::move(InnerMatchers)) {}
 
   bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
-                  ASTMatchFinder *Finder,
-                  BoundNodesTreeBuilder *Builder) const override {
-    return Func(DynNode, Finder, Builder, InnerMatchers);
+                  ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder,
+                  ast_type_traits::ASTNodeKind NodeKind) const override {
+    return Func(DynNode, Finder, Builder, NodeKind, InnerMatchers);
   }
 
 private:
@@ -106,10 +111,11 @@
       : ID(ID), InnerMatcher(std::move(InnerMatcher)) {}
 
   bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
-                  ASTMatchFinder *Finder,
-                  BoundNodesTreeBuilder *Builder) const override {
-    bool Result = InnerMatcher->dynMatches(DynNode, Finder, Builder);
-    if (Result) Builder->setBinding(ID, DynNode);
+                  ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder,
+                  ast_type_traits::ASTNodeKind NodeKind) const override {
+    bool Result = InnerMatcher->dynMatches(DynNode, Finder, Builder, NodeKind);
+    if (Result)
+      Builder->setBinding(ID, DynNode, NodeKind);
     return Result;
   }
 
@@ -130,7 +136,8 @@
   }
 
   bool dynMatches(const ast_type_traits::DynTypedNode &, ASTMatchFinder *,
-                  BoundNodesTreeBuilder *) const override {
+                  BoundNodesTreeBuilder *,
+                  ast_type_traits::ASTNodeKind) const override {
     return true;
   }
 };
@@ -213,7 +220,7 @@
                               ASTMatchFinder *Finder,
                               BoundNodesTreeBuilder *Builder) const {
   if (RestrictKind.isBaseOf(DynNode.getNodeKind()) &&
-      Implementation->dynMatches(DynNode, Finder, Builder)) {
+      Implementation->dynMatches(DynNode, Finder, Builder, RestrictKind)) {
     return true;
   }
   // Delete all bindings when a matcher does not match.
@@ -227,7 +234,7 @@
     const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder,
     BoundNodesTreeBuilder *Builder) const {
   assert(RestrictKind.isBaseOf(DynNode.getNodeKind()));
-  if (Implementation->dynMatches(DynNode, Finder, Builder)) {
+  if (Implementation->dynMatches(DynNode, Finder, Builder, RestrictKind)) {
     return true;
   }
   // Delete all bindings when a matcher does not match.
@@ -262,6 +269,7 @@
 
 bool NotUnaryOperator(const ast_type_traits::DynTypedNode &DynNode,
                       ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder,
+                      ast_type_traits::ASTNodeKind NodeKind,
                       ArrayRef<DynTypedMatcher> InnerMatchers) {
   if (InnerMatchers.size() != 1)
     return false;
@@ -283,6 +291,7 @@
 bool AllOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                            ASTMatchFinder *Finder,
                            BoundNodesTreeBuilder *Builder,
+                           ast_type_traits::ASTNodeKind NodeKind,
                            ArrayRef<DynTypedMatcher> InnerMatchers) {
   // allOf leads to one matcher for each alternative in the first
   // matcher combined with each alternative in the second matcher.
@@ -297,6 +306,7 @@
 bool EachOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                             ASTMatchFinder *Finder,
                             BoundNodesTreeBuilder *Builder,
+                            ast_type_traits::ASTNodeKind NodeKind,
                             ArrayRef<DynTypedMatcher> InnerMatchers) {
   BoundNodesTreeBuilder Result;
   bool Matched = false;
@@ -314,6 +324,7 @@
 bool AnyOfVariadicOperator(const ast_type_traits::DynTypedNode &DynNode,
                            ASTMatchFinder *Finder,
                            BoundNodesTreeBuilder *Builder,
+                           ast_type_traits::ASTNodeKind NodeKind,
                            ArrayRef<DynTypedMatcher> InnerMatchers) {
   for (const DynTypedMatcher &InnerMatcher : InnerMatchers) {
     BoundNodesTreeBuilder Result = *Builder;
Index: include/clang/ASTMatchers/ASTMatchersInternal.h
===================================================================
--- include/clang/ASTMatchers/ASTMatchersInternal.h
+++ include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -149,8 +149,9 @@
   /// Adds \c Node to the map with key \c ID.
   ///
   /// The node's base type should be in NodeBaseType or it will be unaccessible.
-  void addNode(StringRef ID, const ast_type_traits::DynTypedNode& DynNode) {
-    NodeMap[ID] = DynNode;
+  void addNode(StringRef ID, const ast_type_traits::DynTypedNode &DynNode,
+               ast_type_traits::ASTNodeKind NodeKind) {
+    NodeMap[ID] = std::make_pair(DynNode, NodeKind);
   }
 
   /// Returns the AST node bound to \c ID.
@@ -163,15 +164,15 @@
     if (It == NodeMap.end()) {
       return nullptr;
     }
-    return It->second.get<T>();
+    return It->second.first.get<T>();
   }
 
   ast_type_traits::DynTypedNode getNode(StringRef ID) const {
     IDToNodeMap::const_iterator It = NodeMap.find(ID);
     if (It == NodeMap.end()) {
       return ast_type_traits::DynTypedNode();
     }
-    return It->second;
+    return It->second.first;
   }
 
   /// Imposes an order on BoundNodesMaps.
@@ -184,7 +185,9 @@
   /// Note that we're using std::map here, as for memoization:
   /// - we need a comparison operator
   /// - we need an assignment operator
-  using IDToNodeMap = std::map<std::string, ast_type_traits::DynTypedNode>;
+  using IDToNodeMap =
+      std::map<std::string, std::pair<ast_type_traits::DynTypedNode,
+                                      ast_type_traits::ASTNodeKind>>;
 
   const IDToNodeMap &getMap() const {
     return NodeMap;
@@ -194,7 +197,7 @@
   /// stored nodes have memoization data.
   bool isComparable() const {
     for (const auto &IDAndNode : NodeMap) {
-      if (!IDAndNode.second.getMemoizationData())
+      if (!IDAndNode.second.first.getMemoizationData())
         return false;
     }
     return true;
@@ -223,11 +226,12 @@
   };
 
   /// Add a binding from an id to a node.
-  void setBinding(StringRef Id, const ast_type_traits::DynTypedNode &DynNode) {
+  void setBinding(StringRef Id, const ast_type_traits::DynTypedNode &DynNode,
+                  ast_type_traits::ASTNodeKind NodeKind) {
     if (Bindings.empty())
       Bindings.emplace_back();
     for (BoundNodesMap &Binding : Bindings)
-      Binding.addNode(Id, DynNode);
+      Binding.addNode(Id, DynNode, NodeKind);
   }
 
   /// Adds a branch in the tree.
@@ -282,7 +286,8 @@
   /// the AST via \p Finder.
   virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
                           ASTMatchFinder *Finder,
-                          BoundNodesTreeBuilder *Builder) const = 0;
+                          BoundNodesTreeBuilder *Builder,
+                          ast_type_traits::ASTNodeKind NodeKind) const = 0;
 };
 
 /// Generic interface for matchers on an AST node of type T.
@@ -304,8 +309,8 @@
                        BoundNodesTreeBuilder *Builder) const = 0;
 
   bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
-                  ASTMatchFinder *Finder,
-                  BoundNodesTreeBuilder *Builder) const override {
+                  ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder,
+                  ast_type_traits::ASTNodeKind NodeKind) const override {
     return matches(DynNode.getUnchecked<T>(), Finder, Builder);
   }
 };
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to