Module: Mesa Branch: main Commit: 288e9db05397822cdf22315c48b43d5a3810dc63 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=288e9db05397822cdf22315c48b43d5a3810dc63
Author: Rhys Perry <[email protected]> Date: Wed Oct 4 14:23:59 2023 +0100 nir/lower_fp16_casts: add option to split fp64 casts Signed-off-by: Rhys Perry <[email protected]> Reviewed-by: Ivan Briano <[email protected]> Reviewed-by: Georg Lehmann <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25566> --- src/compiler/nir/nir.h | 1 + src/compiler/nir/nir_lower_fp16_conv.c | 46 +++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index db3c8554357..96c5e7b7c22 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6144,6 +6144,7 @@ typedef enum { nir_lower_fp16_ru = (1 << 2), nir_lower_fp16_rd = (1 << 3), nir_lower_fp16_all = 0xf, + nir_lower_fp16_split_fp64 = (1 << 4), } nir_lower_fp16_cast_options; bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options); bool nir_normalize_cubemap_coords(nir_shader *shader); diff --git a/src/compiler/nir/nir_lower_fp16_conv.c b/src/compiler/nir/nir_lower_fp16_conv.c index 50ece5cecb5..b6990b3da9c 100644 --- a/src/compiler/nir/nir_lower_fp16_conv.c +++ b/src/compiler/nir/nir_lower_fp16_conv.c @@ -227,13 +227,15 @@ split_f2f16_conversion(nir_builder *b, nir_def *src, nir_rounding_mode rnd) static bool lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) { - nir_def *src, *dst; + nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data; + nir_src *src; + nir_def *dst; uint8_t *swizzle = NULL; nir_rounding_mode mode = nir_rounding_mode_undef; if (instr->type == nir_instr_type_alu) { nir_alu_instr *alu = nir_instr_as_alu(instr); - src = alu->src[0].src.ssa; + src = &alu->src[0].src; swizzle = alu->src[0].swizzle; dst = &alu->def; switch (alu->op) { @@ -249,22 +251,48 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) case nir_op_f2f16_rtz: mode = nir_rounding_mode_rtz; break; + case nir_op_f2f64: + if (src->ssa->bit_size == 16 && (options & nir_lower_fp16_split_fp64)) { + b->cursor = nir_before_instr(instr); + nir_src_rewrite(src, nir_f2f32(b, src->ssa)); + return true; + } + return false; default: return false; } } else if (instr->type == nir_instr_type_intrinsic) { nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - if (intrin->intrinsic != nir_intrinsic_convert_alu_types || - nir_intrinsic_dest_type(intrin) != nir_type_float16) + if (intrin->intrinsic != nir_intrinsic_convert_alu_types) return false; - src = intrin->src[0].ssa; + + src = &intrin->src[0]; dst = &intrin->def; mode = nir_intrinsic_rounding_mode(intrin); + + if (nir_intrinsic_src_type(intrin) == nir_type_float16 && + nir_intrinsic_dest_type(intrin) == nir_type_float64 && + (options & nir_lower_fp16_split_fp64)) { + b->cursor = nir_before_instr(instr); + nir_src_rewrite(src, nir_f2f32(b, src->ssa)); + return true; + } + + if (nir_intrinsic_dest_type(intrin) != nir_type_float16) + return false; } else { return false; } - nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data; + bool progress = false; + if (src->ssa->bit_size == 64 && (options & nir_lower_fp16_split_fp64)) { + b->cursor = nir_before_instr(instr); + nir_src_rewrite(src, split_f2f16_conversion(b, src->ssa, mode)); + if (instr->type == nir_instr_type_intrinsic) + nir_intrinsic_set_src_type(nir_instr_as_intrinsic(instr), nir_type_float32); + progress = true; + } + nir_lower_fp16_cast_options req_option = 0; switch (mode) { case nir_rounding_mode_rtz: @@ -280,7 +308,7 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) req_option = nir_lower_fp16_rd; break; case nir_rounding_mode_undef: - if (options == nir_lower_fp16_all) { + if ((options & nir_lower_fp16_all) == nir_lower_fp16_all) { /* Pick one arbitrarily for lowering */ mode = nir_rounding_mode_rtne; req_option = nir_lower_fp16_rtne; @@ -291,13 +319,13 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) unreachable("Invalid rounding mode"); } if (!(options & req_option)) - return false; + return progress; b->cursor = nir_before_instr(instr); nir_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL }; for (unsigned i = 0; i < dst->num_components; i++) { - nir_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i); + nir_def *comp = nir_channel(b, src->ssa, swizzle ? swizzle[i] : i); if (comp->bit_size == 64) comp = split_f2f16_conversion(b, comp, mode); rets[i] = float_to_half_impl(b, comp, mode);
