Claudio Bantaloukas <claudio.bantalou...@arm.com> writes:
> [...]
> @@ -4004,6 +4008,44 @@ SHAPE (ternary_bfloat_lane)
>  typedef ternary_bfloat_lane_base<2> ternary_bfloat_lanex2_def;
>  SHAPE (ternary_bfloat_lanex2)
 
> +/* sv<t0>_t svfoo[_t0](sv<t0>_t, svmfloat8_t, svmfloat8_t, uint64_t)
> +
> +   where the final argument is an integer constant expression in the range
> +   [0, 15].  */
> +struct ternary_mfloat8_lane_def
> +    : public ternary_resize2_lane_base<8, TYPE_mfloat, TYPE_mfloat>
> +{
> +  void
> +  build (function_builder &b, const function_group_info &group) const 
> override
> +  {
> +    gcc_assert (group.fpm_mode == FPM_set);
> +    b.add_overloaded_functions (group, MODE_none);
> +    build_all (b, "v0,v0,vM,vM,su64", group, MODE_none);
> +  }
> +
> +  bool
> +  check (function_checker &c) const override
> +  {
> +    return c.require_immediate_lane_index (3, 2, 1);
> +  }
> +
> +  tree
> +  resolve (function_resolver &r) const override
> +  {
> +    type_suffix_index type;
> +    if (!r.check_num_arguments (5)
> +     || (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
> +     || !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
> +     || !r.require_vector_type (2, VECTOR_TYPE_svmfloat8_t)
> +     || !r.require_integer_immediate (3)
> +     || !r.require_scalar_type (4, "int64_t"))

uint64_t

> +      return error_mark_node;
> +
> +    return r.resolve_to (r.mode_suffix_id, type, TYPE_SUFFIX_mf8, 
> GROUP_none);
> +  }
> +};
> +SHAPE (ternary_mfloat8_lane)
> +
>  /* sv<t0>_t svfoo[_t0](sv<t0>_t, svbfloatt16_t, svbfloat16_t)
>     sv<t0>_t svfoo[_n_t0](sv<t0>_t, svbfloat16_t, bfloat16_t).  */
>  struct ternary_bfloat_opt_n_def
> @@ -4019,6 +4061,46 @@ struct ternary_bfloat_opt_n_def
>  };
>  SHAPE (ternary_bfloat_opt_n)
>  
> +/* sv<t0>_t svfoo[_t0](sv<t0>_t, svmfloatt8_t, svmfloat8_t)
> +   sv<t0>_t svfoo[_n_t0](sv<t0>_t, svmfloat8_t, bfloat8_t).  */
> +struct ternary_mfloat8_opt_n_def
> +    : public ternary_resize2_opt_n_base<8, TYPE_mfloat, TYPE_mfloat>
> +{
> +  void
> +  build (function_builder &b, const function_group_info &group) const 
> override
> +  {
> +    gcc_assert (group.fpm_mode == FPM_set);
> +    b.add_overloaded_functions (group, MODE_none);
> +    build_all (b, "v0,v0,vM,vM", group, MODE_none);
> +    build_all (b, "v0,v0,vM,sM", group, MODE_n);
> +  }
> +
> +  tree
> +  resolve (function_resolver &r) const override
> +  {
> +    type_suffix_index type;
> +    if (!r.check_num_arguments (4)
> +     || (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
> +     || !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
> +     || !r.require_scalar_type (3, "int64_t"))
> +      return error_mark_node;
> +
> +    tree scalar_form
> +     = r.lookup_form (MODE_n, type, TYPE_SUFFIX_mf8, GROUP_none);
> +    if (r.scalar_argument_p (2))
> +      {
> +     if (scalar_form)
> +       return scalar_form;
> +     return error_mark_node;

It looks like this would return error_mark_node without reporting
an error first.

> +      }
> +    if (scalar_form && !r.require_vector_or_scalar_type (2))
> +      return error_mark_node;
> +
> +    return r.resolve_to (r.mode_suffix_id, type, TYPE_SUFFIX_mf8, 
> GROUP_none);
> +  }

In this context (unlike finish_opt_n_resolution) we know that there is
a bijection between the vector and scalar forms.  So I think we can just
add require_vector_or_scalar_type to the initial checks:

    if (!r.check_num_arguments (4)
        || (type = r.infer_vector_type (0)) == NUM_TYPE_SUFFIXES
        || !r.require_vector_type (1, VECTOR_TYPE_svmfloat8_t)
        || !r.require_vector_or_scalar_type (2)
        || !r.require_scalar_type (3, "int64_t"))
      return error_mark_node;

    auto mode = r.mode_suffix_id;
    if (r.scalar_argument_p (2))
      mode = MODE_n;
    else if (!r.require_vector_type (2, VECTOR_TYPE_svmfloat8_t))
      return error_mark_node;

    return r.resolve_to (mode, type, TYPE_SUFFIX_mf8, GROUP_none);

(untested).

> [...]
> +;; -------------------------------------------------------------------------
> +;; ---- [FP] Mfloat8 Multiply-and-accumulate operations
> +;; -------------------------------------------------------------------------
> +;; Includes:
> +;; - FMLALB (vectors, FP8 to FP16)
> +;; - FMLALT (vectors, FP8 to FP16)
> +;; - FMLALB (indexed, FP8 to FP16)
> +;; - FMLALT (indexed, FP8 to FP16)
> +;; - FMLALLBB (vectors)
> +;; - FMLALLBB (indexed)
> +;; - FMLALLBT (vectors)
> +;; - FMLALLBT (indexed)
> +;; - FMLALLTB (vectors)
> +;; - FMLALLTB (indexed)
> +;; - FMLALLTT (vectors)
> +;; - FMLALLTT (indexed)
> +;; -------------------------------------------------------------------------
> +
> +(define_insn "@aarch64_sve_add_<sve2_fp8_fma_op><mode>"
> +  [(set (match_operand:SVE_FULL_HSF 0 "register_operand")
> +     (unspec:SVE_FULL_HSF
> +       [(match_operand:SVE_FULL_HSF 1 "register_operand")
> +        (match_operand:VNx16QI 2 "register_operand")
> +        (match_operand:VNx16QI 3 "register_operand")
> +        (reg:DI FPM_REGNUM)]
> +       SVE2_FP8_TERNARY))]
> +  "TARGET_SSVE_FP8FMA"
> +  {@ [ cons: =0 , 1 , 2 , 3 ; attrs: movprfx ]
> +     [ w        , 0 , w , w ; *              ] 
> <sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b
> +     [ ?&w      , w , w , w ; yes            ] movprfx\t%0, 
> %1\;<sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b
> +  }
> +)
> +
> +(define_insn "@aarch64_sve_add_lane_<sve2_fp8_fma_op><mode>"
> +  [(set (match_operand:SVE_FULL_HSF 0 "register_operand")
> +     (unspec:SVE_FULL_HSF
> +       [(match_operand:SVE_FULL_HSF 1 "register_operand")
> +        (match_operand:VNx16QI 2 "register_operand")
> +        (match_operand:VNx16QI 3 "register_operand")
> +        (match_operand:SI 4 "const_int_operand")
> +        (reg:DI FPM_REGNUM)]
> +       SVE2_FP8_TERNARY_LANE))]
> +  "TARGET_SSVE_FP8FMA"
> +  {@ [ cons: =0 , 1 , 2 , 3 ; attrs: movprfx ]
> +     [ w        , 0 , w , y ; *              ] 
> <sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b[%4]
> +     [ ?&w      , w , w , y ; yes            ] movprfx\t%0, 
> %1\;<sve2_fp8_fma_op>\t%0.<Vetype>, %2.b, %3.b[%4]
> +  }
> +)
> +

It goes against my instincts to ask for more cut-&-paste, but:
I think we should split the operator list into HF-only and SF-only,
rather than define invalid combinations.  [ Hope I didn't suggest the
opposite earlier -- always a risk, unfortunately. :( ]

> [...]
> +/* SVE2 versions of fp8 multiply-accumulate instructions are enabled through 
> +ssve-fp8fma.  */
> +#define TARGET_SSVE_FP8FMA ((\
> +             (TARGET_SVE2 && TARGET_FP8FMA) || TARGET_STREAMING) \
> +             && (AARCH64_HAVE_ISA(SSVE_FP8FMA) || TARGET_NON_STREAMING))

Formatting nits, sorry, but: long line for the comment, and missing space
in the final line.  Also, the comment doesn't cover the non-streaming case.
Maybe:

/* SVE2 versions of fp8 multiply-accumulate instructions are enabled for
   non-streaming mode by +fp8fma and for streaming mode by +ssve-fp8fma.  */
#define TARGET_SSVE_FP8FMA \
  ((TARGET_SVE2 && TARGET_FP8FMA) || TARGET_STREAMING) \
   && (AARCH64_HAVE_ISA (SSVE_FP8FMA) || TARGET_NON_STREAMING))

> diff --git a/gcc/doc/invoke.texi b/gcc/doc/invoke.texi
> index 93e096bc9d5..119f636dc16 100644
> --- a/gcc/doc/invoke.texi
> +++ b/gcc/doc/invoke.texi
> @@ -21824,6 +21824,10 @@ Enable support for Armv8.9-a/9.4-a translation 
> hardening extension.
>  Enable the RCpc3 (Release Consistency) extension.
>  @item fp8
>  Enable the fp8 (8-bit floating point) extension.
> +@item fp8fma
> +Enable the fp8 (8-bit floating point) multiply accumulate extension.
> +@item ssve-fp8fma
> +Enable the fp8 (8-bit floating point) multiply accumulate extension 
> streaming mode.

Maybe "in streaming mode"?  Also: the usual 80-character line limit applies
here too, where possible.

> [...]
> diff --git a/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c 
> b/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c
> new file mode 100644
> index 00000000000..5b43f4d6611
> --- /dev/null
> +++ b/gcc/testsuite/gcc.target/aarch64/sve2/acle/asm/mlalb_lane_mf8.c
> @@ -0,0 +1,88 @@
> +/* { dg-final { check-function-bodies "**" "" "-DCHECK_ASM" } } */
> +/* { dg-additional-options "-march=armv8.5-a+sve2+fp8fma" } */
> +/* { dg-require-effective-target aarch64_asm_fp8fma_ok }  */
> +/* { dg-require-effective-target aarch64_asm_ssve-fp8fma_ok }  */
> +/* { dg-skip-if "" { *-*-* } { "-DSTREAMING_COMPATIBLE" } { "" } } */
> +
> +#include "test_sve_acle.h"

Following on from the comment on patch 3, the corresponding change here
would probably be:

/* { dg-do assemble { target aarch64_asm_ssve-fp8fma_ok } } */
/* { dg-do compile { target { ! aarch64_asm_ssve-fp8fma_ok } } } */
/* { dg-final { check-function-bodies "**" "" "-DCHECK_ASM" } } */

#include "test_sve_acle.h"

#pragma GCC target "+fp8fma"
#ifdef STREAMING_COMPATIBLE
#pragma GCC target "+ssve-fp8fma"
#endif

(which assumes that +ssve-fp8fma is good for +fp8fma too).

> +/*
> +** mlalb_lane_0_f16_tied1:
> +**   msr     fpmr, x0
> +**   fmlalb  z0\.h, z4\.b, z5\.b\[0\]
> +**   ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_0_f16_tied1, svfloat16_t, svmfloat8_t,
> +          z0 = svmlalb_lane_f16_mf8_fpm (z0, z4, z5, 0, fpm0),
> +          z0 = svmlalb_lane_fpm (z0, z4, z5, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_tied2:
> +**   msr     fpmr, x0
> +**   mov     (z[0-9]+)\.d, z0\.d
> +**   movprfx z0, z4
> +**   fmlalb  z0\.h, \1\.b, z1\.b\[0\]
> +**   ret
> +*/
> +TEST_DUAL_Z_REV (mlalb_lane_0_f16_tied2, svfloat16_t, svmfloat8_t,
> +              z0_res = svmlalb_lane_f16_mf8_fpm (z4, z0, z1, 0, fpm0),
> +              z0_res = svmlalb_lane_fpm (z4, z0, z1, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_tied3:
> +**   msr     fpmr, x0
> +**   mov     (z[0-9]+)\.d, z0\.d
> +**   movprfx z0, z4
> +**   fmlalb  z0\.h, z1\.b, \1\.b\[0\]
> +**   ret
> +*/
> +TEST_DUAL_Z_REV (mlalb_lane_0_f16_tied3, svfloat16_t, svmfloat8_t,
> +              z0_res = svmlalb_lane_f16_mf8_fpm (z4, z1, z0, 0, fpm0),
> +              z0_res = svmlalb_lane_fpm (z4, z1, z0, 0, fpm0))
> +
> +/*
> +** mlalb_lane_0_f16_untied:
> +**   msr     fpmr, x0
> +**   movprfx z0, z1
> +**   fmlalb  z0\.h, z4\.b, z5\.b\[0\]
> +**   ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_0_f16_untied, svfloat16_t, svmfloat8_t,
> +          z0 = svmlalb_lane_f16_mf8_fpm (z1, z4, z5, 0, fpm0),
> +          z0 = svmlalb_lane_fpm (z1, z4, z5, 0, fpm0))
> +
> +/*
> +** mlalb_lane_1_f16:
> +**   msr     fpmr, x0
> +**   fmlalb  z0\.h, z4\.b, z5\.b\[1\]
> +**   ret
> +*/
> +TEST_DUAL_Z (mlalb_lane_1_f16, svfloat16_t, svmfloat8_t,
> +          z0 = svmlalb_lane_f16_mf8_fpm (z0, z4, z5, 1, fpm0),
> +          z0 = svmlalb_lane_fpm (z0, z4, z5, 1, fpm0))
> +
> +/*
> +** mlalb_lane_z8_f16:
> +**   ...
> +**   msr     fpmr, x0
> +**   mov     (z[0-7])\.d, z8\.d
> +**   fmlalb  z0\.h, z1\.b, \1\.b\[1\]
> +**   ldr     d8, \[sp\], 32
> +**   ret
> +*/
> +TEST_DUAL_LANE_REG (mlalb_lane_z8_f16, svfloat16_t, svmfloat8_t, z8,
> +                 z0 = svmlalb_lane_f16_mf8_fpm (z0, z1, z8, 1, fpm0),
> +                 z0 = svmlalb_lane_fpm (z0, z1, z8, 1, fpm0))
> +
> +/*
> +** mlalb_lane_z16_f16:
> +**   ...
> +**   msr     fpmr, x0
> +**   mov     (z[0-7])\.d, z16\.d
> +**   fmlalb  z0\.h, z1\.b, \1\.b\[1\]
> +**   ...
> +**   ret
> +*/
> +TEST_DUAL_LANE_REG (mlalb_lane_z16_f16, svfloat16_t, svmfloat8_t, z16,
> +                 z0 = svmlalb_lane_f16_mf8_fpm (z0, z1, z16, 1, fpm0),
> +                 z0 = svmlalb_lane_fpm (z0, z1, z16, 1, fpm0))

It would be good to have a test for the upper limit of the index range,
like for the _f32 tests.  Same for svmlalt_lane.

Looks good to me otherwise, thanks,

Richard

Reply via email to