https://gcc.gnu.org/g:8e45a01d0fd36d21c9743f30a25e277b67e79f0e
commit 8e45a01d0fd36d21c9743f30a25e277b67e79f0e Author: Saurabh Jha <saurabh....@arm.com> Date: Wed Nov 13 17:16:37 2024 +0000 aarch64: Add support for fp8fma instructions The AArch64 FEAT_FP8FMA extension introduces instructions for multiply-add of vectors. This patch introduces the following instructions: 1. {vmlalbq|vmlaltq}_f16_mf8_fpm. 2. {vmlalbq|vmlaltq}_lane{q}_f16_mf8_fpm. 3. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_f32_mf8_fpm. 4. {vmlallbbq|vmlallbtq|vmlalltbq|vmlallttq}_lane{q}_f32_mf8_fpm. It introduces the fp8fma flag. gcc/ChangeLog: * config/aarch64/aarch64-builtins.cc (check_simd_lane_bounds): Add support for new unspecs. (aarch64_expand_pragma_builtins): Add support for new unspecs. * config/aarch64/aarch64-c.cc (aarch64_update_cpp_builtins): New flags. * config/aarch64/aarch64-option-extensions.def (AARCH64_OPT_EXTENSION): New flags. * config/aarch64/aarch64-simd-pragma-builtins.def (ENTRY_FMA_FPM): Macro to declare fma intrinsics. (REQUIRED_EXTENSIONS): Define to declare functions behind command line flags. * config/aarch64/aarch64-simd.md: (@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><V16QI_ONLY:mode): Instruction pattern for fma intrinsics. (@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><VB:mode><SI_ONLY:mode): Instruction pattern for fma intrinsics with lane. * config/aarch64/aarch64.h (TARGET_FP8FMA): New flag for fp8fma instructions. * config/aarch64/iterators.md: New attributes and iterators. * doc/invoke.texi: New flag for fp8fma instructions. gcc/testsuite/ChangeLog: * gcc.target/aarch64/simd/fma_fpm.c: New test. Diff: --- gcc/config/aarch64/aarch64-builtins.cc | 30 ++- gcc/config/aarch64/aarch64-c.cc | 2 + gcc/config/aarch64/aarch64-option-extensions.def | 2 + .../aarch64/aarch64-simd-pragma-builtins.def | 16 ++ gcc/config/aarch64/aarch64-simd.md | 29 +++ gcc/config/aarch64/aarch64.h | 3 + gcc/config/aarch64/iterators.md | 18 ++ gcc/doc/invoke.texi | 2 + gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c | 221 +++++++++++++++++++++ 9 files changed, 319 insertions(+), 4 deletions(-) diff --git a/gcc/config/aarch64/aarch64-builtins.cc b/gcc/config/aarch64/aarch64-builtins.cc index a71c8c9a64e9..7b2decf671fa 100644 --- a/gcc/config/aarch64/aarch64-builtins.cc +++ b/gcc/config/aarch64/aarch64-builtins.cc @@ -2562,10 +2562,26 @@ check_simd_lane_bounds (location_t location, const aarch64_pragma_builtins_data = GET_MODE_NUNITS (vector_to_index_mode).to_constant (); auto low = 0; - int high - = builtin_data->unspec == UNSPEC_VDOT2 - ? vector_to_index_mode_size / 2 - 1 - : vector_to_index_mode_size / 4 - 1; + int high; + switch (builtin_data->unspec) + { + case UNSPEC_VDOT2: + high = vector_to_index_mode_size / 2 - 1; + break; + case UNSPEC_VDOT4: + high = vector_to_index_mode_size / 4 - 1; + break; + case UNSPEC_FMLALB: + case UNSPEC_FMLALT: + case UNSPEC_FMLALLBB: + case UNSPEC_FMLALLBT: + case UNSPEC_FMLALLTB: + case UNSPEC_FMLALLTT: + high = vector_to_index_mode_size - 1; + break; + default: + gcc_unreachable (); + } require_immediate_range (location, index_arg, low, high); break; } @@ -3552,6 +3568,12 @@ aarch64_expand_pragma_builtin (tree exp, rtx target, case UNSPEC_VDOT2: case UNSPEC_VDOT4: + case UNSPEC_FMLALB: + case UNSPEC_FMLALT: + case UNSPEC_FMLALLBB: + case UNSPEC_FMLALLBT: + case UNSPEC_FMLALLTB: + case UNSPEC_FMLALLTT: if (builtin_data->signature == aarch64_builtin_signatures::ternary) icode = code_for_aarch64 (builtin_data->unspec, builtin_data->types[0].mode, diff --git a/gcc/config/aarch64/aarch64-c.cc b/gcc/config/aarch64/aarch64-c.cc index ae1472e0fcf2..03f912cde077 100644 --- a/gcc/config/aarch64/aarch64-c.cc +++ b/gcc/config/aarch64/aarch64-c.cc @@ -264,6 +264,8 @@ aarch64_update_cpp_builtins (cpp_reader *pfile) aarch64_def_or_undef (TARGET_FP8DOT4, "__ARM_FEATURE_FP8DOT4", pfile); + aarch64_def_or_undef (TARGET_FP8FMA, "__ARM_FEATURE_FP8FMA", pfile); + aarch64_def_or_undef (TARGET_LS64, "__ARM_FEATURE_LS64", pfile); aarch64_def_or_undef (TARGET_RCPC, "__ARM_FEATURE_RCPC", pfile); diff --git a/gcc/config/aarch64/aarch64-option-extensions.def b/gcc/config/aarch64/aarch64-option-extensions.def index 44d2e18d46bd..8446d1bcd5dc 100644 --- a/gcc/config/aarch64/aarch64-option-extensions.def +++ b/gcc/config/aarch64/aarch64-option-extensions.def @@ -240,6 +240,8 @@ AARCH64_OPT_EXTENSION("fp8dot2", FP8DOT2, (SIMD), (), (), "fp8dot2") AARCH64_OPT_EXTENSION("fp8dot4", FP8DOT4, (SIMD), (), (), "fp8dot4") +AARCH64_OPT_EXTENSION("fp8fma", FP8FMA, (SIMD), (), (), "fp8fma") + AARCH64_OPT_EXTENSION("faminmax", FAMINMAX, (SIMD), (), (), "faminmax") #undef AARCH64_OPT_FMV_EXTENSION diff --git a/gcc/config/aarch64/aarch64-simd-pragma-builtins.def b/gcc/config/aarch64/aarch64-simd-pragma-builtins.def index 4a94a6613f08..c7857123ca03 100644 --- a/gcc/config/aarch64/aarch64-simd-pragma-builtins.def +++ b/gcc/config/aarch64/aarch64-simd-pragma-builtins.def @@ -48,6 +48,12 @@ ENTRY_TERNARY_FPM_LANE (vdotq_lane_##T##_mf8_fpm, T##q, T##q, f8q, f8, U) \ ENTRY_TERNARY_FPM_LANE (vdotq_laneq_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U) +#undef ENTRY_FMA_FPM +#define ENTRY_FMA_FPM(N, T, U) \ + ENTRY_TERNARY_FPM (N##_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U) \ + ENTRY_TERNARY_FPM_LANE (N##_lane_##T##_mf8_fpm, T##q, T##q, f8q, f8, U) \ + ENTRY_TERNARY_FPM_LANE (N##_laneq_##T##_mf8_fpm, T##q, T##q, f8q, f8q, U) + #undef ENTRY_VHSDF #define ENTRY_VHSDF(NAME, UNSPEC) \ ENTRY_BINARY (NAME##_f16, f16, f16, f16, UNSPEC) \ @@ -106,3 +112,13 @@ ENTRY_VDOT_FPM (f16, UNSPEC_VDOT2) #define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8DOT4) ENTRY_VDOT_FPM (f32, UNSPEC_VDOT4) #undef REQUIRED_EXTENSIONS + +// fp8 multiply-add +#define REQUIRED_EXTENSIONS nonstreaming_only (AARCH64_FL_FP8FMA) +ENTRY_FMA_FPM (vmlalbq, f16, UNSPEC_FMLALB) +ENTRY_FMA_FPM (vmlaltq, f16, UNSPEC_FMLALT) +ENTRY_FMA_FPM (vmlallbbq, f32, UNSPEC_FMLALLBB) +ENTRY_FMA_FPM (vmlallbtq, f32, UNSPEC_FMLALLBT) +ENTRY_FMA_FPM (vmlalltbq, f32, UNSPEC_FMLALLTB) +ENTRY_FMA_FPM (vmlallttq, f32, UNSPEC_FMLALLTT) +#undef REQUIRED_EXTENSIONS diff --git a/gcc/config/aarch64/aarch64-simd.md b/gcc/config/aarch64/aarch64-simd.md index 7b974865f555..df0d30af6a11 100644 --- a/gcc/config/aarch64/aarch64-simd.md +++ b/gcc/config/aarch64/aarch64-simd.md @@ -10155,3 +10155,32 @@ "TARGET_FP8DOT4" "<fpm_uns_op>\t%1.<VDQSF:Vtype>, %2.<VB:Vtype>, %3.<VDQSF:Vdotlanetype>[%4]" ) + +;; fpm fma instructions. +(define_insn + "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><V16QI_ONLY:mode>" + [(set (match_operand:VQ_HSF 0 "register_operand" "=w") + (unspec:VQ_HSF + [(match_operand:VQ_HSF 1 "register_operand" "w") + (match_operand:V16QI_ONLY 2 "register_operand" "w") + (match_operand:V16QI_ONLY 3 "register_operand" "w") + (reg:DI FPM_REGNUM)] + FPM_FMA_UNS))] + "TARGET_FP8FMA" + "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.<V16QI_ONLY:Vtype>" +) + +;; fpm fma instructions with lane. +(define_insn + "@aarch64_<fpm_uns_op><VQ_HSF:mode><VQ_HSF:mode><V16QI_ONLY:mode><VB:mode><SI_ONLY:mode>" + [(set (match_operand:VQ_HSF 0 "register_operand" "=w") + (unspec:VQ_HSF + [(match_operand:VQ_HSF 1 "register_operand" "w") + (match_operand:V16QI_ONLY 2 "register_operand" "w") + (match_operand:VB 3 "register_operand" "w") + (match_operand:SI_ONLY 4 "const_int_operand" "n") + (reg:DI FPM_REGNUM)] + FPM_FMA_UNS))] + "TARGET_FP8FMA" + "<fpm_uns_op>\t%1.<VQ_HSF:Vtype>, %2.<V16QI_ONLY:Vtype>, %3.b[%4]" +) diff --git a/gcc/config/aarch64/aarch64.h b/gcc/config/aarch64/aarch64.h index c50a578731a5..a691a0f2b181 100644 --- a/gcc/config/aarch64/aarch64.h +++ b/gcc/config/aarch64/aarch64.h @@ -500,6 +500,9 @@ constexpr auto AARCH64_FL_DEFAULT_ISA_MODE ATTRIBUTE_UNUSED /* fp8 dot product instructions are enabled through +fp8dot4. */ #define TARGET_FP8DOT4 AARCH64_HAVE_ISA (FP8DOT4) +/* fp8 multiply-add instructions are enabled through +fp8fma. */ +#define TARGET_FP8FMA AARCH64_HAVE_ISA (FP8FMA) + /* Standard register usage. */ /* 31 64-bit general purpose registers R0-R30: diff --git a/gcc/config/aarch64/iterators.md b/gcc/config/aarch64/iterators.md index 8c03dcd14dd1..82dc7dcf7621 100644 --- a/gcc/config/aarch64/iterators.md +++ b/gcc/config/aarch64/iterators.md @@ -722,6 +722,10 @@ UNSPEC_FMINNMV ; Used in aarch64-simd.md. UNSPEC_FMINV ; Used in aarch64-simd.md. UNSPEC_FADDV ; Used in aarch64-simd.md. + UNSPEC_FMLALLBB ; Used in aarch64-simd.md. + UNSPEC_FMLALLBT ; Used in aarch64-simd.md. + UNSPEC_FMLALLTB ; Used in aarch64-simd.md. + UNSPEC_FMLALLTT ; Used in aarch64-simd.md. UNSPEC_FNEG ; Used in aarch64-simd.md. UNSPEC_FSCALE ; Used in aarch64-simd.md. UNSPEC_ADDV ; Used in aarch64-simd.md. @@ -4735,9 +4739,23 @@ (define_int_iterator FPM_VDOT2_UNS [UNSPEC_VDOT2]) (define_int_iterator FPM_VDOT4_UNS [UNSPEC_VDOT4]) +(define_int_iterator FPM_FMA_UNS + [UNSPEC_FMLALB + UNSPEC_FMLALT + UNSPEC_FMLALLBB + UNSPEC_FMLALLBT + UNSPEC_FMLALLTB + UNSPEC_FMLALLTT]) + (define_int_attr fpm_uns_op [(UNSPEC_FSCALE "fscale") (UNSPEC_VCVT "fcvtn") (UNSPEC_VCVT_HIGH "fcvtn2") + (UNSPEC_FMLALB "fmlalb") + (UNSPEC_FMLALT "fmlalt") + (UNSPEC_FMLALLBB "fmlallbb") + (UNSPEC_FMLALLBT "fmlallbt") + (UNSPEC_FMLALLTB "fmlalltb") + (UNSPEC_FMLALLTT "fmlalltt") (UNSPEC_VDOT2 "fdot") (UNSPEC_VDOT4 "fdot")]) diff --git a/gcc/doc/invoke.texi b/gcc/doc/invoke.texi index bc3f74234259..d41136bebc1c 100644 --- a/gcc/doc/invoke.texi +++ b/gcc/doc/invoke.texi @@ -21811,6 +21811,8 @@ Enable the fp8 (8-bit floating point) extension. Enable the fp8dot2 (8-bit floating point dot product) extension. @item fp8dot4 Enable the fp8dot4 (8-bit floating point dot product) extension. +@item fp8fma +Enable the fp8fma (8-bit floating point multiply-add) extension. @item faminmax Enable the Floating Point Absolute Maximum/Minimum extension. diff --git a/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c b/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c new file mode 100644 index 000000000000..ea21856fa626 --- /dev/null +++ b/gcc/testsuite/gcc.target/aarch64/simd/fma_fpm.c @@ -0,0 +1,221 @@ +/* { dg-do compile } */ +/* { dg-additional-options "-O3 -march=armv9-a+fp8fma" } */ +/* { dg-final { check-function-bodies "**" "" } } */ + +#include "arm_neon.h" + +/* +** test_vmlalbq_f16_fpm: +** msr fpmr, x0 +** fmlalb v0.8h, v1.16b, v2.16b +** ret +*/ +float16x8_t +test_vmlalbq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlalbq_f16_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlaltq_f16_fpm: +** msr fpmr, x0 +** fmlalt v0.8h, v1.16b, v2.16b +** ret +*/ +float16x8_t +test_vmlaltq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlaltq_f16_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlallbbq_f32_fpm: +** msr fpmr, x0 +** fmlallbb v0.4s, v1.16b, v2.16b +** ret +*/ +float32x4_t +test_vmlallbbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallbbq_f32_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlallbtq_f32_fpm: +** msr fpmr, x0 +** fmlallbt v0.4s, v1.16b, v2.16b +** ret +*/ +float32x4_t +test_vmlallbtq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallbtq_f32_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlalltbq_f32_fpm: +** msr fpmr, x0 +** fmlalltb v0.4s, v1.16b, v2.16b +** ret +*/ +float32x4_t +test_vmlalltbq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlalltbq_f32_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlallttq_f32_fpm: +** msr fpmr, x0 +** fmlalltt v0.4s, v1.16b, v2.16b +** ret +*/ +float32x4_t +test_vmlallttq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallttq_f32_mf8_fpm (a, b, c, d); +} + +/* +** test_vmlalbq_lane_f16_fpm: +** msr fpmr, x0 +** fmlalb v0.8h, v1.16b, v2.b\[1\] +** ret +*/ +float16x8_t +test_vmlalbq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlalbq_lane_f16_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlalbq_laneq_f16_fpm: +** msr fpmr, x0 +** fmlalb v0.8h, v1.16b, v2.b\[1\] +** ret +*/ +float16x8_t +test_vmlalbq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlalbq_laneq_f16_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlaltq_lane_f16_fpm: +** msr fpmr, x0 +** fmlalt v0.8h, v1.16b, v2.b\[1\] +** ret +*/ +float16x8_t +test_vmlaltq_lane_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlaltq_lane_f16_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlaltq_laneq_f16_fpm: +** msr fpmr, x0 +** fmlalt v0.8h, v1.16b, v2.b\[1\] +** ret +*/ +float16x8_t +test_vmlaltq_laneq_f16_fpm (float16x8_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlaltq_laneq_f16_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallbbq_lane_f32_fpm: +** msr fpmr, x0 +** fmlallbb v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallbbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlallbbq_lane_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallbbq_laneq_f32_fpm: +** msr fpmr, x0 +** fmlallbb v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallbbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallbbq_laneq_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallbtq_lane_f32_fpm: +** msr fpmr, x0 +** fmlallbt v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallbtq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlallbtq_lane_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallbtq_laneq_f32_fpm: +** msr fpmr, x0 +** fmlallbt v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallbtq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallbtq_laneq_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlalltbq_lane_f32_fpm: +** msr fpmr, x0 +** fmlalltb v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlalltbq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlalltbq_lane_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlalltbq_laneq_f32_fpm: +** msr fpmr, x0 +** fmlalltb v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlalltbq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlalltbq_laneq_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallttq_lane_f32_fpm: +** msr fpmr, x0 +** fmlalltt v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallttq_lane_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x8_t c, fpm_t d) +{ + return vmlallttq_lane_f32_mf8_fpm (a, b, c, 1, d); +} + +/* +** test_vmlallttq_laneq_f32_fpm: +** msr fpmr, x0 +** fmlalltt v0.4s, v1.16b, v2.b\[1\] +** ret +*/ +float32x4_t +test_vmlallttq_laneq_f32_fpm (float32x4_t a, mfloat8x16_t b, mfloat8x16_t c, fpm_t d) +{ + return vmlallttq_laneq_f32_mf8_fpm (a, b, c, 1, d); +}