Module: Mesa
Branch: main
Commit: 9df4703fbb59d1295a9d3daf6320f329c9de2d66
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=9df4703fbb59d1295a9d3daf6320f329c9de2d66

Author: Bas Nieuwenhuizen <[email protected]>
Date:   Sun Jul 16 01:01:55 2023 +0200

radv: Add cooperative matrix lowering.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24683>

---

 src/amd/vulkan/meson.build                         |   1 +
 src/amd/vulkan/nir/radv_nir.h                      |   2 +
 .../vulkan/nir/radv_nir_lower_cooperative_matrix.c | 435 +++++++++++++++++++++
 src/amd/vulkan/radv_shader.c                       |   4 +
 4 files changed, 442 insertions(+)

diff --git a/src/amd/vulkan/meson.build b/src/amd/vulkan/meson.build
index 8de1ccc9d74..b72fcfcb0fe 100644
--- a/src/amd/vulkan/meson.build
+++ b/src/amd/vulkan/meson.build
@@ -76,6 +76,7 @@ libradv_files = files(
   'nir/radv_nir_apply_pipeline_layout.c',
   'nir/radv_nir_export_multiview.c',
   'nir/radv_nir_lower_abi.c',
+  'nir/radv_nir_lower_cooperative_matrix.c',
   'nir/radv_nir_lower_fs_barycentric.c',
   'nir/radv_nir_lower_fs_intrinsics.c',
   'nir/radv_nir_lower_intrinsics_early.c',
diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h
index 279f161e71e..44a51d747e1 100644
--- a/src/amd/vulkan/nir/radv_nir.h
+++ b/src/amd/vulkan/nir/radv_nir.h
@@ -78,6 +78,8 @@ bool radv_nir_lower_io_to_mem(struct radv_device *device, 
struct radv_shader_sta
 
 void radv_nir_lower_poly_line_smooth(nir_shader *nir, const struct 
radv_pipeline_key *key);
 
+bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c 
b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
new file mode 100644
index 00000000000..d81231b0137
--- /dev/null
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -0,0 +1,435 @@
+/*
+ * Copyright © 2023 Bas Nieuwenhuizen
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "nir_builder.h"
+#include "radv_nir.h"
+
+static unsigned
+radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size)
+{
+   return desc.use != GLSL_CMAT_USE_ACCUMULATOR
+             ? 16
+             : (desc.cols * desc.rows / wave_size * 32 / 
glsl_base_type_bit_size(desc.element_type));
+}
+
+/* for C matrices we have 1 VGPR per element even if the element type is < 32 
bits. So with 8 fp16 elements we implement
+ * that with a f16vec16. We then use the coefficient generated by this 
function to figure out how many elements we
+ * really have.
+ */
+static unsigned
+radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
+{
+   return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / 
glsl_base_type_bit_size(desc.element_type)) : 1;
+}
+
+static unsigned
+radv_nir_cmat_bits(struct glsl_cmat_description desc)
+{
+   return glsl_base_type_bit_size(desc.element_type);
+}
+
+static nir_def *
+radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src)
+{
+   nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
+   struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
+   return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), 
glsl_base_type_bit_size(desc.element_type),
+                               src, 0);
+}
+
+static const struct glsl_type *
+radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct 
hash_table *type_map, unsigned wave_size)
+{
+   struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type);
+   if (entry) {
+      return entry->data;
+   } else if (glsl_type_is_cmat(orig_type)) {
+      struct glsl_cmat_description desc = 
*glsl_get_cmat_description(orig_type);
+      unsigned length = radv_nir_cmat_length(desc, wave_size);
+
+      return glsl_vector_type(desc.element_type, length);
+   } else if (glsl_type_is_array(orig_type)) {
+      const struct glsl_type *elem_type = glsl_get_array_element(orig_type);
+      const struct glsl_type *new_elem_type = 
radv_nir_translate_matrix_type(elem_type, type_map, wave_size);
+
+      if (elem_type == new_elem_type)
+         return orig_type;
+
+      return glsl_array_type(new_elem_type, glsl_get_length(orig_type), 
glsl_get_explicit_stride(orig_type));
+   } else if (glsl_type_is_struct(orig_type)) {
+      unsigned num_fields = glsl_get_length(orig_type);
+
+      bool change = false;
+      for (unsigned i = 0; i < num_fields; ++i) {
+         const struct glsl_type *field_type = glsl_get_struct_field(orig_type, 
i);
+         const struct glsl_type *new_field_type = 
radv_nir_translate_matrix_type(field_type, type_map, wave_size);
+
+         if (field_type != new_field_type) {
+            change = true;
+            break;
+         }
+      }
+
+      if (!change)
+         return orig_type;
+
+      struct glsl_struct_field *fields = malloc(sizeof(struct 
glsl_struct_field) * num_fields);
+
+      for (unsigned i = 0; i < num_fields; ++i) {
+         fields[i] = *glsl_get_struct_field_data(orig_type, i);
+
+         fields[i].type = radv_nir_translate_matrix_type(fields[i].type, 
type_map, wave_size);
+      }
+
+      const struct glsl_type *ret =
+         glsl_struct_type(fields, num_fields, glsl_get_type_name(orig_type), 
glsl_struct_type_is_packed(orig_type));
+      free(fields);
+
+      _mesa_hash_table_insert(type_map, orig_type, (void *)ret);
+      return ret;
+   } else
+      return orig_type;
+}
+
+bool
+radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
+{
+   bool progress = false;
+
+   if (!shader->info.cs.has_cooperative_matrix)
+      return false;
+
+   struct nir_function *func = (struct nir_function 
*)exec_list_get_head_const(&shader->functions);
+   struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL);
+
+   nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) {
+      const struct glsl_type *new_type = 
radv_nir_translate_matrix_type(var->type, type_map, wave_size);
+      if (new_type != var->type) {
+         var->type = new_type;
+         progress = true;
+      }
+   }
+
+   nir_foreach_function_temp_variable (var, func->impl) {
+      const struct glsl_type *new_type = 
radv_nir_translate_matrix_type(var->type, type_map, wave_size);
+      if (new_type != var->type) {
+         var->type = new_type;
+         progress = true;
+      }
+   }
+
+   nir_builder b = nir_builder_create(func->impl);
+
+   /* Iterate in reverse order so that lowering can still use the matrix types 
from the derefs before we change it. */
+   nir_foreach_block_reverse (block, func->impl) {
+      nir_foreach_instr_reverse_safe (instr, block) {
+         b.cursor = nir_before_instr(instr);
+
+         switch (instr->type) {
+         case nir_instr_type_intrinsic: {
+            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+            switch (intr->intrinsic) {
+            case nir_intrinsic_cmat_length: {
+               struct glsl_cmat_description desc = 
nir_intrinsic_cmat_desc(intr);
+               unsigned len = radv_nir_cmat_length(desc, wave_size) / 
radv_nir_cmat_length_mul(desc);
+               nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_extract: {
+               nir_deref_instr *src_deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(src_deref->type);
+               nir_def *src0 = radv_nir_load_cmat(&b, wave_size, 
intr->src[0].ssa);
+
+               nir_def *index = intr->src[1].ssa;
+               index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
+
+               nir_def *elem = nir_vector_extract(&b, src0, index);
+
+               nir_def_rewrite_uses(&intr->def, elem);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_insert: {
+               nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[2].ssa);
+               nir_deref_instr *dst_deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(dst_deref->type);
+               nir_def *index = intr->src[3].ssa;
+               index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
+
+               nir_def *elem = intr->src[1].ssa;
+               nir_def *r = nir_vector_insert(&b, src1, elem, index);
+               nir_store_deref(&b, dst_deref, r, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_construct: {
+               nir_deref_instr *dst_deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(dst_deref->type);
+               nir_def *elem = intr->src[1].ssa;
+
+               nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, 
wave_size));
+
+               nir_store_deref(&b, dst_deref, r, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_load: {
+               nir_deref_instr *dst_deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(dst_deref->type);
+               enum glsl_matrix_layout layout = 
nir_intrinsic_matrix_layout(intr);
+
+               nir_deref_instr *deref = 
nir_instr_as_deref(intr->src[1].ssa->parent_instr);
+               nir_def *stride = intr->src[2].ssa;
+
+               nir_def *local_idx = nir_load_subgroup_invocation(&b);
+               nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
+
+               /* A input is transposed */
+               if (desc.use == GLSL_CMAT_USE_A)
+                  layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? 
GLSL_MATRIX_LAYOUT_ROW_MAJOR
+                                                                     : 
GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
+
+               unsigned length = radv_nir_cmat_length(desc, wave_size);
+               unsigned mul = radv_nir_cmat_length_mul(desc);
+               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR 
? wave_size : 16;
+               nir_def *vars[16];
+               if (mul > 1) {
+                  for (unsigned i = 0; i < length; ++i)
+                     if (i % mul != 0)
+                        vars[i] = nir_undef(&b, 1, 
glsl_base_type_bit_size(desc.element_type));
+               }
+
+               unsigned idx_bits = deref->def.bit_size;
+               nir_def *base_row =
+                  desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, 
local_idx, 16) : nir_imm_int(&b, 0);
+
+               for (unsigned i = 0; i < length / mul; ++i) {
+                  nir_def *col_offset = inner_idx;
+                  nir_def *row_offset = nir_iadd_imm(&b, base_row, i * 
lanes_per_iter / 16);
+
+                  if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
+                     nir_def *tmp = col_offset;
+                     col_offset = row_offset;
+                     row_offset = tmp;
+                  }
+
+                  col_offset = nir_imul(&b, col_offset, stride);
+
+                  col_offset = nir_u2uN(&b, col_offset, idx_bits);
+                  row_offset = nir_u2uN(&b, row_offset, idx_bits);
+
+                  nir_deref_instr *iter_deref = 
nir_build_deref_ptr_as_array(&b, deref, col_offset);
+                  iter_deref =
+                     nir_build_deref_cast(&b, &iter_deref->def, deref->modes, 
glsl_scalar_type(desc.element_type),
+                                          
glsl_base_type_bit_size(desc.element_type) / 8);
+                  iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, 
row_offset);
+
+                  vars[i * mul] = nir_load_deref(&b, iter_deref);
+               }
+
+               nir_def *mat = nir_vec(&b, vars, length);
+               nir_store_deref(&b, dst_deref, mat, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_store: {
+               enum glsl_matrix_layout layout = 
nir_intrinsic_matrix_layout(intr);
+
+               nir_deref_instr *deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               nir_def *src = intr->src[1].ssa;
+               nir_def *stride = intr->src[2].ssa;
+
+               nir_deref_instr *src_deref = 
nir_instr_as_deref(src->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(src_deref->type);
+               src = radv_nir_load_cmat(&b, wave_size, src);
+
+               nir_def *local_idx = nir_load_subgroup_invocation(&b);
+
+               if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
+                  nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
+
+               nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
+
+               /* A input is transposed */
+               if (desc.use == GLSL_CMAT_USE_A)
+                  layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? 
GLSL_MATRIX_LAYOUT_ROW_MAJOR
+                                                                     : 
GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
+
+               unsigned length = radv_nir_cmat_length(desc, wave_size);
+               unsigned mul = radv_nir_cmat_length_mul(desc);
+               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR 
? wave_size : 16;
+               nir_def *vars[16];
+               for (unsigned i = 0; i < length; ++i)
+                  vars[i] = nir_channel(&b, src, i);
+
+               unsigned idx_bits = deref->def.bit_size;
+               nir_def *base_row =
+                  desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, 
local_idx, 16) : nir_imm_int(&b, 0);
+
+               for (unsigned i = 0; i < length / mul; ++i) {
+                  nir_def *col_offset = inner_idx;
+                  nir_def *row_offset = nir_iadd_imm(&b, base_row, i * 
lanes_per_iter / 16);
+
+                  if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
+                     nir_def *tmp = col_offset;
+                     col_offset = row_offset;
+                     row_offset = tmp;
+                  }
+
+                  col_offset = nir_imul(&b, col_offset, stride);
+
+                  col_offset = nir_u2uN(&b, col_offset, idx_bits);
+                  row_offset = nir_u2uN(&b, row_offset, idx_bits);
+
+                  nir_deref_instr *iter_deref = 
nir_build_deref_ptr_as_array(&b, deref, col_offset);
+                  iter_deref =
+                     nir_build_deref_cast(&b, &iter_deref->def, deref->modes, 
glsl_scalar_type(desc.element_type),
+                                          
glsl_base_type_bit_size(desc.element_type) / 8);
+                  iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, 
row_offset);
+
+                  nir_store_deref(&b, iter_deref, vars[i * mul], 1);
+               }
+
+               if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
+                  nir_pop_if(&b, NULL);
+
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_muladd: {
+               nir_def *A = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
+               nir_def *B = radv_nir_load_cmat(&b, wave_size, 
intr->src[2].ssa);
+               nir_def *C = radv_nir_load_cmat(&b, wave_size, 
intr->src[3].ssa);
+               nir_def *ret;
+
+               ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = 
nir_intrinsic_saturate(intr),
+                                         .cmat_signed_mask = 
nir_intrinsic_cmat_signed_mask(intr));
+
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_unary_op: {
+               nir_deref_instr *dst_deref = 
nir_instr_as_deref(intr->src[0].ssa->parent_instr);
+               nir_deref_instr *src_deref = 
nir_instr_as_deref(intr->src[1].ssa->parent_instr);
+               struct glsl_cmat_description desc = 
*glsl_get_cmat_description(dst_deref->type);
+               struct glsl_cmat_description src_desc = 
*glsl_get_cmat_description(src_deref->type);
+               nir_def *src = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
+               nir_op op = nir_intrinsic_alu_op(intr);
+
+               if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
+                   glsl_base_type_bit_size(desc.element_type) == 32 && 
desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
+                  nir_def *components[NIR_MAX_VEC_COMPONENTS];
+                  for (unsigned i = 0; i * 2 < src->num_components; ++i) {
+                     components[i] = nir_channel(&b, src, i * 2);
+                  }
+                  src = nir_vec(&b, components, src->num_components / 2);
+               }
+
+               nir_def *ret = nir_build_alu1(&b, op, src);
+
+               if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
+                   glsl_base_type_bit_size(desc.element_type) == 16 && 
desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
+                  nir_def *components[NIR_MAX_VEC_COMPONENTS];
+                  for (unsigned i = 0; i < ret->num_components; ++i) {
+                     components[i * 2] = nir_channel(&b, ret, i);
+                     components[i * 2 + 1] = nir_undef(&b, 1, 16);
+                  }
+                  ret = nir_vec(&b, components, ret->num_components * 2);
+               }
+
+               nir_store_deref(&b, dst_deref, ret, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_scalar_op: {
+               nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
+               nir_op op = nir_intrinsic_alu_op(intr);
+               nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_binary_op: {
+               nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
+               nir_def *src2 = radv_nir_load_cmat(&b, wave_size, 
intr->src[2].ssa);
+               nir_op op = nir_intrinsic_alu_op(intr);
+               nir_def *ret = nir_build_alu2(&b, op, src1, src2);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_bitcast: {
+               nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            case nir_intrinsic_cmat_copy: {
+               nir_build_copy_deref(&b, intr->src[0].ssa, intr->src[1].ssa);
+               nir_instr_remove(instr);
+               progress = true;
+               break;
+            }
+            default:
+               continue;
+            }
+            break;
+         }
+         case nir_instr_type_deref: {
+            nir_deref_instr *deref = nir_instr_as_deref(instr);
+            const struct glsl_type *new_type = 
radv_nir_translate_matrix_type(deref->type, type_map, wave_size);
+            if (new_type != deref->type) {
+               deref->type = new_type;
+               progress = true;
+            }
+            break;
+         }
+         default:
+            continue;
+         }
+      }
+   }
+
+   _mesa_hash_table_destroy(type_map, NULL);
+
+   if (progress) {
+      nir_metadata_preserve(func->impl, 0);
+   } else {
+      nir_metadata_preserve(func->impl, nir_metadata_all);
+   }
+
+   return progress;
+}
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index df770f6a267..f1a980617f8 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -447,6 +447,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const 
struct radv_shader_st
                .vk_memory_model_device_scope = true,
                .fragment_shading_rate = 
device->physical_device->rad_info.gfx_level >= GFX10_3,
                .workgroup_memory_explicit_layout = true,
+               .cooperative_matrix = true,
             },
          .ubo_addr_format = nir_address_format_vec2_index_32bit_offset,
          .ssbo_addr_format = nir_address_format_vec2_index_32bit_offset,
@@ -504,6 +505,8 @@ radv_shader_spirv_to_nir(struct radv_device *device, const 
struct radv_shader_st
        */
       NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
 
+      NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, subgroup_size);
+
       /* Split member structs.  We do this before lower_io_to_temporaries so 
that
        * it doesn't lower system values to temporaries by accident.
        */
@@ -533,6 +536,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const 
struct radv_shader_st
        * than it needs to be.
        */
       NIR_PASS(_, nir, nir_lower_global_vars_to_local);
+
       NIR_PASS(_, nir, nir_lower_vars_to_ssa);
 
       NIR_PASS(_, nir, nir_propagate_invariant, key->invariant_geom);

Reply via email to