Module: Mesa
Branch: main
Commit: 178603202959c528c28b6c18739b2335477b4e09
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=178603202959c528c28b6c18739b2335477b4e09

Author: Alyssa Rosenzweig <[email protected]>
Date:   Mon Nov  6 12:25:59 2023 -0400

nir/validate: Optimize ssa_srcs set

Profiling showed that maintaining this ssa_srcs set consumes ~3% of CTS time
with a debugoptimized build. Unfortunately, we really do benefit from getting
this coverage in CI. So rather than remove the validation, let's optimize the
data structure used so we can keep the coverage at a fraction of the cost.

The expensive piece is the pointer set, which is backed by a relatively
expensive hash table. It would be much cheaper to use an invasive set instead,
with a single "present" bit. We don't want to bloat nir_src for this, however
there's an easy solution: use a tagged pointer to steal a bit in the nir_src for
the job. We untag everything at the end of validation (and this meta-invariant
is asserted with an auxiliary counter), so while we mutate the IR while
validating, the mutations do not escape nir_validate.

We tag the parent pointer and not the def pointer, because it is dramatically
less used and therefore has far fewer disrupted call sites.

The M1 job is improved from 3:03 to 2:55 of deqp-runner reported time, which is
excellent.

Signed-off-by: Alyssa Rosenzweig <[email protected]>
Reviewed-by: Rhys Perry <[email protected]>
Reviewed-by: Faith Ekstrand <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26084>

---

 src/compiler/nir/nir_validate.c | 90 ++++++++++++++++++++++++++++-------------
 1 file changed, 61 insertions(+), 29 deletions(-)

diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c
index 3762b713ed9..74998b02fa2 100644
--- a/src/compiler/nir/nir_validate.c
+++ b/src/compiler/nir/nir_validate.c
@@ -73,8 +73,10 @@ typedef struct {
    /* Set of all blocks in the list */
    struct set *blocks;
 
-   /* Set of seen SSA sources */
-   struct set *ssa_srcs;
+   /* Number of tagged nir_src's. This is implicitly the cardinality of the set
+    * of pending nir_src's.
+    */
+   uint32_t nr_tagged_srcs;
 
    /* bitset of ssa definitions we have found; used to check uniqueness */
    BITSET_WORD *ssa_defs_found;
@@ -127,22 +129,57 @@ validate_num_components(validate_state *state, unsigned 
num_components)
    validate_assert(state, nir_num_components_valid(num_components));
 }
 
+/* Tag used in nir_src::_parent to indicate that a source has been seen. */
+#define SRC_TAG_SEEN (0x2)
+
+static_assert(SRC_TAG_SEEN == (~NIR_SRC_PARENT_MASK + 1),
+              "Parent pointer tags chosen not to collide");
+
+static void
+tag_src(nir_src *src, validate_state *state)
+{
+   /* nir_src only appears once and only in one SSA def use list, since we
+    * mark nir_src's as we go by tagging this pointer.
+    */
+   if (validate_assert(state, (src->_parent & SRC_TAG_SEEN) == 0)) {
+      src->_parent |= SRC_TAG_SEEN;
+      state->nr_tagged_srcs++;
+   }
+}
+
+/* Due to tagging, it's not safe to use nir_src_parent_instr during the main
+ * validate loop. This is a tagging-aware version.
+ */
+static nir_instr *
+src_parent_instr_safe(nir_src *src)
+{
+   uintptr_t untagged = (src->_parent & ~SRC_TAG_SEEN);
+   assert(!(untagged & NIR_SRC_PARENT_IS_IF) && "precondition");
+   return (nir_instr *)untagged;
+}
+
+/*
+ * As we walk SSA defs, we mark every use as seen by tagging the parent 
pointer.
+ * We need to make sure our use is seen in a use list.
+ *
+ * Then we unmark when we hit the source. This will let us prove that we've
+ * seen all the sources.
+ */
+static void
+validate_src_tag(nir_src *src, validate_state *state)
+{
+   if (validate_assert(state, src->_parent & SRC_TAG_SEEN)) {
+      src->_parent &= ~SRC_TAG_SEEN;
+      state->nr_tagged_srcs--;
+   }
+}
+
 static void
 validate_ssa_src(nir_src *src, validate_state *state,
                  unsigned bit_sizes, unsigned num_components)
 {
    validate_assert(state, src->ssa != NULL);
 
-   /* As we walk SSA defs, we add every use to this set.  We need to make sure
-    * our use is seen in a use list.
-    */
-   struct set_entry *entry = _mesa_set_search(state->ssa_srcs, src);
-   validate_assert(state, entry);
-
-   /* This will let us prove that we've seen all the sources */
-   if (entry)
-      _mesa_set_remove(state->ssa_srcs, entry);
-
    if (bit_sizes)
       validate_assert(state, src->ssa->bit_size & bit_sizes);
    if (num_components)
@@ -155,6 +192,9 @@ static void
 validate_src(nir_src *src, validate_state *state,
              unsigned bit_sizes, unsigned num_components)
 {
+   /* Validate the tag first, so that nir_src_parent_instr is valid */
+   validate_src_tag(src, state);
+
    if (state->instr)
       validate_assert(state, nir_src_parent_instr(src) == state->instr);
    else
@@ -197,12 +237,11 @@ validate_def(nir_def *def, validate_state *state,
 
    list_validate(&def->uses);
    nir_foreach_use_including_if(src, def) {
+      /* Check that the def matches. */
       validate_assert(state, src->ssa == def);
 
-      bool already_seen = false;
-      _mesa_set_search_and_add(state->ssa_srcs, src, &already_seen);
-      /* A nir_src should only appear once and only in one SSA def use list */
-      validate_assert(state, !already_seen);
+      /* Check that nir_src's are unique */
+      tag_src(src, state);
    }
 }
 
@@ -396,7 +435,7 @@ validate_deref_instr(nir_deref_instr *instr, validate_state 
*state)
       if (!validate_assert(state, !nir_src_is_if(use)))
          continue;
 
-      if (nir_src_parent_instr(use)->type == nir_instr_type_phi) {
+      if (src_parent_instr_safe(use)->type == nir_instr_type_phi) {
          validate_assert(state, !(instr->modes & (nir_var_shader_in |
                                                   nir_var_shader_out |
                                                   nir_var_shader_out |
@@ -1481,15 +1520,6 @@ validate_ssa_dominance(nir_function_impl *impl, 
validate_state *state)
 static void
 validate_function_impl(nir_function_impl *impl, validate_state *state)
 {
-   /* Resize the ssa_srcs set.  It's likely that the size of this set will
-    * never actually hit the number of SSA defs because we remove sources from
-    * the set as we visit them.  (It could actually be much larger because
-    * each SSA def can be used more than once.)  However, growing it now costs
-    * us very little (the extra memory is already dwarfed by the SSA defs
-    * themselves) and makes collisions much less likely.
-    */
-   _mesa_set_resize(state->ssa_srcs, impl->ssa_alloc);
-
    validate_assert(state, impl->function->impl == impl);
    validate_assert(state, impl->cf_node.parent == NULL);
 
@@ -1524,8 +1554,10 @@ validate_function_impl(nir_function_impl *impl, 
validate_state *state)
    }
    validate_end_block(impl->end_block, state);
 
-   validate_assert(state, state->ssa_srcs->entries == 0);
-   _mesa_set_clear(state->ssa_srcs, NULL);
+   /* We must have seen every source by now. This also means that we've 
untagged
+    * every source, so we have valid (unaugmented) NIR once again.
+    */
+   validate_assert(state, state->nr_tagged_srcs == 0);
 
    static int validate_dominance = -1;
    if (validate_dominance < 0) {
@@ -1551,11 +1583,11 @@ static void
 init_validate_state(validate_state *state)
 {
    state->mem_ctx = ralloc_context(NULL);
-   state->ssa_srcs = _mesa_pointer_set_create(state->mem_ctx);
    state->ssa_defs_found = NULL;
    state->blocks = _mesa_pointer_set_create(state->mem_ctx);
    state->var_defs = _mesa_pointer_hash_table_create(state->mem_ctx);
    state->errors = _mesa_pointer_hash_table_create(state->mem_ctx);
+   state->nr_tagged_srcs = 0;
 
    state->loop = NULL;
    state->in_loop_continue_construct = false;

Reply via email to