sammccall created this revision.
sammccall added a reviewer: hokein.
Herald added a subscriber: mgrang.
Herald added a project: All.
sammccall requested review of this revision.
Herald added subscribers: cfe-commits, alextsao1999.
Herald added a project: clang-tools-extra.
The actions table is very compact but the binary search to find the
correct action is relatively expensive.
A hashtable is faster but pretty large (64 bits per value, plus empty
slots, and lookup is constant time but not trivial due to collisions).

The structure in this patch uses 1.25 bits per entry (whether present or absent)
plus the size of the values, and lookup is trivial.

The Shift table is 119KB = 27KB values + 92KB keys.
The Goto table is 86KB = 30KB values + 57KB keys.
(Goto has a smaller keyspace as #nonterminals < #terminals, and more entries).

This patch improves glrParse speed by 28%: 4.69 => 5.99 MB/s
Overall the table grows by 60%: 142 => 228KB.

By comparison, DenseMap<unsigned, StateID> is "only" 16% faster (5.43 MB/s),
and results in a 285% larger table (547 KB) vs the baseline.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D128485

Files:
  clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
  clang-tools-extra/pseudo/lib/GLR.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
  clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
  clang-tools-extra/pseudo/unittests/LRTableTest.cpp

Index: clang-tools-extra/pseudo/unittests/LRTableTest.cpp
===================================================================
--- clang-tools-extra/pseudo/unittests/LRTableTest.cpp
+++ clang-tools-extra/pseudo/unittests/LRTableTest.cpp
@@ -61,7 +61,7 @@
 
   EXPECT_EQ(T.getShiftState(1, Eof), llvm::None);
   EXPECT_EQ(T.getShiftState(1, Semi), llvm::None);
-  EXPECT_EQ(T.getGoToState(1, Term), 3);
+  EXPECT_THAT(T.getGoToState(1, Term), ValueIs(3));
   EXPECT_THAT(T.getReduceRules(1), ElementsAre(2));
 
   // Verify the behaivor for other non-available-actions terminals.
Index: clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp
@@ -11,6 +11,7 @@
 #include "clang-pseudo/grammar/LRTable.h"
 #include "clang/Basic/TokenKinds.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 
 namespace llvm {
@@ -45,49 +46,23 @@
       : StartStates(StartStates) {}
 
   bool insert(Entry E) { return Entries.insert(std::move(E)).second; }
-  LRTable build(unsigned NumStates) && {
-    // E.g. given the following parsing table with 3 states and 3 terminals:
-    //
-    //            a    b     c
-    // +-------+----+-------+-+
-    // |state0 |    | s0,r0 | |
-    // |state1 | acc|       | |
-    // |state2 |    |  r1   | |
-    // +-------+----+-------+-+
-    //
-    // The final LRTable:
-    //  - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4
-    //  - Symbols:     [ b,   b,   a,  b]
-    //    Actions:     [ s0, r0, acc, r1]
-    //                   ~~~~~~ range for state 0
-    //                           ~~~~ range for state 1
-    //                                ~~ range for state 2
-    // First step, we sort all entries by (State, Symbol, Action).
-    std::vector<Entry> Sorted(Entries.begin(), Entries.end());
-    llvm::sort(Sorted, [](const Entry &L, const Entry &R) {
-      return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) <
-             std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque());
-    });
-
+  LRTable build(unsigned NumStates, unsigned NumNonterminals) && {
     LRTable Table;
-    Table.Actions.reserve(Sorted.size());
-    Table.Symbols.reserve(Sorted.size());
-    // We are good to finalize the States and Actions.
-    for (const auto &E : Sorted) {
-      Table.Actions.push_back(E.Act);
-      Table.Symbols.push_back(E.Symbol);
-    }
-    // Initialize the terminal and nonterminal offset, all ranges are empty by
-    // default.
-    Table.StateOffset = std::vector<uint32_t>(NumStates + 1, 0);
-    size_t SortedIndex = 0;
-    for (StateID State = 0; State < Table.StateOffset.size(); ++State) {
-      Table.StateOffset[State] = SortedIndex;
-      while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State)
-        ++SortedIndex;
-    }
     Table.StartStates = std::move(StartStates);
 
+    // Compile the goto and shift actions into sparse tables.
+    llvm::DenseMap<unsigned, SymbolID> Gotos;
+    llvm::DenseMap<unsigned, SymbolID> Shifts;
+    for (const auto &E : Entries) {
+      if (E.Act.kind() == Action::Shift)
+        Shifts.try_emplace(symbolToToken(E.Symbol) * NumStates + E.State,
+                           E.Act.getShiftState());
+      else if (E.Act.kind() == Action::GoTo)
+        Gotos.try_emplace(E.Symbol * NumStates + E.State, E.Act.getGoToState());
+    }
+    Table.Shifts = SparseTable(Shifts, NumStates * NumTerminals);
+    Table.Gotos = SparseTable(Gotos, NumStates * NumNonterminals);
+
     // Compile the follow sets into a bitmap.
     Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size());
     for (SymbolID NT = 0; NT < FollowSets.size(); ++NT)
@@ -136,7 +111,8 @@
   for (const ReduceEntry &E : Reduces)
     Build.Reduces[E.State].insert(E.Rule);
   Build.FollowSets = followSets(G);
-  return std::move(Build).build(/*NumStates=*/MaxState + 1);
+  return std::move(Build).build(/*NumStates=*/MaxState + 1,
+                                G.table().Nonterminals.size());
 }
 
 LRTable LRTable::buildSLR(const Grammar &G) {
@@ -163,7 +139,8 @@
         Build.Reduces[SID].insert(I.rule());
     }
   }
-  return std::move(Build).build(Graph.states().size());
+  return std::move(Build).build(Graph.states().size(),
+                                G.table().Nonterminals.size());
 }
 
 } // namespace pseudo
Index: clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
+++ clang-tools-extra/pseudo/lib/grammar/LRTable.cpp
@@ -34,12 +34,12 @@
   return llvm::formatv(R"(
 Statistics of the LR parsing table:
     number of states: {0}
-    number of actions: {1}
-    number of reduces: {2}
-    size of the table (bytes): {3}
+    number of actions: shift={1} goto={2} reduce={3}
+    size of Shifts {5}, Gotos={6}
+    size of the table (bytes): {4}
 )",
-                       StateOffset.size() - 1, Actions.size(), Reduces.size(),
-                       bytes())
+                       numStates(), Shifts.size(), Gotos.size(), Reduces.size(),
+                       bytes(), Shifts.bytes(), Gotos.bytes())
       .str();
 }
 
@@ -47,15 +47,13 @@
   std::string Result;
   llvm::raw_string_ostream OS(Result);
   OS << "LRTable:\n";
-  for (StateID S = 0; S < StateOffset.size() - 1; ++S) {
+  for (StateID S = 0; S < numStates(); ++S) {
     OS << llvm::formatv("State {0}\n", S);
     for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) {
       SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
-      for (auto A : find(S, TokID)) {
-        if (A.kind() == LRTable::Action::Shift)
-          OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
-                                        G.symbolName(TokID), A.getShiftState());
-      }
+      if (auto SS = getShiftState(S, TokID))
+        OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
+                                      G.symbolName(TokID), SS);
     }
     for (RuleID R : getReduceRules(S)) {
       SymbolID Target = G.lookupRule(R).Target;
@@ -71,55 +69,15 @@
     }
     for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
          ++NontermID) {
-      if (find(S, NontermID).empty())
-        continue;
-      OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
-                                    G.symbolName(NontermID),
-                                    getGoToState(S, NontermID));
+      if (auto GS = getGoToState(S, NontermID)) {
+        OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
+                                      G.symbolName(NontermID), *GS);
+      }
     }
   }
   return OS.str();
 }
 
-llvm::Optional<LRTable::StateID>
-LRTable::getShiftState(StateID State, SymbolID Terminal) const {
-  // FIXME: we spend a significant amount of time on misses here.
-  // We could consider storing a std::bitset for a cheaper test?
-  assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
-  for (const auto &Result : find(State, Terminal))
-    if (Result.kind() == Action::Shift)
-      return Result.getShiftState(); // unique: no shift/shift conflicts.
-  return llvm::None;
-}
-
-LRTable::StateID LRTable::getGoToState(StateID State,
-                                       SymbolID Nonterminal) const {
-  assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");
-  auto Result = find(State, Nonterminal);
-  assert(Result.size() == 1 && Result.front().kind() == Action::GoTo);
-  return Result.front().getGoToState();
-}
-
-llvm::ArrayRef<LRTable::Action> LRTable::find(StateID Src, SymbolID ID) const {
-  assert(Src + 1u < StateOffset.size());
-  std::pair<size_t, size_t> Range =
-      std::make_pair(StateOffset[Src], StateOffset[Src + 1]);
-  auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first,
-                                        Symbols.data() + Range.second);
-
-  assert(llvm::is_sorted(SymbolRange) &&
-         "subrange of the Symbols should be sorted!");
-  const LRTable::StateID *Start =
-      llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; });
-  if (Start == SymbolRange.end())
-    return {};
-  const LRTable::StateID *End = Start;
-  while (End != SymbolRange.end() && *End == ID)
-    ++End;
-  return llvm::makeArrayRef(&Actions[Start - Symbols.data()],
-                            /*length=*/End - Start);
-}
-
 LRTable::StateID LRTable::getStartState(SymbolID Target) const {
   assert(llvm::is_sorted(StartStates) && "StartStates must be sorted!");
   auto It = llvm::partition_point(
Index: clang-tools-extra/pseudo/lib/GLR.cpp
===================================================================
--- clang-tools-extra/pseudo/lib/GLR.cpp
+++ clang-tools-extra/pseudo/lib/GLR.cpp
@@ -318,9 +318,11 @@
     do {
       const PushSpec &Push = Sequences.top().second;
       FamilySequences.emplace_back(Sequences.top().first.Rule, *Push.Seq);
-      for (const GSS::Node *Base : Push.LastPop->parents())
-        FamilyBases.emplace_back(
-            Params.Table.getGoToState(Base->State, F.Symbol), Base);
+      for (const GSS::Node *Base : Push.LastPop->parents()) {
+        auto NextState = Params.Table.getGoToState(Base->State, F.Symbol);
+        assert(NextState.hasValue() && "goto must succeed after reduce!");
+        FamilyBases.emplace_back(*NextState, Base);
+      }
 
       Sequences.pop();
     } while (!Sequences.empty() && Sequences.top().first == F);
@@ -391,8 +393,9 @@
     }
     const ForestNode *Parsed =
         &Params.Forest.createSequence(Rule.Target, *RID, TempSequence);
-    StateID NextState = Params.Table.getGoToState(Base->State, Rule.Target);
-    Heads->push_back(Params.GSStack.addNode(NextState, Parsed, {Base}));
+    auto NextState = Params.Table.getGoToState(Base->State, Rule.Target);
+    assert(NextState.hasValue() && "goto must succeed after reduce!");
+    Heads->push_back(Params.GSStack.addNode(*NextState, Parsed, {Base}));
     return true;
   }
 };
@@ -442,7 +445,8 @@
   }
   LLVM_DEBUG(llvm::dbgs() << llvm::formatv("Reached eof\n"));
 
-  StateID AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  auto AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
+  assert(AcceptState.hasValue() && "goto must succeed after start symbol!");
   const ForestNode *Result = nullptr;
   for (const auto *Head : Heads) {
     if (Head->State == AcceptState) {
Index: clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
===================================================================
--- clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
+++ clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h
@@ -40,6 +40,8 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/Support/Capacity.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
 #include <cstdint>
 #include <vector>
 
@@ -124,11 +126,17 @@
   // Returns the state after we reduce a nonterminal.
   // Expected to be called by LR parsers.
   // REQUIRES: Nonterminal is valid here.
-  StateID getGoToState(StateID State, SymbolID Nonterminal) const;
+  llvm::Optional<StateID> getGoToState(StateID State,
+                                       SymbolID Nonterminal) const {
+    return Gotos.get(numStates() * Nonterminal + State);
+  }
   // Returns the state after we shift a terminal.
   // Expected to be called by LR parsers.
   // If the terminal is invalid here, returns None.
-  llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
+  llvm::Optional<StateID> getShiftState(StateID State,
+                                        SymbolID Terminal) const {
+    return Shifts.get(numStates() * symbolToToken(Terminal) + State);
+  }
 
   // Returns the possible reductions from a state.
   // These are not keyed by a lookahead token. Instead, call canFollow() to
@@ -157,9 +165,7 @@
   StateID getStartState(SymbolID StartSymbol) const;
 
   size_t bytes() const {
-    return sizeof(*this) + llvm::capacity_in_bytes(Actions) +
-           llvm::capacity_in_bytes(Symbols) +
-           llvm::capacity_in_bytes(StateOffset) +
+    return sizeof(*this) + Gotos.bytes() + Shifts.bytes() +
            llvm::capacity_in_bytes(Reduces) +
            llvm::capacity_in_bytes(ReduceOffset) +
            llvm::capacity_in_bytes(FollowSets);
@@ -187,22 +193,79 @@
                                llvm::ArrayRef<ReduceEntry>);
 
 private:
-  // Looks up actions stored in the generic table.
-  llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
-
-  // Conceptually the LR table is a multimap from (State, SymbolID) => Action.
-  // Our physical representation is quite different for compactness.
-
-  // Index is StateID, value is the offset into Symbols/Actions
-  // where the entries for this state begin.
-  // Give a state id, the corresponding half-open range of Symbols/Actions is
-  // [StateOffset[id], StateOffset[id+1]).
-  std::vector<uint32_t> StateOffset;
-  // Parallel to Actions, the value is SymbolID (columns of the matrix).
-  // Grouped by the StateID, and only subranges are sorted.
-  std::vector<SymbolID> Symbols;
-  // A flat list of available actions, sorted by (State, SymbolID).
-  std::vector<Action> Actions;
+  unsigned numStates() const { return ReduceOffset.size() - 1; }
+
+  // A map from unsigned key => StateID, used to store actions.
+  // The keys should be sequential but the values are somewhat sparse.
+  //
+  // We store one bit for presence/absence of the value for each key.
+  // At every 64th key, we store the offset into the table of values.
+  //   e.g. key 0x500 is checkpoint 0x500/64 = 20
+  //                     Checkpoints[20] = 34
+  //        get(0x500) = Values[34]                (assuming it has a value)
+  // To look up values in between, we count the set bits:
+  //        get(0x509) has a value if HasValue[20] & (1<<9)
+  //        #values between 0x500 and 0x509: popcnt(HasValue[20] & (1<<9 - 1))
+  //        get(0x509) = Values[34 + popcnt(...)]
+  //
+  // Overall size is 1.25 bits/key + 16 bits/value.
+  // Lookup is constant time with a low factor (no hashing).
+  class SparseTable {
+    using Word = uint64_t;
+    constexpr static unsigned WordBits = CHAR_BIT * sizeof(Word);
+
+    std::vector<StateID> Values;
+    std::vector<Word> HasValue;
+    std::vector<uint16_t> Checkpoints;
+
+  public:
+    SparseTable() = default;
+    SparseTable(const llvm::DenseMap<unsigned, StateID> &V, unsigned NumKeys) {
+      assert(
+          V.size() <
+              std::numeric_limits<decltype(Checkpoints)::value_type>::max() &&
+          "16 bits too small for value offsets!");
+      unsigned NumWords = (NumKeys + WordBits - 1) / WordBits;
+      HasValue.resize(NumWords, 0);
+      Checkpoints.reserve(NumWords);
+      Values.reserve(V.size());
+      for (unsigned I = 0; I < NumKeys; ++I) {
+        if ((I % WordBits) == 0)
+          Checkpoints.push_back(Values.size());
+        auto It = V.find(I);
+        if (It != V.end()) {
+          HasValue[I / WordBits] |= (Word(1) << (I % WordBits));
+          Values.push_back(It->second);
+        }
+      }
+    }
+
+    llvm::Optional<StateID> get(unsigned Key) const {
+      // Do we have a value for this key?
+      Word KeyMask = Word(1) << (Key % WordBits);
+      unsigned KeyWord = Key / WordBits;
+      if ((HasValue[KeyWord] & KeyMask) == 0)
+        return llvm::None;
+      // Count the number of values since the checkpoint.
+      Word BelowKeyMask = KeyMask - 1;
+      unsigned CountSinceCheckpoint =
+          llvm::countPopulation(HasValue[KeyWord] & BelowKeyMask);
+      // Find the value relative to the last checkpoint.
+      return Values[Checkpoints[KeyWord] + CountSinceCheckpoint];
+    }
+
+    unsigned size() const { return Values.size(); }
+
+    size_t bytes() const {
+      return llvm::capacity_in_bytes(HasValue) +
+             llvm::capacity_in_bytes(Values) +
+             llvm::capacity_in_bytes(Checkpoints);
+    }
+  };
+  // Shift and Goto tables are keyed by (State + Symbol*NumStates).
+  SparseTable Shifts;
+  SparseTable Gotos;
+
   // A sorted table, storing the start state for each target parsing symbol.
   std::vector<std::pair<SymbolID, StateID>> StartStates;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to