Module: Mesa Branch: main Commit: c5abb7c8d1556c2af3d5e2f498f5a7460aeecea9 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=c5abb7c8d1556c2af3d5e2f498f5a7460aeecea9
Author: Karol Herbst <[email protected]> Date: Tue Sep 19 14:44:26 2023 +0200 zink: variable shared mem support Signed-off-by: Karol Herbst <[email protected]> Reviewed-by: Mike Blumenkrantz <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24839> --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 29 ++++++++++++++++++---- src/gallium/drivers/zink/zink_compiler.h | 1 + src/gallium/drivers/zink/zink_pipeline.c | 14 +++++++++-- src/gallium/drivers/zink/zink_program.c | 4 +++ src/gallium/drivers/zink/zink_types.h | 1 + 5 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 37059f45807..d1f26cbe158 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -101,6 +101,8 @@ struct ntv_context { local_group_size_var, base_vertex_var, base_instance_var, draw_id_var; + SpvId shared_mem_size; + SpvId subgroup_eq_mask_var, subgroup_ge_mask_var, subgroup_gt_mask_var, @@ -663,13 +665,25 @@ get_scratch_block(struct ntv_context *ctx, unsigned bit_size) } static void -create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned bit_size) +create_shared_block(struct ntv_context *ctx, unsigned bit_size) { unsigned idx = bit_size >> 4; SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size); - unsigned block_size = shared_size / (bit_size / 8); - assert(block_size); - SpvId array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size)); + SpvId array; + + assert(gl_shader_stage_is_compute(ctx->nir->info.stage)); + if (ctx->nir->info.cs.has_variable_shared_mem) { + assert(ctx->shared_mem_size); + SpvId const_shared_size = emit_uint_const(ctx, 32, ctx->nir->info.shared_size); + SpvId shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpIAdd, const_shared_size, ctx->shared_mem_size); + shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpUDiv, shared_mem_size, emit_uint_const(ctx, 32, bit_size / 8)); + array = spirv_builder_type_array(&ctx->builder, type, shared_mem_size); + } else { + unsigned block_size = ctx->nir->info.shared_size / (bit_size / 8); + assert(block_size); + array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size)); + } + spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8); SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, @@ -686,7 +700,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size) { unsigned idx = bit_size >> 4; if (!ctx->shared_block_var[idx]) - create_shared_block(ctx, ctx->nir->info.shared_size, bit_size); + create_shared_block(ctx, bit_size); if (ctx->sinfo->have_workgroup_memory_explicit_layout) { spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_workgroup_memory_explicit_layout"); spirv_builder_emit_cap(&ctx->builder, SpvCapabilityWorkgroupMemoryExplicitLayoutKHR); @@ -4591,6 +4605,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ spirv_builder_emit_builtin(&ctx.builder, ctx.local_group_size_var, SpvBuiltInWorkgroupSize); } } + if (s->info.cs.has_variable_shared_mem) { + ctx.shared_mem_size = spirv_builder_spec_const_uint(&ctx.builder, 32); + spirv_builder_emit_specid(&ctx.builder, ctx.shared_mem_size, ZINK_VARIABLE_SHARED_MEM); + spirv_builder_emit_name(&ctx.builder, ctx.shared_mem_size, "variable_shared_mem"); + } if (s->info.cs.derivative_group) { SpvCapability caps[] = { 0, SpvCapabilityComputeDerivativeGroupQuadsNV, SpvCapabilityComputeDerivativeGroupLinearNV }; SpvExecutionMode modes[] = { 0, SpvExecutionModeDerivativeGroupQuadsNV, SpvExecutionModeDerivativeGroupLinearNV }; diff --git a/src/gallium/drivers/zink/zink_compiler.h b/src/gallium/drivers/zink/zink_compiler.h index 834084dbc51..1319193f83c 100644 --- a/src/gallium/drivers/zink/zink_compiler.h +++ b/src/gallium/drivers/zink/zink_compiler.h @@ -29,6 +29,7 @@ #define ZINK_WORKGROUP_SIZE_X 1 #define ZINK_WORKGROUP_SIZE_Y 2 #define ZINK_WORKGROUP_SIZE_Z 3 +#define ZINK_VARIABLE_SHARED_MEM 4 #define ZINK_INLINE_VAL_FLAT_MASK 0 #define ZINK_INLINE_VAL_PV_LAST_VERT 1 diff --git a/src/gallium/drivers/zink/zink_pipeline.c b/src/gallium/drivers/zink/zink_pipeline.c index cceb54c4a2c..063fdd9d73f 100644 --- a/src/gallium/drivers/zink/zink_pipeline.c +++ b/src/gallium/drivers/zink/zink_pipeline.c @@ -457,8 +457,8 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro stage.pName = "main"; VkSpecializationInfo sinfo = {0}; - VkSpecializationMapEntry me[3]; - uint32_t data[3]; + VkSpecializationMapEntry me[4]; + uint32_t data[4]; if (state) { int i = 0; @@ -475,6 +475,16 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro } } + if (comp->has_variable_shared_mem) { + sinfo.mapEntryCount += 1; + sinfo.dataSize += sizeof(uint32_t); + data[i] = state->variable_shared_mem; + me[i].size = sizeof(uint32_t); + me[i].constantID = ZINK_VARIABLE_SHARED_MEM; + me[i].offset = i * sizeof(uint32_t); + i++; + } + if (sinfo.dataSize) { stage.pSpecializationInfo = &sinfo; sinfo.pData = data; diff --git a/src/gallium/drivers/zink/zink_program.c b/src/gallium/drivers/zink/zink_program.c index 96835868502..796cc258a4f 100644 --- a/src/gallium/drivers/zink/zink_program.c +++ b/src/gallium/drivers/zink/zink_program.c @@ -1304,6 +1304,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink ctx->compute_pipeline_state.local_size[i] = info->block[i]; } } + if (ctx->compute_pipeline_state.variable_shared_mem != info->variable_shared_mem) { + ctx->compute_pipeline_state.dirty = true; + ctx->compute_pipeline_state.variable_shared_mem = info->variable_shared_mem; + } } static bool diff --git a/src/gallium/drivers/zink/zink_types.h b/src/gallium/drivers/zink/zink_types.h index 0dbb4d3c545..7a2c79ccf0b 100644 --- a/src/gallium/drivers/zink/zink_types.h +++ b/src/gallium/drivers/zink/zink_types.h @@ -937,6 +937,7 @@ struct zink_compute_pipeline_state { uint32_t final_hash; bool dirty; uint32_t local_size[3]; + uint32_t variable_shared_mem; uint32_t module_hash; VkShaderModule module;
