Module: Mesa Branch: main Commit: fe674f67b1b77c559e76aa6180d525a7609527f6 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=fe674f67b1b77c559e76aa6180d525a7609527f6
Author: Konstantin Seurer <[email protected]> Date: Wed Sep 6 15:23:01 2023 +0200 radv/rt: Use a helper for inlining non-recursive stages So we don't have to write the same logic multiple times. Reviewed-by: Friedrich Vock <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25089> --- src/amd/vulkan/radv_rt_shader.c | 339 +++++++++++++++++++++------------------- 1 file changed, 176 insertions(+), 163 deletions(-) diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index dbc29d8573e..0667aa896ad 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -38,6 +38,63 @@ * performance. */ #define MAX_STACK_ENTRY_COUNT 16 +struct radv_rt_case_data { + struct radv_device *device; + struct radv_ray_tracing_pipeline *pipeline; + struct rt_variables *vars; +}; + +typedef void (*radv_get_group_info)(struct radv_ray_tracing_group *, uint32_t *, uint32_t *, + struct radv_rt_case_data *); +typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct radv_ray_tracing_group *, + struct radv_rt_case_data *); + +static void +radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool can_have_null_shaders, struct radv_rt_case_data *data, + radv_get_group_info group_info, radv_insert_shader_case shader_case) +{ + struct radv_ray_tracing_group **groups = + calloc(data->pipeline->group_count, sizeof(struct radv_ray_tracing_group *)); + uint32_t case_count = 0; + + for (unsigned i = 0; i < data->pipeline->group_count; i++) { + struct radv_ray_tracing_group *group = &data->pipeline->groups[i]; + + uint32_t shader_index = VK_SHADER_UNUSED_KHR; + uint32_t handle_index = VK_SHADER_UNUSED_KHR; + group_info(group, &shader_index, &handle_index, data); + if (shader_index == VK_SHADER_UNUSED_KHR) + continue; + + /* Avoid emitting stages with the same shaders/handles multiple times. */ + bool duplicate = false; + for (unsigned j = 0; j < i; j++) { + uint32_t other_shader_index = VK_SHADER_UNUSED_KHR; + uint32_t other_handle_index = VK_SHADER_UNUSED_KHR; + group_info(&data->pipeline->groups[j], &other_shader_index, &other_handle_index, data); + + if (handle_index == other_handle_index) { + duplicate = true; + break; + } + } + + if (!duplicate) + groups[case_count++] = group; + } + + if (can_have_null_shaders) + nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); + + for (unsigned i = 0; i < case_count; i++) + shader_case(b, sbt_idx, groups[i], data); + + if (can_have_null_shaders) + nir_pop_if(b, NULL); + + free(groups); +} + static bool lower_rt_derefs(nir_shader *shader) { @@ -996,46 +1053,95 @@ struct traversal_data { }; static void -visit_any_hit_shaders(struct radv_device *device, nir_builder *b, struct traversal_data *data, - struct rt_variables *vars) +radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index, + struct radv_rt_case_data *data) { - nir_def *sbt_idx = nir_load_var(b, vars->idx); + if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR) { + *shader_index = group->any_hit_shader; + *handle_index = group->handle.any_hit_index; + } +} - if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) - nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); +static void +radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group, + struct radv_rt_case_data *data) +{ + nir_shader *nir_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir); + assert(nir_stage); - for (unsigned i = 0; i < data->pipeline->group_count; ++i) { - struct radv_ray_tracing_group *group = &data->pipeline->groups[i]; - uint32_t shader_id = VK_SHADER_UNUSED_KHR; + insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index); + ralloc_free(nir_stage); +} - switch (group->type) { - case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: - shader_id = group->any_hit_shader; - break; - default: - break; - } - if (shader_id == VK_SHADER_UNUSED_KHR) - continue; +static void +radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index, + struct radv_rt_case_data *data) +{ + if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR) { + *shader_index = group->intersection_shader; + *handle_index = group->handle.intersection_index; + } +} - /* Avoid emitting stages with the same shaders/handles multiple times. */ - bool is_dup = false; - for (unsigned j = 0; j < i; ++j) - if (data->pipeline->groups[j].handle.any_hit_index == data->pipeline->groups[i].handle.any_hit_index) - is_dup = true; +static void +radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group, + struct radv_rt_case_data *data) +{ + nir_shader *nir_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir); + assert(nir_stage); - if (is_dup) - continue; + nir_shader *any_hit_stage = NULL; + if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) { + any_hit_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir); + assert(any_hit_stage); - nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, data->pipeline->stages[shader_id].nir); - assert(nir_stage); + /* reserve stack size for any_hit before it is inlined */ + data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size; - insert_rt_case(b, nir_stage, vars, sbt_idx, data->pipeline->groups[i].handle.any_hit_index); - ralloc_free(nir_stage); + nir_lower_intersection_shader(nir_stage, any_hit_stage); + ralloc_free(any_hit_stage); } - if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) - nir_pop_if(b, NULL); + insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.intersection_index); + ralloc_free(nir_stage); +} + +static void +radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index, + struct radv_rt_case_data *data) +{ + if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) { + *shader_index = group->recursive_shader; + *handle_index = group->handle.closest_hit_index; + } +} + +static void +radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group *group, uint32_t *shader_index, uint32_t *handle_index, + struct radv_rt_case_data *data) +{ + if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) { + if (data->pipeline->stages[group->recursive_shader].stage != MESA_SHADER_MISS) + return; + + *shader_index = group->recursive_shader; + *handle_index = group->handle.general_index; + } +} + +static void +radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_group *group, + struct radv_rt_case_data *data) +{ + nir_shader *nir_stage = + radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir); + assert(nir_stage); + + insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index); + ralloc_free(nir_stage); } static void @@ -1071,7 +1177,16 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX); - visit_any_hit_shaders(data->device, b, args->data, &inner_vars); + struct radv_rt_case_data case_data = { + .device = data->device, + .pipeline = data->pipeline, + .vars = &inner_vars, + }; + + radv_visit_inlined_shaders( + b, nir_load_var(b, inner_vars.idx), + !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR), &case_data, + radv_ray_tracing_group_ahit_info, radv_build_ahit_case); nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept))); { @@ -1129,56 +1244,16 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1); nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1); - if (!(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR)) - nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0)); - - for (unsigned i = 0; i < data->pipeline->group_count; ++i) { - struct radv_ray_tracing_group *group = &data->pipeline->groups[i]; - uint32_t shader_id = VK_SHADER_UNUSED_KHR; - uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR; - - switch (group->type) { - case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: - shader_id = group->intersection_shader; - any_hit_shader_id = group->any_hit_shader; - break; - default: - break; - } - if (shader_id == VK_SHADER_UNUSED_KHR) - continue; - - /* Avoid emitting stages with the same shaders/handles multiple times. */ - bool is_dup = false; - for (unsigned j = 0; j < i; ++j) - if (data->pipeline->groups[j].handle.intersection_index == data->pipeline->groups[i].handle.intersection_index) - is_dup = true; - - if (is_dup) - continue; - - nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[shader_id].nir); - assert(nir_stage); - - nir_shader *any_hit_stage = NULL; - if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { - any_hit_stage = radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[any_hit_shader_id].nir); - assert(any_hit_stage); - - /* reserve stack size for any_hit before it is inlined */ - data->pipeline->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size; - - nir_lower_intersection_shader(nir_stage, any_hit_stage); - ralloc_free(any_hit_stage); - } - - insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), - data->pipeline->groups[i].handle.intersection_index); - ralloc_free(nir_stage); - } + struct radv_rt_case_data case_data = { + .device = data->device, + .pipeline = data->pipeline, + .vars = &inner_vars, + }; - if (!(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR)) - nir_pop_if(b, NULL); + radv_visit_inlined_shaders( + b, nir_load_var(b, inner_vars.idx), + !(data->vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR), &case_data, + radv_ray_tracing_group_isec_info, radv_build_isec_case); nir_push_if(b, nir_load_var(b, data->vars->ahit_accept)); { @@ -1201,87 +1276,6 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio nir_pop_if(b, NULL); } -static void -visit_closest_hit_shaders(struct radv_device *device, nir_builder *b, struct radv_ray_tracing_pipeline *pipeline, - struct rt_variables *vars) -{ - nir_def *sbt_idx = nir_load_var(b, vars->idx); - - if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) - nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); - - for (unsigned i = 0; i < pipeline->group_count; ++i) { - struct radv_ray_tracing_group *group = &pipeline->groups[i]; - - unsigned shader_id = VK_SHADER_UNUSED_KHR; - if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) - shader_id = group->recursive_shader; - - if (shader_id == VK_SHADER_UNUSED_KHR) - continue; - - /* Avoid emitting stages with the same shaders/handles multiple times. */ - bool is_dup = false; - for (unsigned j = 0; j < i; ++j) - if (pipeline->groups[j].handle.closest_hit_index == pipeline->groups[i].handle.closest_hit_index) - is_dup = true; - - if (is_dup) - continue; - - nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, pipeline->stages[shader_id].nir); - assert(nir_stage); - - insert_rt_case(b, nir_stage, vars, sbt_idx, pipeline->groups[i].handle.closest_hit_index); - ralloc_free(nir_stage); - } - - if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) - nir_pop_if(b, NULL); -} - -static void -visit_miss_shaders(struct radv_device *device, nir_builder *b, struct radv_ray_tracing_pipeline *pipeline, - struct rt_variables *vars) -{ - nir_def *sbt_idx = nir_load_var(b, vars->idx); - - if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) - nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); - - for (unsigned i = 0; i < pipeline->group_count; ++i) { - struct radv_ray_tracing_group *group = &pipeline->groups[i]; - - unsigned shader_id = VK_SHADER_UNUSED_KHR; - if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) - shader_id = group->recursive_shader; - - if (shader_id == VK_SHADER_UNUSED_KHR) - continue; - - if (pipeline->stages[shader_id].stage != MESA_SHADER_MISS) - continue; - - /* Avoid emitting stages with the same shaders/handles multiple times. */ - bool is_dup = false; - for (unsigned j = 0; j < i; ++j) - if (pipeline->groups[j].handle.general_index == pipeline->groups[i].handle.general_index) - is_dup = true; - - if (is_dup) - continue; - - nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, pipeline->stages[shader_id].nir); - assert(nir_stage); - - insert_rt_case(b, nir_stage, vars, sbt_idx, pipeline->groups[i].handle.general_index); - ralloc_free(nir_stage); - } - - if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) - nir_pop_if(b, NULL); -} - static void store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args) { @@ -1410,7 +1404,18 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin /* should_return is set if we had a hit but we won't be calling the closest hit * shader and hence need to return immediately to the calling shader. */ nir_push_if(b, nir_inot(b, should_return)); - visit_closest_hit_shaders(device, b, pipeline, vars); + + struct radv_rt_case_data case_data = { + .device = device, + .pipeline = pipeline, + .vars = vars, + }; + + radv_visit_inlined_shaders( + b, nir_load_var(b, vars->idx), + !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR), &case_data, + radv_ray_tracing_group_chit_info, radv_build_recursive_case); + nir_pop_if(b, NULL); } else { for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i) @@ -1425,7 +1430,15 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin if (monolithic) { load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, SBT_GENERAL_IDX); - visit_miss_shaders(device, b, pipeline, vars); + struct radv_rt_case_data case_data = { + .device = device, + .pipeline = pipeline, + .vars = vars, + }; + + radv_visit_inlined_shaders(b, nir_load_var(b, vars->idx), + !(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR), + &case_data, radv_ray_tracing_group_miss_info, radv_build_recursive_case); } else { /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer * for miss shaders if none of the rays miss. */
