Module: Mesa
Branch: main
Commit: 2f467738193a8009cfd18c995aea13e63540062c
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=2f467738193a8009cfd18c995aea13e63540062c

Author: Rhys Perry <[email protected]>
Date:   Fri Oct 13 16:39:24 2023 +0100

nir/loop_analyze: scalarize try_eval_const_alu

This is simpler, and users of this function expected scalar anyway.

Signed-off-by: Rhys Perry <[email protected]>
Acked-by: Timothy Arceri <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26225>

---

 src/compiler/nir/nir_loop_analyze.c | 97 +++++++++++++++++--------------------
 1 file changed, 44 insertions(+), 53 deletions(-)

diff --git a/src/compiler/nir/nir_loop_analyze.c 
b/src/compiler/nir/nir_loop_analyze.c
index f468d2316e5..e06a8f61392 100644
--- a/src/compiler/nir/nir_loop_analyze.c
+++ b/src/compiler/nir/nir_loop_analyze.c
@@ -717,11 +717,11 @@ eval_const_binop(nir_op op, unsigned bit_size,
 }
 
 static int
-find_replacement(const nir_def **originals, const nir_def *key,
+find_replacement(const nir_scalar *originals, nir_scalar key,
                  unsigned num_replacements)
 {
    for (int i = 0; i < num_replacements; i++) {
-      if (originals[i] == key)
+      if (nir_scalar_equal(originals[i], key))
          return i;
    }
 
@@ -750,12 +750,14 @@ find_replacement(const nir_def **originals, const nir_def 
*key,
  * applying the previously described substitution) or false otherwise.
  */
 static bool
-try_eval_const_alu(nir_const_value *dest, nir_alu_instr *alu,
-                   const nir_def **originals,
-                   const nir_const_value **replacements,
+try_eval_const_alu(nir_const_value *dest, nir_scalar alu_s, const nir_scalar 
*originals,
+                   const nir_const_value *replacements,
                    unsigned num_replacements, unsigned execution_mode)
 {
-   nir_const_value src[NIR_MAX_VEC_COMPONENTS][NIR_MAX_VEC_COMPONENTS];
+   nir_alu_instr *alu = nir_instr_as_alu(alu_s.def->parent_instr);
+
+   if (nir_op_infos[alu->op].output_size)
+      return false;
 
    /* In the case that any outputs/inputs have unsized types, then we need to
     * guess the bit-size. In this case, the validator ensures that all
@@ -767,55 +769,42 @@ try_eval_const_alu(nir_const_value *dest, nir_alu_instr 
*alu,
     * (although it still requires to receive a valid bit-size).
     */
    unsigned bit_size = 0;
-   if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].output_type))
+   if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].output_type)) {
       bit_size = alu->def.bit_size;
+   } else {
+      for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
+         if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].input_types[i]))
+            bit_size = alu->src[i].src.ssa->bit_size;
+      }
 
-   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
-      if (bit_size == 0 &&
-          !nir_alu_type_get_type_size(nir_op_infos[alu->op].input_types[i]))
-         bit_size = alu->src[i].src.ssa->bit_size;
-
-      nir_instr *src_instr = alu->src[i].src.ssa->parent_instr;
+      if (bit_size == 0)
+         bit_size = 32;
+   }
 
-      if (src_instr->type == nir_instr_type_load_const) {
-         nir_load_const_instr *load_const = nir_instr_as_load_const(src_instr);
+   nir_const_value src[NIR_MAX_VEC_COMPONENTS];
+   nir_const_value *src_ptrs[NIR_MAX_VEC_COMPONENTS];
 
-         for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i);
-              j++) {
-            src[i][j] = load_const->value[alu->src[i].swizzle[j]];
-         }
-      } else {
-         int r = find_replacement(originals, alu->src[i].src.ssa,
-                                  num_replacements);
+   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
+      nir_scalar src_s = nir_scalar_chase_alu_src(alu_s, i);
 
-         if (r >= 0) {
-            for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i);
-                 j++) {
-               src[i][j] = replacements[r][alu->src[i].swizzle[j]];
-            }
-         } else if (src_instr->type == nir_instr_type_alu) {
-            memset(src[i], 0, sizeof(src[i]));
+      src_ptrs[i] = &src[i];
+      if (nir_scalar_is_const(src_s)) {
+         src[i] = nir_scalar_as_const_value(src_s);
+         continue;
+      }
 
-            if (!try_eval_const_alu(src[i], nir_instr_as_alu(src_instr),
-                                    originals, replacements, num_replacements,
-                                    execution_mode))
-               return false;
-         } else {
-            return false;
-         }
+      int r = find_replacement(originals, src_s, num_replacements);
+      if (r >= 0) {
+         src[i] = replacements[r];
+      } else if (!nir_scalar_is_alu(src_s) ||
+                 !try_eval_const_alu(&src[i], src_s,
+                                     originals, replacements,
+                                     num_replacements, execution_mode)) {
+         return false;
       }
    }
 
-   if (bit_size == 0)
-      bit_size = 32;
-
-   nir_const_value *srcs[NIR_MAX_VEC_COMPONENTS];
-
-   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i)
-      srcs[i] = src[i];
-
-   nir_eval_const_opcode(alu->op, dest, alu->def.num_components,
-                         bit_size, srcs, execution_mode);
+   nir_eval_const_opcode(alu->op, dest, 1, bit_size, src_ptrs, execution_mode);
 
    return true;
 }
@@ -931,13 +920,14 @@ get_iteration_empirical(nir_alu_instr *cond_alu, 
nir_alu_instr *incr_alu,
    nir_const_value result;
    nir_const_value iter = initial;
 
-   const nir_def *originals[2] = { basis, NULL };
-   const nir_const_value *replacements[2] = { &iter, NULL };
+   const nir_scalar original = nir_get_scalar(basis, 0);
+   const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0);
+   const nir_scalar incr = nir_get_scalar(&incr_alu->def, 0);
 
    while (iter_count <= max_unroll_iterations) {
       bool success;
 
-      success = try_eval_const_alu(&result, cond_alu, originals, replacements,
+      success = try_eval_const_alu(&result, cond, &original, &iter,
                                    1, execution_mode);
       if (!success)
          return -1;
@@ -948,7 +938,7 @@ get_iteration_empirical(nir_alu_instr *cond_alu, 
nir_alu_instr *incr_alu,
 
       iter_count++;
 
-      success = try_eval_const_alu(&result, incr_alu, originals, replacements,
+      success = try_eval_const_alu(&result, incr, &original, &iter,
                                    1, execution_mode);
       assert(success);
 
@@ -966,10 +956,11 @@ will_break_on_first_iteration(nir_alu_instr *cond_alu, 
nir_def *basis,
 {
    nir_const_value result;
 
-   const nir_def *originals[2] = { basis, limit_basis };
-   const nir_const_value *replacements[2] = { &initial, &limit };
+   const nir_scalar originals[2] = { nir_get_scalar(basis, 0), 
nir_get_scalar(limit_basis, 0) };
+   const nir_const_value replacements[2] = { initial, limit };
 
-   ASSERTED bool success = try_eval_const_alu(&result, cond_alu, originals,
+   const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0);
+   ASSERTED bool success = try_eval_const_alu(&result, cond, originals,
                                               replacements, 2, execution_mode);
 
    assert(success);

Reply via email to