johannes created this revision.
Herald added a subscriber: klimek.

This adds to each node a reference to its syntax tree. As a result,
instead of passing around the tree plus the node ID, we just use a
reference to the node. This removes some potential for errors. Users
will almost always use node references and are oblivious of their IDs.

Iterating through lists of Nodes is provided by NodeRefIterator


https://reviews.llvm.org/D39644

Files:
  include/clang/Tooling/ASTDiff/ASTDiff.h
  include/clang/Tooling/ASTDiff/ASTDiffInternal.h
  lib/Tooling/ASTDiff/ASTDiff.cpp
  tools/clang-diff/ClangDiff.cpp

Index: tools/clang-diff/ClangDiff.cpp
===================================================================
--- tools/clang-diff/ClangDiff.cpp
+++ tools/clang-diff/ClangDiff.cpp
@@ -265,30 +265,31 @@
 }
 
 static unsigned printHtmlForNode(raw_ostream &OS, const diff::ASTDiff &Diff,
-                                 diff::SyntaxTree &Tree, bool IsLeft,
-                                 diff::NodeId Id, unsigned Offset) {
-  const diff::Node &Node = Tree.getNode(Id);
+                                 bool IsLeft, const diff::Node &Node,
+                                 unsigned Offset) {
   char MyTag, OtherTag;
   diff::NodeId LeftId, RightId;
-  diff::NodeId TargetId = Diff.getMapped(Tree, Id);
+  diff::SyntaxTree &Tree = Node.getTree();
+  const diff::Node *Target = Diff.getMapped(Tree, Node);
+  diff::NodeId TargetId = Target ? Target->getId() : diff::NodeId();
   if (IsLeft) {
     MyTag = 'L';
     OtherTag = 'R';
-    LeftId = Id;
+    LeftId = Node.getId();
     RightId = TargetId;
   } else {
     MyTag = 'R';
     OtherTag = 'L';
     LeftId = TargetId;
-    RightId = Id;
+    RightId = Node.getId();
   }
   unsigned Begin, End;
   std::tie(Begin, End) = Tree.getSourceRangeOffsets(Node);
-  const SourceManager &SrcMgr = Tree.getASTContext().getSourceManager();
-  auto Code = SrcMgr.getBuffer(SrcMgr.getMainFileID())->getBuffer();
+  const SourceManager &SM = Tree.getASTContext().getSourceManager();
+  auto Code = SM.getBuffer(SM.getMainFileID())->getBuffer();
   for (; Offset < Begin; ++Offset)
     printHtml(OS, Code[Offset]);
-  OS << "<span id='" << MyTag << Id << "' "
+  OS << "<span id='" << MyTag << Node.getId() << "' "
      << "tid='" << OtherTag << TargetId << "' ";
   OS << "title='";
   printHtml(OS, Node.getTypeLabel());
@@ -303,12 +304,12 @@
     OS << " class='" << getChangeKindAbbr(Node.Change) << "'";
   OS << ">";
 
-  for (diff::NodeId Child : Node.Children)
-    Offset = printHtmlForNode(OS, Diff, Tree, IsLeft, Child, Offset);
+  for (const diff::Node &Child : Node)
+    Offset = printHtmlForNode(OS, Diff, IsLeft, Child, Offset);
 
   for (; Offset < End; ++Offset)
     printHtml(OS, Code[Offset]);
-  if (Id == Tree.getRootId()) {
+  if (&Node == &Tree.getRoot()) {
     End = Code.size();
     for (; Offset < End; ++Offset)
       printHtml(OS, Code[Offset]);
@@ -343,28 +344,26 @@
 }
 
 static void printNodeAttributes(raw_ostream &OS, diff::SyntaxTree &Tree,
-                                diff::NodeId Id) {
-  const diff::Node &N = Tree.getNode(Id);
-  OS << R"("id":)" << int(Id);
-  OS << R"(,"type":")" << N.getTypeLabel() << '"';
-  auto Offsets = Tree.getSourceRangeOffsets(N);
+                                const diff::Node &Node) {
+  OS << R"("id":)" << int(Node.getId());
+  OS << R"(,"type":")" << Node.getTypeLabel() << '"';
+  auto Offsets = Tree.getSourceRangeOffsets(Node);
   OS << R"(,"begin":)" << Offsets.first;
   OS << R"(,"end":)" << Offsets.second;
-  std::string Value = Tree.getNodeValue(N);
+  std::string Value = Tree.getNodeValue(Node);
   if (!Value.empty()) {
     OS << R"(,"value":")";
     printJsonString(OS, Value);
     OS << '"';
   }
 }
 
 static void printNodeAsJson(raw_ostream &OS, diff::SyntaxTree &Tree,
-                            diff::NodeId Id) {
-  const diff::Node &N = Tree.getNode(Id);
+                            const diff::Node &Node) {
   OS << "{";
-  printNodeAttributes(OS, Tree, Id);
-  auto Identifier = N.getIdentifier();
-  auto QualifiedIdentifier = N.getQualifiedIdentifier();
+  printNodeAttributes(OS, Tree, Node);
+  auto Identifier = Node.getIdentifier();
+  auto QualifiedIdentifier = Node.getQualifiedIdentifier();
   if (Identifier) {
     OS << R"(,"identifier":")";
     printJsonString(OS, *Identifier);
@@ -376,66 +375,65 @@
     }
   }
   OS << R"(,"children":[)";
-  if (N.Children.size() > 0) {
-    printNodeAsJson(OS, Tree, N.Children[0]);
-    for (size_t I = 1, E = N.Children.size(); I < E; ++I) {
+  auto ChildBegin = Node.begin(), ChildEnd = Node.end();
+  if (ChildBegin != ChildEnd) {
+    printNodeAsJson(OS, Tree, *ChildBegin);
+    for (++ChildBegin; ChildBegin != ChildEnd; ++ChildBegin) {
       OS << ",";
-      printNodeAsJson(OS, Tree, N.Children[I]);
+      printNodeAsJson(OS, Tree, *ChildBegin);
     }
   }
   OS << "]}";
 }
 
 static void printNode(raw_ostream &OS, diff::SyntaxTree &Tree,
-                      diff::NodeId Id) {
-  if (Id.isInvalid()) {
-    OS << "None";
-    return;
-  }
-  OS << Tree.getNode(Id).getTypeLabel();
-  std::string Value = Tree.getNodeValue(Id);
+                      const diff::Node &Node) {
+  OS << Node.getTypeLabel();
+  std::string Value = Tree.getNodeValue(Node);
   if (!Value.empty())
     OS << ": " << Value;
-  OS << "(" << Id << ")";
+  OS << "(" << Node.getId() << ")";
 }
 
 static void printTree(raw_ostream &OS, diff::SyntaxTree &Tree) {
-  for (diff::NodeId Id : Tree) {
-    for (int I = 0; I < Tree.getNode(Id).Depth; ++I)
+  for (const diff::Node &Node : Tree) {
+    for (int I = 0; I < Node.Depth; ++I)
       OS << " ";
-    printNode(OS, Tree, Id);
+    printNode(OS, Tree, Node);
     OS << "\n";
   }
 }
 
 static void printDstChange(raw_ostream &OS, diff::ASTDiff &Diff,
                            diff::SyntaxTree &SrcTree, diff::SyntaxTree &DstTree,
-                           diff::NodeId Dst) {
-  const diff::Node &DstNode = DstTree.getNode(Dst);
-  diff::NodeId Src = Diff.getMapped(DstTree, Dst);
-  switch (DstNode.Change) {
+                           const diff::Node &Dst) {
+  const diff::Node *Src = Diff.getMapped(DstTree, Dst);
+  switch (Dst.Change) {
   case diff::None:
     break;
   case diff::Delete:
     llvm_unreachable("The destination tree can't have deletions.");
   case diff::Update:
     OS << "Update ";
-    printNode(OS, SrcTree, Src);
+    printNode(OS, SrcTree, *Src);
     OS << " to " << DstTree.getNodeValue(Dst) << "\n";
     break;
   case diff::Insert:
   case diff::Move:
   case diff::UpdateMove:
-    if (DstNode.Change == diff::Insert)
+    if (Dst.Change == diff::Insert)
       OS << "Insert";
-    else if (DstNode.Change == diff::Move)
+    else if (Dst.Change == diff::Move)
       OS << "Move";
-    else if (DstNode.Change == diff::UpdateMove)
+    else if (Dst.Change == diff::UpdateMove)
       OS << "Update and Move";
     OS << " ";
     printNode(OS, DstTree, Dst);
     OS << " into ";
-    printNode(OS, DstTree, DstNode.Parent);
+    if (!Dst.getParent())
+      OS << "None";
+    else
+      printNode(OS, DstTree, *Dst.getParent());
     OS << " at " << DstTree.findPositionInParent(Dst) << "\n";
     break;
   }
@@ -471,7 +469,7 @@
     llvm::outs() << R"({"filename":")";
     printJsonString(llvm::outs(), SourcePath);
     llvm::outs() << R"(","root":)";
-    printNodeAsJson(llvm::outs(), Tree, Tree.getRootId());
+    printNodeAsJson(llvm::outs(), Tree, Tree.getRoot());
     llvm::outs() << "}\n";
     return 0;
   }
@@ -504,29 +502,28 @@
   if (HtmlDiff) {
     llvm::outs() << HtmlDiffHeader << "<pre>";
     llvm::outs() << "<div id='L' class='code'>";
-    printHtmlForNode(llvm::outs(), Diff, SrcTree, true, SrcTree.getRootId(), 0);
+    printHtmlForNode(llvm::outs(), Diff, true, SrcTree.getRoot(), 0);
     llvm::outs() << "</div>";
     llvm::outs() << "<div id='R' class='code'>";
-    printHtmlForNode(llvm::outs(), Diff, DstTree, false, DstTree.getRootId(),
-                     0);
+    printHtmlForNode(llvm::outs(), Diff, false, DstTree.getRoot(), 0);
     llvm::outs() << "</div>";
     llvm::outs() << "</pre></div></body></html>\n";
     return 0;
   }
 
-  for (diff::NodeId Dst : DstTree) {
-    diff::NodeId Src = Diff.getMapped(DstTree, Dst);
-    if (PrintMatches && Src.isValid()) {
+  for (const diff::Node &Dst : DstTree) {
+    const diff::Node *Src = Diff.getMapped(DstTree, Dst);
+    if (PrintMatches && Src) {
       llvm::outs() << "Match ";
-      printNode(llvm::outs(), SrcTree, Src);
+      printNode(llvm::outs(), SrcTree, *Src);
       llvm::outs() << " to ";
       printNode(llvm::outs(), DstTree, Dst);
       llvm::outs() << "\n";
     }
     printDstChange(llvm::outs(), Diff, SrcTree, DstTree, Dst);
   }
-  for (diff::NodeId Src : SrcTree) {
-    if (Diff.getMapped(SrcTree, Src).isInvalid()) {
+  for (const diff::Node &Src : SrcTree) {
+    if (!Diff.getMapped(SrcTree, Src)) {
       llvm::outs() << "Delete ";
       printNode(llvm::outs(), SrcTree, Src);
       llvm::outs() << "\n";
Index: lib/Tooling/ASTDiff/ASTDiff.cpp
===================================================================
--- lib/Tooling/ASTDiff/ASTDiff.cpp
+++ lib/Tooling/ASTDiff/ASTDiff.cpp
@@ -68,37 +68,33 @@
   // Compute Change for each node based on similarity.
   void computeChangeKinds(Mapping &M);
 
-  NodeId getMapped(const std::unique_ptr<SyntaxTree::Impl> &Tree,
-                   NodeId Id) const {
-    if (&*Tree == &T1)
-      return TheMapping.getDst(Id);
-    assert(&*Tree == &T2 && "Invalid tree.");
-    return TheMapping.getSrc(Id);
-  }
+  const Node *getMapped(const std::unique_ptr<SyntaxTree::Impl> &Tree,
+                        const Node &N) const;
 
 private:
   // Returns true if the two subtrees are identical.
-  bool identical(NodeId Id1, NodeId Id2) const;
+  bool identical(const Node &N1, const Node &N2) const;
 
   // Returns false if the nodes must not be mached.
-  bool isMatchingPossible(NodeId Id1, NodeId Id2) const;
+  bool isMatchingPossible(const Node &N1, const Node &N2) const;
 
   // Returns true if the nodes' parents are matched.
-  bool haveSameParents(const Mapping &M, NodeId Id1, NodeId Id2) const;
+  bool haveSameParents(const Mapping &M, const Node &N1, const Node &N2) const;
 
   // Uses an optimal albeit slow algorithm to compute a mapping between two
   // subtrees, but only if both have fewer nodes than MaxSize.
-  void addOptimalMapping(Mapping &M, NodeId Id1, NodeId Id2) const;
+  void addOptimalMapping(Mapping &M, const Node &N1, const Node &N2) const;
 
   // Computes the ratio of common descendants between the two nodes.
-  // Descendants are only considered to be equal when they are mapped in M.
-  double getJaccardSimilarity(const Mapping &M, NodeId Id1, NodeId Id2) const;
+  // Descendants are only considered to be equal when they are mapped.
+  double getJaccardSimilarity(const Mapping &M, const Node &N1,
+                              const Node &N2) const;
 
   // Returns the node that has the highest degree of similarity.
-  NodeId findCandidate(const Mapping &M, NodeId Id1) const;
+  const Node *findCandidate(const Mapping &M, const Node &N1) const;
 
   // Returns a mapping of identical subtrees.
-  Mapping matchTopDown() const;
+  Mapping matchTopDown();
 
   // Tries to match any yet unmapped nodes, in a bottom-up fashion.
   void matchBottomUp(Mapping &M) const;
@@ -108,6 +104,20 @@
   friend class ZhangShashaMatcher;
 };
 
+namespace {
+struct NodeList {
+  SyntaxTree::Impl &Tree;
+  std::vector<NodeId> Ids;
+  NodeList(SyntaxTree::Impl &Tree) : Tree(Tree) {}
+  void push_back(NodeId Id) { Ids.push_back(Id); }
+  NodeRefIterator begin() const { return {&Tree, &*Ids.begin()}; }
+  NodeRefIterator end() const { return {&Tree, &*Ids.end()}; }
+  const Node &operator[](size_t Index) { return *(begin() + Index); }
+  size_t size() { return Ids.size(); }
+  void sort() { std::sort(Ids.begin(), Ids.end()); }
+};
+} // end anonymous namespace
+
 /// Represents the AST of a TranslationUnit.
 class SyntaxTree::Impl {
 public:
@@ -131,29 +141,28 @@
   PrintingPolicy TypePP;
   /// Nodes in preorder.
   std::vector<Node> Nodes;
-  std::vector<NodeId> Leaves;
+  NodeList Leaves;
   // Maps preorder indices to postorder ones.
   std::vector<int> PostorderIds;
-  std::vector<NodeId> NodesBfs;
+  NodeList NodesBfs;
 
   int getSize() const { return Nodes.size(); }
+  const Node &getRoot() const { return getNode(getRootId()); }
   NodeId getRootId() const { return 0; }
-  PreorderIterator begin() const { return getRootId(); }
-  PreorderIterator end() const { return getSize(); }
+  PreorderIterator begin() const { return &getRoot(); }
+  PreorderIterator end() const { return begin() + getSize(); }
 
   const Node &getNode(NodeId Id) const { return Nodes[Id]; }
   Node &getMutableNode(NodeId Id) { return Nodes[Id]; }
-  bool isValidNodeId(NodeId Id) const { return Id >= 0 && Id < getSize(); }
-  void addNode(Node &N) { Nodes.push_back(N); }
-  int getNumberOfDescendants(NodeId Id) const;
-  bool isInSubtree(NodeId Id, NodeId SubtreeRoot) const;
-  int findPositionInParent(NodeId Id, bool Shifted = false) const;
+  Node &getMutableNode(const Node &N) { return getMutableNode(N.getId()); }
+  int getNumberOfDescendants(const Node &N) const;
+  bool isInSubtree(const Node &N, const Node &SubtreeRoot) const;
+  int findPositionInParent(const Node &Id, bool Shifted = false) const;
 
   std::string getRelativeName(const NamedDecl *ND,
                               const DeclContext *Context) const;
   std::string getRelativeName(const NamedDecl *ND) const;
 
-  std::string getNodeValue(NodeId Id) const;
   std::string getNodeValue(const Node &Node) const;
   std::string getDeclValue(const Decl *D) const;
   std::string getStmtValue(const Stmt *S) const;
@@ -163,23 +172,37 @@
   void setLeftMostDescendants();
 };
 
+const Node &NodeRefIterator::operator*() const {
+  return Tree->getNode(*IdPointer);
+}
+
+NodeRefIterator &NodeRefIterator::operator++() { return ++IdPointer, *this; }
+NodeRefIterator &NodeRefIterator::operator+(int Offset) {
+  return IdPointer += Offset, *this;
+}
+
+bool NodeRefIterator::operator!=(const NodeRefIterator &Other) const {
+  assert(Tree == Other.Tree &&
+         "Cannot compare two iterators of different trees.");
+  return IdPointer != Other.IdPointer;
+}
+
 static bool isSpecializedNodeExcluded(const Decl *D) { return D->isImplicit(); }
 static bool isSpecializedNodeExcluded(const Stmt *S) { return false; }
 static bool isSpecializedNodeExcluded(CXXCtorInitializer *I) {
   return !I->isWritten();
 }
 
-template <class T>
-static bool isNodeExcluded(const SourceManager &SrcMgr, T *N) {
+template <class T> static bool isNodeExcluded(const SourceManager &SM, T *N) {
   if (!N)
     return true;
   SourceLocation SLoc = N->getSourceRange().getBegin();
   if (SLoc.isValid()) {
     // Ignore everything from other files.
-    if (!SrcMgr.isInMainFile(SLoc))
+    if (!SM.isInMainFile(SLoc))
       return true;
     // Ignore macros.
-    if (SLoc != SrcMgr.getSpellingLoc(SLoc))
+    if (SLoc != SM.getSpellingLoc(SLoc))
       return true;
   }
   return isSpecializedNodeExcluded(N);
@@ -196,7 +219,7 @@
 
   template <class T> std::tuple<NodeId, NodeId> PreTraverse(T *ASTNode) {
     NodeId MyId = Id;
-    Tree.Nodes.emplace_back();
+    Tree.Nodes.emplace_back(Tree);
     Node &N = Tree.getMutableNode(MyId);
     N.Parent = Parent;
     N.Depth = Depth;
@@ -220,7 +243,7 @@
     --Depth;
     Node &N = Tree.getMutableNode(MyId);
     N.RightMostDescendant = Id - 1;
-    assert(N.RightMostDescendant >= 0 &&
+    assert(N.RightMostDescendant >= Tree.getRootId() &&
            N.RightMostDescendant < Tree.getSize() &&
            "Rightmost descendant must be a valid tree node.");
     if (N.isLeaf())
@@ -260,7 +283,8 @@
 } // end anonymous namespace
 
 SyntaxTree::Impl::Impl(SyntaxTree *Parent, ASTContext &AST)
-    : Parent(Parent), AST(AST), TypePP(AST.getLangOpts()) {
+    : Parent(Parent), AST(AST), TypePP(AST.getLangOpts()), Leaves(*this),
+      NodesBfs(*this) {
   TypePP.AnonymousTagLocations = false;
 }
 
@@ -278,7 +302,7 @@
   initTree();
 }
 
-static std::vector<NodeId> getSubtreePostorder(const SyntaxTree::Impl &Tree,
+static std::vector<NodeId> getSubtreePostorder(SyntaxTree::Impl &Tree,
                                                NodeId Root) {
   std::vector<NodeId> Postorder;
   std::function<void(NodeId)> Traverse = [&](NodeId Id) {
@@ -291,61 +315,59 @@
   return Postorder;
 }
 
-static std::vector<NodeId> getSubtreeBfs(const SyntaxTree::Impl &Tree,
-                                         NodeId Root) {
-  std::vector<NodeId> Ids;
+static void getSubtreeBfs(NodeList &Ids, const Node &Root) {
   size_t Expanded = 0;
-  Ids.push_back(Root);
+  Ids.push_back(Root.getId());
   while (Expanded < Ids.size())
-    for (NodeId Child : Tree.getNode(Ids[Expanded++]).Children)
-      Ids.push_back(Child);
-  return Ids;
+    for (const Node &Child : Ids[Expanded++])
+      Ids.push_back(Child.getId());
 }
 
 void SyntaxTree::Impl::initTree() {
   setLeftMostDescendants();
   int PostorderId = 0;
   PostorderIds.resize(getSize());
-  std::function<void(NodeId)> PostorderTraverse = [&](NodeId Id) {
-    for (NodeId Child : getNode(Id).Children)
+  std::function<void(const Node &)> PostorderTraverse = [&](const Node &N) {
+    for (const Node &Child : N)
       PostorderTraverse(Child);
-    PostorderIds[Id] = PostorderId;
+    PostorderIds[N.getId()] = PostorderId;
     ++PostorderId;
   };
-  PostorderTraverse(getRootId());
-  NodesBfs = getSubtreeBfs(*this, getRootId());
+  PostorderTraverse(getRoot());
+  getSubtreeBfs(NodesBfs, getRoot());
 }
 
 void SyntaxTree::Impl::setLeftMostDescendants() {
-  for (NodeId Leaf : Leaves) {
-    getMutableNode(Leaf).LeftMostDescendant = Leaf;
-    NodeId Parent, Cur = Leaf;
-    while ((Parent = getNode(Cur).Parent).isValid() &&
-           getNode(Parent).Children[0] == Cur) {
+  for (const Node &Leaf : Leaves) {
+    getMutableNode(Leaf).LeftMostDescendant = Leaf.getId();
+    const Node *Parent, *Cur = &Leaf;
+    while ((Parent = Cur->getParent()) && &Parent->getChild(0) == Cur) {
       Cur = Parent;
-      getMutableNode(Cur).LeftMostDescendant = Leaf;
+      getMutableNode(*Cur).LeftMostDescendant = Leaf.getId();
     }
   }
 }
 
-int SyntaxTree::Impl::getNumberOfDescendants(NodeId Id) const {
-  return getNode(Id).RightMostDescendant - Id + 1;
+int SyntaxTree::Impl::getNumberOfDescendants(const Node &N) const {
+  return N.RightMostDescendant - N.getId() + 1;
 }
 
-bool SyntaxTree::Impl::isInSubtree(NodeId Id, NodeId SubtreeRoot) const {
-  return Id >= SubtreeRoot && Id <= getNode(SubtreeRoot).RightMostDescendant;
+bool SyntaxTree::Impl::isInSubtree(const Node &N,
+                                   const Node &SubtreeRoot) const {
+  return N.getId() >= SubtreeRoot.getId() &&
+         N.getId() <= SubtreeRoot.RightMostDescendant;
 }
 
-int SyntaxTree::Impl::findPositionInParent(NodeId Id, bool Shifted) const {
-  NodeId Parent = getNode(Id).Parent;
-  if (Parent.isInvalid())
+int SyntaxTree::Impl::findPositionInParent(const Node &N, bool Shifted) const {
+  if (!N.getParent())
     return 0;
-  const auto &Siblings = getNode(Parent).Children;
+  const Node &Parent = *N.getParent();
+  const auto &Siblings = Parent.Children;
   int Position = 0;
   for (size_t I = 0, E = Siblings.size(); I < E; ++I) {
     if (Shifted)
       Position += getNode(Siblings[I]).Shift;
-    if (Siblings[I] == Id) {
+    if (Siblings[I] == N.getId()) {
       Position += I;
       return Position;
     }
@@ -406,11 +428,8 @@
   llvm_unreachable("Unknown initializer type");
 }
 
-std::string SyntaxTree::Impl::getNodeValue(NodeId Id) const {
-  return getNodeValue(getNode(Id));
-}
-
 std::string SyntaxTree::Impl::getNodeValue(const Node &N) const {
+  assert(&N.Tree == this);
   const DynTypedNode &DTN = N.ASTNode;
   if (auto *S = DTN.get<Stmt>())
     return getStmtValue(S);
@@ -486,16 +505,16 @@
 class Subtree {
 private:
   /// The parent tree.
-  const SyntaxTree::Impl &Tree;
+  SyntaxTree::Impl &Tree;
   /// Maps SNodeIds to original ids.
   std::vector<NodeId> RootIds;
   /// Maps subtree nodes to their leftmost descendants wtihin the subtree.
   std::vector<SNodeId> LeftMostDescendants;
 
 public:
   std::vector<SNodeId> KeyRoots;
 
-  Subtree(const SyntaxTree::Impl &Tree, NodeId SubtreeRoot) : Tree(Tree) {
+  Subtree(SyntaxTree::Impl &Tree, NodeId SubtreeRoot) : Tree(Tree) {
     RootIds = getSubtreePostorder(Tree, SubtreeRoot);
     int NumLeaves = setLeftMostDescendants();
     computeKeyRoots(NumLeaves);
@@ -517,7 +536,7 @@
     return Tree.PostorderIds[getIdInRoot(SNodeId(1))];
   }
   std::string getNodeValue(SNodeId Id) const {
-    return Tree.getNodeValue(getIdInRoot(Id));
+    return Tree.getNodeValue(getNode(Id));
   }
 
 private:
@@ -563,8 +582,8 @@
   std::unique_ptr<std::unique_ptr<double[]>[]> TreeDist, ForestDist;
 
 public:
-  ZhangShashaMatcher(const ASTDiff::Impl &DiffImpl, const SyntaxTree::Impl &T1,
-                     const SyntaxTree::Impl &T2, NodeId Id1, NodeId Id2)
+  ZhangShashaMatcher(const ASTDiff::Impl &DiffImpl, SyntaxTree::Impl &T1,
+                     SyntaxTree::Impl &T2, NodeId Id1, NodeId Id2)
       : DiffImpl(DiffImpl), S1(T1, Id1), S2(T2, Id2) {
     TreeDist = llvm::make_unique<std::unique_ptr<double[]>[]>(
         size_t(S1.getSize()) + 1);
@@ -615,11 +634,11 @@
           SNodeId LMD2 = S2.getLeftMostDescendant(Col);
           if (LMD1 == S1.getLeftMostDescendant(LastRow) &&
               LMD2 == S2.getLeftMostDescendant(LastCol)) {
-            NodeId Id1 = S1.getIdInRoot(Row);
-            NodeId Id2 = S2.getIdInRoot(Col);
-            assert(DiffImpl.isMatchingPossible(Id1, Id2) &&
+            const Node &N1 = S1.getNode(Row);
+            const Node &N2 = S2.getNode(Col);
+            assert(DiffImpl.isMatchingPossible(N1, N2) &&
                    "These nodes must not be matched.");
-            Matches.emplace_back(Id1, Id2);
+            Matches.emplace_back(N1.getId(), N2.getId());
             --Row;
             --Col;
           } else {
@@ -643,7 +662,8 @@
   static constexpr double InsertionCost = 1;
 
   double getUpdateCost(SNodeId Id1, SNodeId Id2) {
-    if (!DiffImpl.isMatchingPossible(S1.getIdInRoot(Id1), S2.getIdInRoot(Id2)))
+    const Node &N1 = S1.getNode(Id1), N2 = S2.getNode(Id2);
+    if (!DiffImpl.isMatchingPossible(N1, N2))
       return std::numeric_limits<double>::max();
     return S1.getNodeValue(Id1) != S2.getNodeValue(Id2);
   }
@@ -684,6 +704,18 @@
   }
 };
 
+NodeId Node::getId() const { return this - &Tree.getRoot(); }
+SyntaxTree &Node::getTree() const { return *Tree.Parent; }
+const Node *Node::getParent() const {
+  if (Parent.isInvalid())
+    return nullptr;
+  return &Tree.getNode(Parent);
+}
+
+const Node &Node::getChild(size_t Index) const {
+  return Tree.getNode(Children[Index]);
+}
+
 ast_type_traits::ASTNodeKind Node::getType() const {
   return ASTNode.getNodeKind();
 }
@@ -706,11 +738,18 @@
   return llvm::None;
 }
 
+NodeRefIterator Node::begin() const {
+  return {&Tree, isLeaf() ? nullptr : &Children[0]};
+}
+NodeRefIterator Node::end() const {
+  return {&Tree, isLeaf() ? nullptr : &Children[0] + Children.size()};
+}
+
 namespace {
 // Compares nodes by their depth.
 struct HeightLess {
-  const SyntaxTree::Impl &Tree;
-  HeightLess(const SyntaxTree::Impl &Tree) : Tree(Tree) {}
+  SyntaxTree::Impl &Tree;
+  HeightLess(SyntaxTree::Impl &Tree) : Tree(Tree) {}
   bool operator()(NodeId Id1, NodeId Id2) const {
     return Tree.getNode(Id1).Height < Tree.getNode(Id2).Height;
   }
@@ -720,113 +759,111 @@
 namespace {
 // Priority queue for nodes, sorted descendingly by their height.
 class PriorityList {
-  const SyntaxTree::Impl &Tree;
+  SyntaxTree::Impl &Tree;
   HeightLess Cmp;
   std::vector<NodeId> Container;
   PriorityQueue<NodeId, std::vector<NodeId>, HeightLess> List;
 
 public:
-  PriorityList(const SyntaxTree::Impl &Tree)
+  PriorityList(SyntaxTree::Impl &Tree)
       : Tree(Tree), Cmp(Tree), List(Cmp, Container) {}
 
-  void push(NodeId id) { List.push(id); }
+  void push(NodeId Id) { List.push(Id); }
 
-  std::vector<NodeId> pop() {
+  NodeList pop() {
     int Max = peekMax();
-    std::vector<NodeId> Result;
+    NodeList Result(Tree);
     if (Max == 0)
       return Result;
     while (peekMax() == Max) {
       Result.push_back(List.top());
       List.pop();
     }
     // TODO this is here to get a stable output, not a good heuristic
-    std::sort(Result.begin(), Result.end());
+    Result.sort();
     return Result;
   }
   int peekMax() const {
     if (List.empty())
       return 0;
     return Tree.getNode(List.top()).Height;
   }
-  void open(NodeId Id) {
-    for (NodeId Child : Tree.getNode(Id).Children)
-      push(Child);
+  void open(const Node &N) {
+    for (const Node &Child : N)
+      push(Child.getId());
   }
 };
 } // end anonymous namespace
 
-bool ASTDiff::Impl::identical(NodeId Id1, NodeId Id2) const {
-  const Node &N1 = T1.getNode(Id1);
-  const Node &N2 = T2.getNode(Id2);
-  if (N1.Children.size() != N2.Children.size() ||
-      !isMatchingPossible(Id1, Id2) ||
-      T1.getNodeValue(Id1) != T2.getNodeValue(Id2))
+bool ASTDiff::Impl::identical(const Node &N1, const Node &N2) const {
+  if (N1.getNumChildren() != N2.getNumChildren() ||
+      !isMatchingPossible(N1, N2) || T1.getNodeValue(N1) != T2.getNodeValue(N2))
     return false;
-  for (size_t Id = 0, E = N1.Children.size(); Id < E; ++Id)
-    if (!identical(N1.Children[Id], N2.Children[Id]))
+  for (size_t Id = 0, E = N1.getNumChildren(); Id < E; ++Id)
+    if (!identical(N1.getChild(Id), N2.getChild(Id)))
       return false;
   return true;
 }
 
-bool ASTDiff::Impl::isMatchingPossible(NodeId Id1, NodeId Id2) const {
-  return Options.isMatchingAllowed(T1.getNode(Id1), T2.getNode(Id2));
+bool ASTDiff::Impl::isMatchingPossible(const Node &N1, const Node &N2) const {
+  return Options.isMatchingAllowed(N1, N2);
 }
 
-bool ASTDiff::Impl::haveSameParents(const Mapping &M, NodeId Id1,
-                                    NodeId Id2) const {
-  NodeId P1 = T1.getNode(Id1).Parent;
-  NodeId P2 = T2.getNode(Id2).Parent;
-  return (P1.isInvalid() && P2.isInvalid()) ||
-         (P1.isValid() && P2.isValid() && M.getDst(P1) == P2);
+bool ASTDiff::Impl::haveSameParents(const Mapping &M, const Node &N1,
+                                    const Node &N2) const {
+  const Node *P1 = N1.getParent();
+  const Node *P2 = N2.getParent();
+  return (!P1 && !P2) || (P1 && P2 && M.getDst(P1->getId()) == P2->getId());
 }
 
-void ASTDiff::Impl::addOptimalMapping(Mapping &M, NodeId Id1,
-                                      NodeId Id2) const {
-  if (std::max(T1.getNumberOfDescendants(Id1), T2.getNumberOfDescendants(Id2)) >
+void ASTDiff::Impl::addOptimalMapping(Mapping &M, const Node &N1,
+                                      const Node &N2) const {
+  if (std::max(T1.getNumberOfDescendants(N1), T2.getNumberOfDescendants(N2)) >
       Options.MaxSize)
     return;
-  ZhangShashaMatcher Matcher(*this, T1, T2, Id1, Id2);
+  ZhangShashaMatcher Matcher(*this, T1, T2, N1.getId(), N2.getId());
   std::vector<std::pair<NodeId, NodeId>> R = Matcher.getMatchingNodes();
   for (const auto Tuple : R) {
-    NodeId Src = Tuple.first;
-    NodeId Dst = Tuple.second;
-    if (!M.hasSrc(Src) && !M.hasDst(Dst))
-      M.link(Src, Dst);
+    const Node &N1 = T1.getNode(Tuple.first);
+    const Node &N2 = T2.getNode(Tuple.second);
+    if (!M.hasSrc(N1.getId()) && !M.hasDst(N2.getId()))
+      M.link(N1.getId(), N2.getId());
   }
 }
 
-double ASTDiff::Impl::getJaccardSimilarity(const Mapping &M, NodeId Id1,
-                                           NodeId Id2) const {
+double ASTDiff::Impl::getJaccardSimilarity(const Mapping &M, const Node &N1,
+                                           const Node &N2) const {
   int CommonDescendants = 0;
-  const Node &N1 = T1.getNode(Id1);
   // Count the common descendants, excluding the subtree root.
-  for (NodeId Src = Id1 + 1; Src <= N1.RightMostDescendant; ++Src) {
-    NodeId Dst = M.getDst(Src);
-    CommonDescendants += int(Dst.isValid() && T2.isInSubtree(Dst, Id2));
+  for (NodeId Src = N1.getId() + 1; Src <= N1.RightMostDescendant; ++Src) {
+    const Node *Dst = getMapped(T1.Parent->TreeImpl, T1.getNode(Src));
+    if (Dst)
+      CommonDescendants += T2.isInSubtree(*Dst, N2);
   }
-  // We need to subtract 1 to get the number of descendants excluding the root.
-  double Denominator = T1.getNumberOfDescendants(Id1) - 1 +
-                       T2.getNumberOfDescendants(Id2) - 1 - CommonDescendants;
+  // We need to subtract 1 to get the number of descendants excluding the
+  // root.
+  double Denominator = T1.getNumberOfDescendants(N1) - 1 +
+                       T2.getNumberOfDescendants(N2) - 1 - CommonDescendants;
   // CommonDescendants is less than the size of one subtree.
   assert(Denominator >= 0 && "Expected non-negative denominator.");
   if (Denominator == 0)
     return 0;
   return CommonDescendants / Denominator;
 }
 
-NodeId ASTDiff::Impl::findCandidate(const Mapping &M, NodeId Id1) const {
-  NodeId Candidate;
+const Node *ASTDiff::Impl::findCandidate(const Mapping &M,
+                                         const Node &N1) const {
+  const Node *Candidate = nullptr;
   double HighestSimilarity = 0.0;
-  for (NodeId Id2 : T2) {
-    if (!isMatchingPossible(Id1, Id2))
+  for (const Node &N2 : T2) {
+    if (!isMatchingPossible(N1, N2))
       continue;
-    if (M.hasDst(Id2))
+    if (M.hasDst(N2.getId()))
       continue;
-    double Similarity = getJaccardSimilarity(M, Id1, Id2);
+    double Similarity = getJaccardSimilarity(M, N1, N2);
     if (Similarity >= Options.MinSimilarity && Similarity > HighestSimilarity) {
       HighestSimilarity = Similarity;
-      Candidate = Id2;
+      Candidate = &N2;
     }
   }
   return Candidate;
@@ -837,9 +874,9 @@
   for (NodeId Id1 : Postorder) {
     if (Id1 == T1.getRootId() && !M.hasSrc(T1.getRootId()) &&
         !M.hasDst(T2.getRootId())) {
-      if (isMatchingPossible(T1.getRootId(), T2.getRootId())) {
+      if (isMatchingPossible(T1.getRoot(), T2.getRoot())) {
         M.link(T1.getRootId(), T2.getRootId());
-        addOptimalMapping(M, T1.getRootId(), T2.getRootId());
+        addOptimalMapping(M, T1.getRoot(), T2.getRoot());
       }
       break;
     }
@@ -850,15 +887,15 @@
                     [&](NodeId Child) { return M.hasSrc(Child); });
     if (Matched || !MatchedChildren)
       continue;
-    NodeId Id2 = findCandidate(M, Id1);
-    if (Id2.isValid()) {
-      M.link(Id1, Id2);
-      addOptimalMapping(M, Id1, Id2);
+    const Node *N2 = findCandidate(M, N1);
+    if (N2) {
+      M.link(N1.getId(), N2->getId());
+      addOptimalMapping(M, N1, *N2);
     }
   }
 }
 
-Mapping ASTDiff::Impl::matchTopDown() const {
+Mapping ASTDiff::Impl::matchTopDown() {
   PriorityList L1(T1);
   PriorityList L2(T2);
 
@@ -871,33 +908,32 @@
   while (std::min(Max1 = L1.peekMax(), Max2 = L2.peekMax()) >
          Options.MinHeight) {
     if (Max1 > Max2) {
-      for (NodeId Id : L1.pop())
-        L1.open(Id);
+      for (const Node &N1 : L1.pop())
+        L1.open(N1);
       continue;
     }
     if (Max2 > Max1) {
-      for (NodeId Id : L2.pop())
-        L2.open(Id);
+      for (const Node &N2 : L2.pop())
+        L2.open(N2);
       continue;
     }
-    std::vector<NodeId> H1, H2;
-    H1 = L1.pop();
-    H2 = L2.pop();
-    for (NodeId Id1 : H1) {
-      for (NodeId Id2 : H2) {
-        if (identical(Id1, Id2) && !M.hasSrc(Id1) && !M.hasDst(Id2)) {
-          for (int I = 0, E = T1.getNumberOfDescendants(Id1); I < E; ++I)
-            M.link(Id1 + I, Id2 + I);
+    NodeList H1 = L1.pop(), H2 = L2.pop();
+    for (const Node &N1 : H1) {
+      for (const Node &N2 : H2) {
+        if (identical(N1, N2) && !M.hasSrc(N1.getId()) &&
+            !M.hasDst(N2.getId())) {
+          for (int I = 0, E = T1.getNumberOfDescendants(N1); I < E; ++I)
+            M.link(N1.getId() + I, N2.getId() + I);
         }
       }
     }
-    for (NodeId Id1 : H1) {
-      if (!M.hasSrc(Id1))
-        L1.open(Id1);
+    for (const Node &N1 : H1) {
+      if (!M.hasSrc(N1.getId()))
+        L1.open(N1);
     }
-    for (NodeId Id2 : H2) {
-      if (!M.hasDst(Id2))
-        L2.open(Id2);
+    for (const Node &N2 : H2) {
+      if (!M.hasDst(N2.getId()))
+        L2.open(N2);
     }
   }
   return M;
@@ -918,56 +954,68 @@
 }
 
 void ASTDiff::Impl::computeChangeKinds(Mapping &M) {
-  for (NodeId Id1 : T1) {
-    if (!M.hasSrc(Id1)) {
-      T1.getMutableNode(Id1).Change = Delete;
-      T1.getMutableNode(Id1).Shift -= 1;
+  for (const Node &N1 : T1) {
+    if (!M.hasSrc(N1.getId())) {
+      T1.getMutableNode(N1.getId()).Change = Delete;
+      T1.getMutableNode(N1.getId()).Shift -= 1;
     }
   }
-  for (NodeId Id2 : T2) {
-    if (!M.hasDst(Id2)) {
-      T2.getMutableNode(Id2).Change = Insert;
-      T2.getMutableNode(Id2).Shift -= 1;
+  for (const Node &N2 : T2) {
+    if (!M.hasDst(N2.getId())) {
+      T2.getMutableNode(N2.getId()).Change = Insert;
+      T2.getMutableNode(N2.getId()).Shift -= 1;
     }
   }
-  for (NodeId Id1 : T1.NodesBfs) {
-    NodeId Id2 = M.getDst(Id1);
+  for (const Node &N1 : T1.NodesBfs) {
+    NodeId Id2 = M.getDst(N1.getId());
     if (Id2.isInvalid())
       continue;
-    if (!haveSameParents(M, Id1, Id2) ||
-        T1.findPositionInParent(Id1, true) !=
-            T2.findPositionInParent(Id2, true)) {
-      T1.getMutableNode(Id1).Shift -= 1;
-      T2.getMutableNode(Id2).Shift -= 1;
+    const Node &N2 = T2.getNode(Id2);
+    if (!haveSameParents(M, N1, N2) || T1.findPositionInParent(N1, true) !=
+                                           T2.findPositionInParent(N2, true)) {
+      T1.getMutableNode(N1).Shift -= 1;
+      T2.getMutableNode(N2).Shift -= 1;
     }
   }
-  for (NodeId Id2 : T2.NodesBfs) {
-    NodeId Id1 = M.getSrc(Id2);
+  for (const Node &N2TODO : T2.NodesBfs) {
+    NodeId Id1 = M.getSrc(N2TODO.getId());
     if (Id1.isInvalid())
       continue;
     Node &N1 = T1.getMutableNode(Id1);
-    Node &N2 = T2.getMutableNode(Id2);
+    Node &N2 = T2.getMutableNode(N2TODO.getId());
     if (Id1.isInvalid())
       continue;
-    if (!haveSameParents(M, Id1, Id2) ||
-        T1.findPositionInParent(Id1, true) !=
-            T2.findPositionInParent(Id2, true)) {
+    if (!haveSameParents(M, N1, N2) || T1.findPositionInParent(N1, true) !=
+                                           T2.findPositionInParent(N2, true)) {
       N1.Change = N2.Change = Move;
     }
-    if (T1.getNodeValue(Id1) != T2.getNodeValue(Id2)) {
+    if (T1.getNodeValue(N1) != T2.getNodeValue(N2)) {
       N1.Change = N2.Change = (N1.Change == Move ? UpdateMove : Update);
     }
   }
 }
 
+const Node *
+ASTDiff::Impl::getMapped(const std::unique_ptr<SyntaxTree::Impl> &Tree,
+                         const Node &N) const {
+  if (&*Tree == &T1) {
+    NodeId Id = TheMapping.getDst(N.getId());
+    return Id.isValid() ? &T2.getNode(Id) : nullptr;
+  }
+  assert(&*Tree == &T2 && "Invalid tree.");
+  NodeId Id = TheMapping.getSrc(N.getId());
+  return Id.isValid() ? &T1.getNode(Id) : nullptr;
+}
+
 ASTDiff::ASTDiff(SyntaxTree &T1, SyntaxTree &T2,
                  const ComparisonOptions &Options)
     : DiffImpl(llvm::make_unique<Impl>(*T1.TreeImpl, *T2.TreeImpl, Options)) {}
 
 ASTDiff::~ASTDiff() = default;
 
-NodeId ASTDiff::getMapped(const SyntaxTree &SourceTree, NodeId Id) const {
-  return DiffImpl->getMapped(SourceTree.TreeImpl, Id);
+const Node *ASTDiff::getMapped(const SyntaxTree &SourceTree,
+                               const Node &N) const {
+  return DiffImpl->getMapped(SourceTree.TreeImpl, N);
 }
 
 SyntaxTree::SyntaxTree(ASTContext &AST)
@@ -983,36 +1031,32 @@
 }
 
 int SyntaxTree::getSize() const { return TreeImpl->getSize(); }
-NodeId SyntaxTree::getRootId() const { return TreeImpl->getRootId(); }
+const Node &SyntaxTree::getRoot() const { return TreeImpl->getRoot(); }
 SyntaxTree::PreorderIterator SyntaxTree::begin() const {
   return TreeImpl->begin();
 }
 SyntaxTree::PreorderIterator SyntaxTree::end() const { return TreeImpl->end(); }
 
-int SyntaxTree::findPositionInParent(NodeId Id) const {
-  return TreeImpl->findPositionInParent(Id);
+int SyntaxTree::findPositionInParent(const Node &N) const {
+  return TreeImpl->findPositionInParent(N);
 }
 
 std::pair<unsigned, unsigned>
 SyntaxTree::getSourceRangeOffsets(const Node &N) const {
-  const SourceManager &SrcMgr = TreeImpl->AST.getSourceManager();
+  const SourceManager &SM = TreeImpl->AST.getSourceManager();
   SourceRange Range = N.ASTNode.getSourceRange();
   SourceLocation BeginLoc = Range.getBegin();
   SourceLocation EndLoc = Lexer::getLocForEndOfToken(
-      Range.getEnd(), /*Offset=*/0, SrcMgr, TreeImpl->AST.getLangOpts());
+      Range.getEnd(), /*Offset=*/0, SM, TreeImpl->AST.getLangOpts());
   if (auto *ThisExpr = N.ASTNode.get<CXXThisExpr>()) {
     if (ThisExpr->isImplicit())
       EndLoc = BeginLoc;
   }
-  unsigned Begin = SrcMgr.getFileOffset(SrcMgr.getExpansionLoc(BeginLoc));
-  unsigned End = SrcMgr.getFileOffset(SrcMgr.getExpansionLoc(EndLoc));
+  unsigned Begin = SM.getFileOffset(SM.getExpansionLoc(BeginLoc));
+  unsigned End = SM.getFileOffset(SM.getExpansionLoc(EndLoc));
   return {Begin, End};
 }
 
-std::string SyntaxTree::getNodeValue(NodeId Id) const {
-  return TreeImpl->getNodeValue(Id);
-}
-
 std::string SyntaxTree::getNodeValue(const Node &N) const {
   return TreeImpl->getNodeValue(N);
 }
Index: include/clang/Tooling/ASTDiff/ASTDiffInternal.h
===================================================================
--- include/clang/Tooling/ASTDiff/ASTDiffInternal.h
+++ include/clang/Tooling/ASTDiff/ASTDiffInternal.h
@@ -19,8 +19,9 @@
 using DynTypedNode = ast_type_traits::DynTypedNode;
 
 class SyntaxTree;
-class SyntaxTreeImpl;
 struct ComparisonOptions;
+struct Node;
+struct NodeRefIterator;
 
 /// Within a tree, this identifies a node by its preorder offset.
 struct NodeId {
@@ -36,8 +37,6 @@
   operator int() const { return Id; }
   NodeId &operator++() { return ++Id, *this; }
   NodeId &operator--() { return --Id, *this; }
-  // Support defining iterators on NodeId.
-  NodeId &operator*() { return *this; }
 
   bool isValid() const { return Id != InvalidNodeId; }
   bool isInvalid() const { return Id == InvalidNodeId; }
Index: include/clang/Tooling/ASTDiff/ASTDiff.h
===================================================================
--- include/clang/Tooling/ASTDiff/ASTDiff.h
+++ include/clang/Tooling/ASTDiff/ASTDiff.h
@@ -34,28 +34,14 @@
   UpdateMove // Same as Move plus Update.
 };
 
-/// Represents a Clang AST node, alongside some additional information.
-struct Node {
-  NodeId Parent, LeftMostDescendant, RightMostDescendant;
-  int Depth, Height, Shift = 0;
-  ast_type_traits::DynTypedNode ASTNode;
-  SmallVector<NodeId, 4> Children;
-  ChangeKind Change = None;
-
-  ast_type_traits::ASTNodeKind getType() const;
-  StringRef getTypeLabel() const;
-  bool isLeaf() const { return Children.empty(); }
-  llvm::Optional<StringRef> getIdentifier() const;
-  llvm::Optional<std::string> getQualifiedIdentifier() const;
-};
 
 class ASTDiff {
 public:
   ASTDiff(SyntaxTree &Src, SyntaxTree &Dst, const ComparisonOptions &Options);
   ~ASTDiff();
 
   // Returns the ID of the node that is mapped to the given node in SourceTree.
-  NodeId getMapped(const SyntaxTree &SourceTree, NodeId Id) const;
+  const Node *getMapped(const SyntaxTree &SourceTree, const Node &N) const;
 
   class Impl;
 
@@ -80,27 +66,63 @@
   StringRef getFilename() const;
 
   int getSize() const;
-  NodeId getRootId() const;
-  using PreorderIterator = NodeId;
+  const Node &getRoot() const;
+  using PreorderIterator = const Node *;
   PreorderIterator begin() const;
   PreorderIterator end() const;
 
   const Node &getNode(NodeId Id) const;
-  int findPositionInParent(NodeId Id) const;
+  int findPositionInParent(const Node &Node) const;
 
   // Returns the starting and ending offset of the node in its source file.
   std::pair<unsigned, unsigned> getSourceRangeOffsets(const Node &N) const;
 
   /// Serialize the node attributes to a string representation. This should
-  /// uniquely distinguish nodes of the same kind. Note that this function just
+  /// uniquely distinguish nodes of the same kind. Note that this function
+  /// just
   /// returns a representation of the node value, not considering descendants.
-  std::string getNodeValue(NodeId Id) const;
   std::string getNodeValue(const Node &Node) const;
 
   class Impl;
   std::unique_ptr<Impl> TreeImpl;
 };
 
+/// Represents a Clang AST node, alongside some additional information.
+struct Node {
+  SyntaxTree::Impl &Tree;
+  NodeId Parent, LeftMostDescendant, RightMostDescendant;
+  int Depth, Height, Shift = 0;
+  ast_type_traits::DynTypedNode ASTNode;
+  SmallVector<NodeId, 4> Children;
+  ChangeKind Change = None;
+  Node(SyntaxTree::Impl &Tree) : Tree(Tree), Children() {}
+
+  NodeId getId() const;
+  SyntaxTree &getTree() const;
+  const Node *getParent() const;
+  const Node &getChild(size_t Index) const;
+  size_t getNumChildren() const { return Children.size(); }
+  ast_type_traits::ASTNodeKind getType() const;
+  StringRef getTypeLabel() const;
+  bool isLeaf() const { return Children.empty(); }
+  llvm::Optional<StringRef> getIdentifier() const;
+  llvm::Optional<std::string> getQualifiedIdentifier() const;
+
+  NodeRefIterator begin() const;
+  NodeRefIterator end() const;
+};
+
+struct NodeRefIterator {
+  SyntaxTree::Impl *Tree;
+  const NodeId *IdPointer;
+  NodeRefIterator(SyntaxTree::Impl *Tree, const NodeId *IdPointer)
+      : Tree(Tree), IdPointer(IdPointer) {}
+  const Node &operator*() const;
+  NodeRefIterator &operator++();
+  NodeRefIterator &operator+(int Offset);
+  bool operator!=(const NodeRefIterator &Other) const;
+};
+
 struct ComparisonOptions {
   /// During top-down matching, only consider nodes of at least this height.
   int MinHeight = 2;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to