Module: Mesa
Branch: main
Commit: fce434818a6d48a0541c7bc181ad201ae5d6503c
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=fce434818a6d48a0541c7bc181ad201ae5d6503c

Author: Rhys Perry <[email protected]>
Date:   Wed Oct  4 14:14:19 2023 +0100

nir/lower_fp16_casts: correctly round RTNE f64->f16 casts

Based on brw_nir_lower_conversions.

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Ivan Briano <[email protected]>
Reviewed-by: Georg Lehmann <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25566>

---

 src/compiler/nir/nir_lower_fp16_conv.c | 62 ++++++++++++++++++++++++++++++++--
 1 file changed, 60 insertions(+), 2 deletions(-)

diff --git a/src/compiler/nir/nir_lower_fp16_conv.c 
b/src/compiler/nir/nir_lower_fp16_conv.c
index 571cdf40f08..50ece5cecb5 100644
--- a/src/compiler/nir/nir_lower_fp16_conv.c
+++ b/src/compiler/nir/nir_lower_fp16_conv.c
@@ -69,8 +69,6 @@ float_to_half_impl(nir_builder *b, nir_def *src, 
nir_rounding_mode mode)
    nir_def *f32infinity = nir_imm_int(b, 255 << 23);
    nir_def *f16max = nir_imm_int(b, (127 + 16) << 23);
 
-   if (src->bit_size == 64)
-      src = nir_f2f32(b, src);
    nir_def *sign = nir_iand_imm(b, src, 0x80000000);
    nir_def *one = nir_imm_int(b, 1);
 
@@ -168,6 +166,64 @@ float_to_half_impl(nir_builder *b, nir_def *src, 
nir_rounding_mode mode)
    return nir_u2u16(b, nir_ior(b, fp16, nir_ushr_imm(b, sign, 16)));
 }
 
+static nir_def *
+split_f2f16_conversion(nir_builder *b, nir_def *src, nir_rounding_mode rnd)
+{
+   nir_def *tmp = nir_f2f32(b, src);
+
+   if (rnd == nir_rounding_mode_rtne) {
+      /* We round down from double to half float by going through float in
+       * between, but this can give us inaccurate results in some cases. One
+       * such case is 0x40ee6a0000000001, which should round to 0x7b9b, but
+       * going through float first turns into 0x7b9a instead. This is because
+       * the first non-fitting bit is set, so we get a tie, but with the least
+       * significant bit of the original number set, the tie should break
+       * rounding up. The cast to float, however, turns into 0x47735000, which
+       * when going to half still ties, but now we lost the tie-up bit, and
+       * instead we round to the nearest even, which in this case is down.
+       *
+       * To fix this, we check if the original would have tied, and if the tie
+       * would have rounded up, and if both are true, set the least
+       * significant bit of the intermediate float to 1, so that a tie on the
+       * next cast rounds up as well. If the rounding already got rid of the
+       * tie, that set bit will just be truncated anyway and the end result
+       * doesn't change.
+       *
+       * Another failing case is 0x40effdffffffffff. This one doesn't have the
+       * tie from double to half, so it just rounds down to 0x7bff (65504.0),
+       * but going through float first, it turns into 0x477ff000, which does
+       * have the tie bit for half set, and when that one gets rounded it
+       * turns into 0x7c00 (Infinity).
+       * The fix for that one is to make sure the intermediate float does not
+       * have the tie bit set if the original didn't have it.
+       *
+       * For the RTZ case, we don't need to do anything, as the intermediate
+       * float should be ok already.
+       */
+      int significand_bits16 = 10;
+      int significand_bits32 = 23;
+      int significand_bits64 = 52;
+      int f64_to_16_tie_bit = significand_bits64 - significand_bits16 - 1;
+      int f32_to_16_tie_bit = significand_bits32 - significand_bits16 - 1;
+      uint64_t f64_rounds_up_mask = ((1ULL << f64_to_16_tie_bit) - 1);
+
+      nir_def *would_tie = nir_iand_imm(b, src, 1ULL << f64_to_16_tie_bit);
+      nir_def *would_rnd_up = nir_iand_imm(b, src, f64_rounds_up_mask);
+
+      nir_def *tie_up = nir_b2i32(b, nir_ine_imm(b, would_rnd_up, 0));
+
+      nir_def *break_tie = nir_bcsel(b,
+                                     nir_ine_imm(b, would_tie, 0),
+                                     nir_imm_int(b, ~0),
+                                     nir_imm_int(b, ~(1U << 
f32_to_16_tie_bit)));
+
+      tmp = nir_ior(b, tmp, tie_up);
+      tmp = nir_iand(b, tmp, break_tie);
+   }
+
+   return tmp;
+}
+
 static bool
 lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
 {
@@ -242,6 +298,8 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void 
*data)
 
    for (unsigned i = 0; i < dst->num_components; i++) {
       nir_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i);
+      if (comp->bit_size == 64)
+         comp = split_f2f16_conversion(b, comp, mode);
       rets[i] = float_to_half_impl(b, comp, mode);
    }
 

Reply via email to