Module: Mesa
Branch: main
Commit: 3da27733166450a86815276b4adcd352fb29186b
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=3da27733166450a86815276b4adcd352fb29186b

Author: Alyssa Rosenzweig <[email protected]>
Date:   Mon Jan  8 12:01:49 2024 -0400

vtn: fuse OpenCL mad if we can can

clpeak "float" case from 1112 -> 1978 GFLOPS on rusticl on m1.

Signed-off-by: Alyssa Rosenzweig <[email protected]>
Reviewed-by: Karol Herbst <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26932>

---

 src/compiler/spirv/vtn_opencl.c | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c
index e3e4e61802c..59ec75da885 100644
--- a/src/compiler/spirv/vtn_opencl.c
+++ b/src/compiler/spirv/vtn_opencl.c
@@ -508,8 +508,25 @@ handle_special(struct vtn_builder *b, uint32_t opcode,
       return nir_cross3(nb, srcs[0], srcs[1]);
    case OpenCLstd_Fdim:
       return nir_fdim(nb, srcs[0], srcs[1]);
-   case OpenCLstd_Mad:
-      return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
+   case OpenCLstd_Mad: {
+      /* The spec says mad is
+       *
+       *    Implemented either as a correctly rounded fma or as a multiply
+       *    followed by an add both of which are correctly rounded
+       *
+       * So lower to fmul+fadd if we have to, but fuse to an ffma if the 
backend
+       * supports that. This can be significantly faster.
+       */
+      bool lower =
+         ((nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16) ||
+          (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) ||
+          (nb->shader->options->lower_ffma64 && srcs[0]->bit_size == 64));
+
+      if (lower)
+         return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
+      else
+         return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
+   }
    case OpenCLstd_Maxmag:
       return nir_maxmag(nb, srcs[0], srcs[1]);
    case OpenCLstd_Minmag:

Reply via email to