Module: Mesa Branch: main Commit: 3da27733166450a86815276b4adcd352fb29186b URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=3da27733166450a86815276b4adcd352fb29186b
Author: Alyssa Rosenzweig <[email protected]> Date: Mon Jan 8 12:01:49 2024 -0400 vtn: fuse OpenCL mad if we can can clpeak "float" case from 1112 -> 1978 GFLOPS on rusticl on m1. Signed-off-by: Alyssa Rosenzweig <[email protected]> Reviewed-by: Karol Herbst <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26932> --- src/compiler/spirv/vtn_opencl.c | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index e3e4e61802c..59ec75da885 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -508,8 +508,25 @@ handle_special(struct vtn_builder *b, uint32_t opcode, return nir_cross3(nb, srcs[0], srcs[1]); case OpenCLstd_Fdim: return nir_fdim(nb, srcs[0], srcs[1]); - case OpenCLstd_Mad: - return nir_fmad(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_Mad: { + /* The spec says mad is + * + * Implemented either as a correctly rounded fma or as a multiply + * followed by an add both of which are correctly rounded + * + * So lower to fmul+fadd if we have to, but fuse to an ffma if the backend + * supports that. This can be significantly faster. + */ + bool lower = + ((nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16) || + (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) || + (nb->shader->options->lower_ffma64 && srcs[0]->bit_size == 64)); + + if (lower) + return nir_fmad(nb, srcs[0], srcs[1], srcs[2]); + else + return nir_ffma(nb, srcs[0], srcs[1], srcs[2]); + } case OpenCLstd_Maxmag: return nir_maxmag(nb, srcs[0], srcs[1]); case OpenCLstd_Minmag:
