Module: Mesa
Branch: main
Commit: c5abb7c8d1556c2af3d5e2f498f5a7460aeecea9
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=c5abb7c8d1556c2af3d5e2f498f5a7460aeecea9

Author: Karol Herbst <[email protected]>
Date:   Tue Sep 19 14:44:26 2023 +0200

zink: variable shared mem support

Signed-off-by: Karol Herbst <[email protected]>
Reviewed-by: Mike Blumenkrantz <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24839>

---

 .../drivers/zink/nir_to_spirv/nir_to_spirv.c       | 29 ++++++++++++++++++----
 src/gallium/drivers/zink/zink_compiler.h           |  1 +
 src/gallium/drivers/zink/zink_pipeline.c           | 14 +++++++++--
 src/gallium/drivers/zink/zink_program.c            |  4 +++
 src/gallium/drivers/zink/zink_types.h              |  1 +
 5 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c 
b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index 37059f45807..d1f26cbe158 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -101,6 +101,8 @@ struct ntv_context {
          local_group_size_var,
          base_vertex_var, base_instance_var, draw_id_var;
 
+   SpvId shared_mem_size;
+
    SpvId subgroup_eq_mask_var,
          subgroup_ge_mask_var,
          subgroup_gt_mask_var,
@@ -663,13 +665,25 @@ get_scratch_block(struct ntv_context *ctx, unsigned 
bit_size)
 }
 
 static void
-create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned 
bit_size)
+create_shared_block(struct ntv_context *ctx, unsigned bit_size)
 {
    unsigned idx = bit_size >> 4;
    SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size);
-   unsigned block_size = shared_size / (bit_size / 8);
-   assert(block_size);
-   SpvId array = spirv_builder_type_array(&ctx->builder, type, 
emit_uint_const(ctx, 32, block_size));
+   SpvId array;
+
+   assert(gl_shader_stage_is_compute(ctx->nir->info.stage));
+   if (ctx->nir->info.cs.has_variable_shared_mem) {
+      assert(ctx->shared_mem_size);
+      SpvId const_shared_size = emit_uint_const(ctx, 32, 
ctx->nir->info.shared_size);
+      SpvId shared_mem_size = spirv_builder_emit_triop(&ctx->builder, 
SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpIAdd, 
const_shared_size, ctx->shared_mem_size);
+      shared_mem_size = spirv_builder_emit_triop(&ctx->builder, 
SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpUDiv, 
shared_mem_size, emit_uint_const(ctx, 32, bit_size / 8));
+      array = spirv_builder_type_array(&ctx->builder, type, shared_mem_size);
+   } else {
+      unsigned block_size = ctx->nir->info.shared_size / (bit_size / 8);
+      assert(block_size);
+      array = spirv_builder_type_array(&ctx->builder, type, 
emit_uint_const(ctx, 32, block_size));
+   }
+
    spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8);
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                SpvStorageClassWorkgroup,
@@ -686,7 +700,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size)
 {
    unsigned idx = bit_size >> 4;
    if (!ctx->shared_block_var[idx])
-      create_shared_block(ctx, ctx->nir->info.shared_size, bit_size);
+      create_shared_block(ctx, bit_size);
    if (ctx->sinfo->have_workgroup_memory_explicit_layout) {
       spirv_builder_emit_extension(&ctx->builder, 
"SPV_KHR_workgroup_memory_explicit_layout");
       spirv_builder_emit_cap(&ctx->builder, 
SpvCapabilityWorkgroupMemoryExplicitLayoutKHR);
@@ -4591,6 +4605,11 @@ nir_to_spirv(struct nir_shader *s, const struct 
zink_shader_info *sinfo, uint32_
             spirv_builder_emit_builtin(&ctx.builder, ctx.local_group_size_var, 
SpvBuiltInWorkgroupSize);
          }
       }
+      if (s->info.cs.has_variable_shared_mem) {
+         ctx.shared_mem_size = spirv_builder_spec_const_uint(&ctx.builder, 32);
+         spirv_builder_emit_specid(&ctx.builder, ctx.shared_mem_size, 
ZINK_VARIABLE_SHARED_MEM);
+         spirv_builder_emit_name(&ctx.builder, ctx.shared_mem_size, 
"variable_shared_mem");
+      }
       if (s->info.cs.derivative_group) {
          SpvCapability caps[] = { 0, 
SpvCapabilityComputeDerivativeGroupQuadsNV, 
SpvCapabilityComputeDerivativeGroupLinearNV };
          SpvExecutionMode modes[] = { 0, 
SpvExecutionModeDerivativeGroupQuadsNV, SpvExecutionModeDerivativeGroupLinearNV 
};
diff --git a/src/gallium/drivers/zink/zink_compiler.h 
b/src/gallium/drivers/zink/zink_compiler.h
index 834084dbc51..1319193f83c 100644
--- a/src/gallium/drivers/zink/zink_compiler.h
+++ b/src/gallium/drivers/zink/zink_compiler.h
@@ -29,6 +29,7 @@
 #define ZINK_WORKGROUP_SIZE_X 1
 #define ZINK_WORKGROUP_SIZE_Y 2
 #define ZINK_WORKGROUP_SIZE_Z 3
+#define ZINK_VARIABLE_SHARED_MEM 4
 #define ZINK_INLINE_VAL_FLAT_MASK 0
 #define ZINK_INLINE_VAL_PV_LAST_VERT 1
 
diff --git a/src/gallium/drivers/zink/zink_pipeline.c 
b/src/gallium/drivers/zink/zink_pipeline.c
index cceb54c4a2c..063fdd9d73f 100644
--- a/src/gallium/drivers/zink/zink_pipeline.c
+++ b/src/gallium/drivers/zink/zink_pipeline.c
@@ -457,8 +457,8 @@ zink_create_compute_pipeline(struct zink_screen *screen, 
struct zink_compute_pro
    stage.pName = "main";
 
    VkSpecializationInfo sinfo = {0};
-   VkSpecializationMapEntry me[3];
-   uint32_t data[3];
+   VkSpecializationMapEntry me[4];
+   uint32_t data[4];
    if (state)  {
       int i = 0;
 
@@ -475,6 +475,16 @@ zink_create_compute_pipeline(struct zink_screen *screen, 
struct zink_compute_pro
          }
       }
 
+      if (comp->has_variable_shared_mem) {
+         sinfo.mapEntryCount += 1;
+         sinfo.dataSize += sizeof(uint32_t);
+         data[i] = state->variable_shared_mem;
+         me[i].size = sizeof(uint32_t);
+         me[i].constantID = ZINK_VARIABLE_SHARED_MEM;
+         me[i].offset = i * sizeof(uint32_t);
+         i++;
+      }
+
       if (sinfo.dataSize) {
          stage.pSpecializationInfo = &sinfo;
          sinfo.pData = data;
diff --git a/src/gallium/drivers/zink/zink_program.c 
b/src/gallium/drivers/zink/zink_program.c
index 96835868502..796cc258a4f 100644
--- a/src/gallium/drivers/zink/zink_program.c
+++ b/src/gallium/drivers/zink/zink_program.c
@@ -1304,6 +1304,10 @@ zink_program_update_compute_pipeline_state(struct 
zink_context *ctx, struct zink
          ctx->compute_pipeline_state.local_size[i] = info->block[i];
       }
    }
+   if (ctx->compute_pipeline_state.variable_shared_mem != 
info->variable_shared_mem) {
+      ctx->compute_pipeline_state.dirty = true;
+      ctx->compute_pipeline_state.variable_shared_mem = 
info->variable_shared_mem;
+   }
 }
 
 static bool
diff --git a/src/gallium/drivers/zink/zink_types.h 
b/src/gallium/drivers/zink/zink_types.h
index 0dbb4d3c545..7a2c79ccf0b 100644
--- a/src/gallium/drivers/zink/zink_types.h
+++ b/src/gallium/drivers/zink/zink_types.h
@@ -937,6 +937,7 @@ struct zink_compute_pipeline_state {
    uint32_t final_hash;
    bool dirty;
    uint32_t local_size[3];
+   uint32_t variable_shared_mem;
 
    uint32_t module_hash;
    VkShaderModule module;

Reply via email to