Module: Mesa
Branch: main
Commit: 288e9db05397822cdf22315c48b43d5a3810dc63
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=288e9db05397822cdf22315c48b43d5a3810dc63

Author: Rhys Perry <[email protected]>
Date:   Wed Oct  4 14:23:59 2023 +0100

nir/lower_fp16_casts: add option to split fp64 casts

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Ivan Briano <[email protected]>
Reviewed-by: Georg Lehmann <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25566>

---

 src/compiler/nir/nir.h                 |  1 +
 src/compiler/nir/nir_lower_fp16_conv.c | 46 +++++++++++++++++++++++++++-------
 2 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index db3c8554357..96c5e7b7c22 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -6144,6 +6144,7 @@ typedef enum {
    nir_lower_fp16_ru = (1 << 2),
    nir_lower_fp16_rd = (1 << 3),
    nir_lower_fp16_all = 0xf,
+   nir_lower_fp16_split_fp64 = (1 << 4),
 } nir_lower_fp16_cast_options;
 bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options 
options);
 bool nir_normalize_cubemap_coords(nir_shader *shader);
diff --git a/src/compiler/nir/nir_lower_fp16_conv.c 
b/src/compiler/nir/nir_lower_fp16_conv.c
index 50ece5cecb5..b6990b3da9c 100644
--- a/src/compiler/nir/nir_lower_fp16_conv.c
+++ b/src/compiler/nir/nir_lower_fp16_conv.c
@@ -227,13 +227,15 @@ split_f2f16_conversion(nir_builder *b, nir_def *src, 
nir_rounding_mode rnd)
 static bool
 lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
 {
-   nir_def *src, *dst;
+   nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
+   nir_src *src;
+   nir_def *dst;
    uint8_t *swizzle = NULL;
    nir_rounding_mode mode = nir_rounding_mode_undef;
 
    if (instr->type == nir_instr_type_alu) {
       nir_alu_instr *alu = nir_instr_as_alu(instr);
-      src = alu->src[0].src.ssa;
+      src = &alu->src[0].src;
       swizzle = alu->src[0].swizzle;
       dst = &alu->def;
       switch (alu->op) {
@@ -249,22 +251,48 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, 
void *data)
       case nir_op_f2f16_rtz:
          mode = nir_rounding_mode_rtz;
          break;
+      case nir_op_f2f64:
+         if (src->ssa->bit_size == 16 && (options & 
nir_lower_fp16_split_fp64)) {
+            b->cursor = nir_before_instr(instr);
+            nir_src_rewrite(src, nir_f2f32(b, src->ssa));
+            return true;
+         }
+         return false;
       default:
          return false;
       }
    } else if (instr->type == nir_instr_type_intrinsic) {
       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-      if (intrin->intrinsic != nir_intrinsic_convert_alu_types ||
-          nir_intrinsic_dest_type(intrin) != nir_type_float16)
+      if (intrin->intrinsic != nir_intrinsic_convert_alu_types)
          return false;
-      src = intrin->src[0].ssa;
+
+      src = &intrin->src[0];
       dst = &intrin->def;
       mode = nir_intrinsic_rounding_mode(intrin);
+
+      if (nir_intrinsic_src_type(intrin) == nir_type_float16 &&
+          nir_intrinsic_dest_type(intrin) == nir_type_float64 &&
+          (options & nir_lower_fp16_split_fp64)) {
+         b->cursor = nir_before_instr(instr);
+         nir_src_rewrite(src, nir_f2f32(b, src->ssa));
+         return true;
+      }
+
+      if (nir_intrinsic_dest_type(intrin) != nir_type_float16)
+         return false;
    } else {
       return false;
    }
 
-   nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
+   bool progress = false;
+   if (src->ssa->bit_size == 64 && (options & nir_lower_fp16_split_fp64)) {
+      b->cursor = nir_before_instr(instr);
+      nir_src_rewrite(src, split_f2f16_conversion(b, src->ssa, mode));
+      if (instr->type == nir_instr_type_intrinsic)
+         nir_intrinsic_set_src_type(nir_instr_as_intrinsic(instr), 
nir_type_float32);
+      progress = true;
+   }
+
    nir_lower_fp16_cast_options req_option = 0;
    switch (mode) {
    case nir_rounding_mode_rtz:
@@ -280,7 +308,7 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void 
*data)
       req_option = nir_lower_fp16_rd;
       break;
    case nir_rounding_mode_undef:
-      if (options == nir_lower_fp16_all) {
+      if ((options & nir_lower_fp16_all) == nir_lower_fp16_all) {
          /* Pick one arbitrarily for lowering */
          mode = nir_rounding_mode_rtne;
          req_option = nir_lower_fp16_rtne;
@@ -291,13 +319,13 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, 
void *data)
       unreachable("Invalid rounding mode");
    }
    if (!(options & req_option))
-      return false;
+      return progress;
 
    b->cursor = nir_before_instr(instr);
    nir_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
 
    for (unsigned i = 0; i < dst->num_components; i++) {
-      nir_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i);
+      nir_def *comp = nir_channel(b, src->ssa, swizzle ? swizzle[i] : i);
       if (comp->bit_size == 64)
          comp = split_f2f16_conversion(b, comp, mode);
       rets[i] = float_to_half_impl(b, comp, mode);

Reply via email to