Module: Mesa
Branch: main
Commit: fe674f67b1b77c559e76aa6180d525a7609527f6
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=fe674f67b1b77c559e76aa6180d525a7609527f6

Author: Konstantin Seurer <[email protected]>
Date:   Wed Sep  6 15:23:01 2023 +0200

radv/rt: Use a helper for inlining non-recursive stages

So we don't have to write the same logic multiple times.

Reviewed-by: Friedrich Vock <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25089>

---

 src/amd/vulkan/radv_rt_shader.c | 339 +++++++++++++++++++++-------------------
 1 file changed, 176 insertions(+), 163 deletions(-)

diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
index dbc29d8573e..0667aa896ad 100644
--- a/src/amd/vulkan/radv_rt_shader.c
+++ b/src/amd/vulkan/radv_rt_shader.c
@@ -38,6 +38,63 @@
  * performance. */
 #define MAX_STACK_ENTRY_COUNT 16
 
+struct radv_rt_case_data {
+   struct radv_device *device;
+   struct radv_ray_tracing_pipeline *pipeline;
+   struct rt_variables *vars;
+};
+
+typedef void (*radv_get_group_info)(struct radv_ray_tracing_group *, uint32_t 
*, uint32_t *,
+                                    struct radv_rt_case_data *);
+typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct 
radv_ray_tracing_group *,
+                                        struct radv_rt_case_data *);
+
+static void
+radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool 
can_have_null_shaders, struct radv_rt_case_data *data,
+                           radv_get_group_info group_info, 
radv_insert_shader_case shader_case)
+{
+   struct radv_ray_tracing_group **groups =
+      calloc(data->pipeline->group_count, sizeof(struct radv_ray_tracing_group 
*));
+   uint32_t case_count = 0;
+
+   for (unsigned i = 0; i < data->pipeline->group_count; i++) {
+      struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
+
+      uint32_t shader_index = VK_SHADER_UNUSED_KHR;
+      uint32_t handle_index = VK_SHADER_UNUSED_KHR;
+      group_info(group, &shader_index, &handle_index, data);
+      if (shader_index == VK_SHADER_UNUSED_KHR)
+         continue;
+
+      /* Avoid emitting stages with the same shaders/handles multiple times. */
+      bool duplicate = false;
+      for (unsigned j = 0; j < i; j++) {
+         uint32_t other_shader_index = VK_SHADER_UNUSED_KHR;
+         uint32_t other_handle_index = VK_SHADER_UNUSED_KHR;
+         group_info(&data->pipeline->groups[j], &other_shader_index, 
&other_handle_index, data);
+
+         if (handle_index == other_handle_index) {
+            duplicate = true;
+            break;
+         }
+      }
+
+      if (!duplicate)
+         groups[case_count++] = group;
+   }
+
+   if (can_have_null_shaders)
+      nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
+
+   for (unsigned i = 0; i < case_count; i++)
+      shader_case(b, sbt_idx, groups[i], data);
+
+   if (can_have_null_shaders)
+      nir_pop_if(b, NULL);
+
+   free(groups);
+}
+
 static bool
 lower_rt_derefs(nir_shader *shader)
 {
@@ -996,46 +1053,95 @@ struct traversal_data {
 };
 
 static void
-visit_any_hit_shaders(struct radv_device *device, nir_builder *b, struct 
traversal_data *data,
-                      struct rt_variables *vars)
+radv_ray_tracing_group_ahit_info(struct radv_ray_tracing_group *group, 
uint32_t *shader_index, uint32_t *handle_index,
+                                 struct radv_rt_case_data *data)
 {
-   nir_def *sbt_idx = nir_load_var(b, vars->idx);
+   if (group->type == 
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR) {
+      *shader_index = group->any_hit_shader;
+      *handle_index = group->handle.any_hit_index;
+   }
+}
 
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
-      nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
+static void
+radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct 
radv_ray_tracing_group *group,
+                     struct radv_rt_case_data *data)
+{
+   nir_shader *nir_stage =
+      radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[group->any_hit_shader].nir);
+   assert(nir_stage);
 
-   for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
-      struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
+   insert_rt_case(b, nir_stage, data->vars, sbt_idx, 
group->handle.any_hit_index);
+   ralloc_free(nir_stage);
+}
 
-      switch (group->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
-         shader_id = group->any_hit_shader;
-         break;
-      default:
-         break;
-      }
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
+static void
+radv_ray_tracing_group_isec_info(struct radv_ray_tracing_group *group, 
uint32_t *shader_index, uint32_t *handle_index,
+                                 struct radv_rt_case_data *data)
+{
+   if (group->type == 
VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR) {
+      *shader_index = group->intersection_shader;
+      *handle_index = group->handle.intersection_index;
+   }
+}
 
-      /* Avoid emitting stages with the same shaders/handles multiple times. */
-      bool is_dup = false;
-      for (unsigned j = 0; j < i; ++j)
-         if (data->pipeline->groups[j].handle.any_hit_index == 
data->pipeline->groups[i].handle.any_hit_index)
-            is_dup = true;
+static void
+radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct 
radv_ray_tracing_group *group,
+                     struct radv_rt_case_data *data)
+{
+   nir_shader *nir_stage =
+      radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[group->intersection_shader].nir);
+   assert(nir_stage);
 
-      if (is_dup)
-         continue;
+   nir_shader *any_hit_stage = NULL;
+   if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
+      any_hit_stage =
+         radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[group->any_hit_shader].nir);
+      assert(any_hit_stage);
 
-      nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, 
data->pipeline->stages[shader_id].nir);
-      assert(nir_stage);
+      /* reserve stack size for any_hit before it is inlined */
+      data->pipeline->stages[group->any_hit_shader].stack_size = 
any_hit_stage->scratch_size;
 
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 
data->pipeline->groups[i].handle.any_hit_index);
-      ralloc_free(nir_stage);
+      nir_lower_intersection_shader(nir_stage, any_hit_stage);
+      ralloc_free(any_hit_stage);
    }
 
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
-      nir_pop_if(b, NULL);
+   insert_rt_case(b, nir_stage, data->vars, sbt_idx, 
group->handle.intersection_index);
+   ralloc_free(nir_stage);
+}
+
+static void
+radv_ray_tracing_group_chit_info(struct radv_ray_tracing_group *group, 
uint32_t *shader_index, uint32_t *handle_index,
+                                 struct radv_rt_case_data *data)
+{
+   if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
+      *shader_index = group->recursive_shader;
+      *handle_index = group->handle.closest_hit_index;
+   }
+}
+
+static void
+radv_ray_tracing_group_miss_info(struct radv_ray_tracing_group *group, 
uint32_t *shader_index, uint32_t *handle_index,
+                                 struct radv_rt_case_data *data)
+{
+   if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) {
+      if (data->pipeline->stages[group->recursive_shader].stage != 
MESA_SHADER_MISS)
+         return;
+
+      *shader_index = group->recursive_shader;
+      *handle_index = group->handle.general_index;
+   }
+}
+
+static void
+radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct 
radv_ray_tracing_group *group,
+                          struct radv_rt_case_data *data)
+{
+   nir_shader *nir_stage =
+      radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[group->recursive_shader].nir);
+   assert(nir_stage);
+
+   insert_rt_case(b, nir_stage, data->vars, sbt_idx, 
group->handle.general_index);
+   ralloc_free(nir_stage);
 }
 
 static void
@@ -1071,7 +1177,16 @@ handle_candidate_triangle(nir_builder *b, struct 
radv_triangle_intersection *int
 
       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
 
-      visit_any_hit_shaders(data->device, b, args->data, &inner_vars);
+      struct radv_rt_case_data case_data = {
+         .device = data->device,
+         .pipeline = data->pipeline,
+         .vars = &inner_vars,
+      };
+
+      radv_visit_inlined_shaders(
+         b, nir_load_var(b, inner_vars.idx),
+         !(data->vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR), &case_data,
+         radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
 
       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
       {
@@ -1129,56 +1244,16 @@ handle_candidate_aabb(nir_builder *b, struct 
radv_leaf_intersection *intersectio
    nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
    nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
 
-   if (!(data->vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
-      nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
-
-   for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
-      struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
-      uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
-
-      switch (group->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
-         shader_id = group->intersection_shader;
-         any_hit_shader_id = group->any_hit_shader;
-         break;
-      default:
-         break;
-      }
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      /* Avoid emitting stages with the same shaders/handles multiple times. */
-      bool is_dup = false;
-      for (unsigned j = 0; j < i; ++j)
-         if (data->pipeline->groups[j].handle.intersection_index == 
data->pipeline->groups[i].handle.intersection_index)
-            is_dup = true;
-
-      if (is_dup)
-         continue;
-
-      nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[shader_id].nir);
-      assert(nir_stage);
-
-      nir_shader *any_hit_stage = NULL;
-      if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
-         any_hit_stage = radv_pipeline_cache_handle_to_nir(data->device, 
data->pipeline->stages[any_hit_shader_id].nir);
-         assert(any_hit_stage);
-
-         /* reserve stack size for any_hit before it is inlined */
-         data->pipeline->stages[any_hit_shader_id].stack_size = 
any_hit_stage->scratch_size;
-
-         nir_lower_intersection_shader(nir_stage, any_hit_stage);
-         ralloc_free(any_hit_stage);
-      }
-
-      insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, 
inner_vars.idx),
-                     data->pipeline->groups[i].handle.intersection_index);
-      ralloc_free(nir_stage);
-   }
+   struct radv_rt_case_data case_data = {
+      .device = data->device,
+      .pipeline = data->pipeline,
+      .vars = &inner_vars,
+   };
 
-   if (!(data->vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
-      nir_pop_if(b, NULL);
+   radv_visit_inlined_shaders(
+      b, nir_load_var(b, inner_vars.idx),
+      !(data->vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR), 
&case_data,
+      radv_ray_tracing_group_isec_info, radv_build_isec_case);
 
    nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
    {
@@ -1201,87 +1276,6 @@ handle_candidate_aabb(nir_builder *b, struct 
radv_leaf_intersection *intersectio
    nir_pop_if(b, NULL);
 }
 
-static void
-visit_closest_hit_shaders(struct radv_device *device, nir_builder *b, struct 
radv_ray_tracing_pipeline *pipeline,
-                          struct rt_variables *vars)
-{
-   nir_def *sbt_idx = nir_load_var(b, vars->idx);
-
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR))
-      nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
-
-   for (unsigned i = 0; i < pipeline->group_count; ++i) {
-      struct radv_ray_tracing_group *group = &pipeline->groups[i];
-
-      unsigned shader_id = VK_SHADER_UNUSED_KHR;
-      if (group->type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
-         shader_id = group->recursive_shader;
-
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      /* Avoid emitting stages with the same shaders/handles multiple times. */
-      bool is_dup = false;
-      for (unsigned j = 0; j < i; ++j)
-         if (pipeline->groups[j].handle.closest_hit_index == 
pipeline->groups[i].handle.closest_hit_index)
-            is_dup = true;
-
-      if (is_dup)
-         continue;
-
-      nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, 
pipeline->stages[shader_id].nir);
-      assert(nir_stage);
-
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 
pipeline->groups[i].handle.closest_hit_index);
-      ralloc_free(nir_stage);
-   }
-
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR))
-      nir_pop_if(b, NULL);
-}
-
-static void
-visit_miss_shaders(struct radv_device *device, nir_builder *b, struct 
radv_ray_tracing_pipeline *pipeline,
-                   struct rt_variables *vars)
-{
-   nir_def *sbt_idx = nir_load_var(b, vars->idx);
-
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR))
-      nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
-
-   for (unsigned i = 0; i < pipeline->group_count; ++i) {
-      struct radv_ray_tracing_group *group = &pipeline->groups[i];
-
-      unsigned shader_id = VK_SHADER_UNUSED_KHR;
-      if (group->type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
-         shader_id = group->recursive_shader;
-
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      if (pipeline->stages[shader_id].stage != MESA_SHADER_MISS)
-         continue;
-
-      /* Avoid emitting stages with the same shaders/handles multiple times. */
-      bool is_dup = false;
-      for (unsigned j = 0; j < i; ++j)
-         if (pipeline->groups[j].handle.general_index == 
pipeline->groups[i].handle.general_index)
-            is_dup = true;
-
-      if (is_dup)
-         continue;
-
-      nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, 
pipeline->stages[shader_id].nir);
-      assert(nir_stage);
-
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 
pipeline->groups[i].handle.general_index);
-      ralloc_free(nir_stage);
-   }
-
-   if (!(vars->flags & 
VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR))
-      nir_pop_if(b, NULL);
-}
-
 static void
 store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct 
radv_ray_traversal_args *args)
 {
@@ -1410,7 +1404,18 @@ radv_build_traversal(struct radv_device *device, struct 
radv_ray_tracing_pipelin
          /* should_return is set if we had a hit but we won't be calling the 
closest hit
           * shader and hence need to return immediately to the calling shader. 
*/
          nir_push_if(b, nir_inot(b, should_return));
-         visit_closest_hit_shaders(device, b, pipeline, vars);
+
+         struct radv_rt_case_data case_data = {
+            .device = device,
+            .pipeline = pipeline,
+            .vars = vars,
+         };
+
+         radv_visit_inlined_shaders(
+            b, nir_load_var(b, vars->idx),
+            !(vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR), 
&case_data,
+            radv_ray_tracing_group_chit_info, radv_build_recursive_case);
+
          nir_pop_if(b, NULL);
       } else {
          for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
@@ -1425,7 +1430,15 @@ radv_build_traversal(struct radv_device *device, struct 
radv_ray_tracing_pipelin
       if (monolithic) {
          load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 
SBT_GENERAL_IDX);
 
-         visit_miss_shaders(device, b, pipeline, vars);
+         struct radv_rt_case_data case_data = {
+            .device = device,
+            .pipeline = pipeline,
+            .vars = vars,
+         };
+
+         radv_visit_inlined_shaders(b, nir_load_var(b, vars->idx),
+                                    !(vars->flags & 
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR),
+                                    &case_data, 
radv_ray_tracing_group_miss_info, radv_build_recursive_case);
       } else {
          /* Only load the miss shader if we actually miss. It is valid to not 
specify an SBT pointer
           * for miss shaders if none of the rays miss. */

Reply via email to