https://gcc.gnu.org/g:210d06502f22964c7214586c54f8eb54a6965bfd

commit r16-446-g210d06502f22964c7214586c54f8eb54a6965bfd
Author: Jennifer Schmitz <jschm...@nvidia.com>
Date:   Fri Feb 14 00:46:13 2025 -0800

    AArch64: Fold SVE load/store with certain ptrue patterns to LDR/STR.
    
    SVE loads/stores using predicates that select the bottom 8, 16, 32, 64,
    or 128 bits of a register can be folded to ASIMD LDR/STR, thus avoiding the
    predicate.
    For example,
    svuint8_t foo (uint8_t *x) {
      return svld1 (svwhilelt_b8 (0, 16), x);
    }
    was previously compiled to:
    foo:
            ptrue   p3.b, vl16
            ld1b    z0.b, p3/z, [x0]
            ret
    
    and is now compiled to:
    foo:
            ldr     q0, [x0]
            ret
    
    The optimization is applied during the expand pass and was implemented
    by making the following changes to maskload<mode><vpred> and
    maskstore<mode><vpred>:
    - the existing define_insns were renamed and new define_expands for 
maskloads
      and maskstores were added with nonmemory_operand as predicate such that 
the
      SVE predicate matches both register operands and constant-vector operands.
    - if the SVE predicate is a constant vector and contains a pattern as
      described above, an ASIMD load/store is emitted instead of the SVE 
load/store.
    
    The patch implements the optimization for LD1 and ST1, for 8-bit, 16-bit,
    32-bit, 64-bit, and 128-bit moves, for all full SVE data vector modes.
    
    Follow-up patches for LD2/3/4 and ST2/3/4 and potentially partial SVE vector
    modes are planned.
    
    The patch was bootstrapped and tested on aarch64-linux-gnu, no regression.
    
    Signed-off-by: Jennifer Schmitz <jschm...@nvidia.com>
    
    gcc/
            PR target/117978
            * config/aarch64/aarch64-protos.h: Declare
            aarch64_emit_load_store_through_mode and aarch64_sve_maskloadstore.
            * config/aarch64/aarch64-sve.md
            (maskload<mode><vpred>): New define_expand folding maskloads with
            certain predicate patterns to ASIMD loads.
            (*aarch64_maskload<mode><vpred>): Renamed from 
maskload<mode><vpred>.
            (maskstore<mode><vpred>): New define_expand folding maskstores with
            certain predicate patterns to ASIMD stores.
            (*aarch64_maskstore<mode><vpred>): Renamed from 
maskstore<mode><vpred>.
            * config/aarch64/aarch64.cc
            (aarch64_emit_load_store_through_mode): New function emitting a
            load/store through subregs of a given mode.
            (aarch64_emit_sve_pred_move): Refactor to use
            aarch64_emit_load_store_through_mode.
            (aarch64_expand_maskloadstore): New function to emit ASIMD 
loads/stores
            for maskloads/stores with SVE predicates with VL1, VL2, VL4, VL8, or
            VL16 patterns.
            (aarch64_partial_ptrue_length): New function returning number of 
leading
            set bits in a predicate.
    
    gcc/testsuite/
            PR target/117978
            * gcc.target/aarch64/sve/acle/general/whilelt_5.c: Adjust expected
            outcome.
            * gcc.target/aarch64/sve/ldst_ptrue_pat_128_to_neon.c: New test.
            * gcc.target/aarch64/sve/while_7.c: Adjust expected outcome.
            * gcc.target/aarch64/sve/while_9.c: Adjust expected outcome.

Diff:
---
 gcc/config/aarch64/aarch64-protos.h                |  2 +
 gcc/config/aarch64/aarch64-sve.md                  | 38 ++++++++-
 gcc/config/aarch64/aarch64.cc                      | 98 +++++++++++++++++++---
 .../aarch64/sve/acle/general/whilelt_5.c           | 24 ++++--
 .../aarch64/sve/ldst_ptrue_pat_128_to_neon.c       | 81 ++++++++++++++++++
 gcc/testsuite/gcc.target/aarch64/sve/while_7.c     |  4 +-
 gcc/testsuite/gcc.target/aarch64/sve/while_9.c     |  2 +-
 7 files changed, 227 insertions(+), 22 deletions(-)

diff --git a/gcc/config/aarch64/aarch64-protos.h 
b/gcc/config/aarch64/aarch64-protos.h
index 1ca86c9d175d..c83c35c6d71e 100644
--- a/gcc/config/aarch64/aarch64-protos.h
+++ b/gcc/config/aarch64/aarch64-protos.h
@@ -1026,6 +1026,8 @@ rtx aarch64_ptrue_reg (machine_mode, unsigned int);
 rtx aarch64_ptrue_reg (machine_mode, machine_mode);
 rtx aarch64_pfalse_reg (machine_mode);
 bool aarch64_sve_same_pred_for_ptest_p (rtx *, rtx *);
+void aarch64_emit_load_store_through_mode (rtx, rtx, machine_mode);
+bool aarch64_expand_maskloadstore (rtx *, machine_mode);
 void aarch64_emit_sve_pred_move (rtx, rtx, rtx);
 void aarch64_expand_sve_mem_move (rtx, rtx, machine_mode);
 bool aarch64_maybe_expand_sve_subreg_move (rtx, rtx);
diff --git a/gcc/config/aarch64/aarch64-sve.md 
b/gcc/config/aarch64/aarch64-sve.md
index 7bf12ff25ccd..f39af6e24d51 100644
--- a/gcc/config/aarch64/aarch64-sve.md
+++ b/gcc/config/aarch64/aarch64-sve.md
@@ -1286,7 +1286,24 @@
 ;; -------------------------------------------------------------------------
 
 ;; Predicated LD1 (single).
-(define_insn "maskload<mode><vpred>"
+(define_expand "maskload<mode><vpred>"
+  [(set (match_operand:SVE_ALL 0 "register_operand")
+       (unspec:SVE_ALL
+         [(match_operand:<VPRED> 2 "nonmemory_operand")
+          (match_operand:SVE_ALL 1 "memory_operand")
+          (match_operand:SVE_ALL 3 "aarch64_maskload_else_operand")]
+         UNSPEC_LD1_SVE))]
+  "TARGET_SVE"
+  {
+    if (aarch64_expand_maskloadstore (operands, <MODE>mode))
+      DONE;
+    if (CONSTANT_P (operands[2]))
+      operands[2] = force_reg (<VPRED>mode, operands[2]);
+  }
+)
+
+;; Predicated LD1 (single).
+(define_insn "*aarch64_maskload<mode><vpred>"
   [(set (match_operand:SVE_ALL 0 "register_operand" "=w")
        (unspec:SVE_ALL
          [(match_operand:<VPRED> 2 "register_operand" "Upl")
@@ -2287,7 +2304,24 @@
 ;; -------------------------------------------------------------------------
 
 ;; Predicated ST1 (single).
-(define_insn "maskstore<mode><vpred>"
+(define_expand "maskstore<mode><vpred>"
+  [(set (match_operand:SVE_ALL 0 "memory_operand")
+       (unspec:SVE_ALL
+         [(match_operand:<VPRED> 2 "nonmemory_operand")
+          (match_operand:SVE_ALL 1 "register_operand")
+          (match_dup 0)]
+         UNSPEC_ST1_SVE))]
+  "TARGET_SVE"
+  {
+    if (aarch64_expand_maskloadstore (operands, <MODE>mode))
+      DONE;
+    if (CONSTANT_P (operands[2]))
+      operands[2] = force_reg (<VPRED>mode, operands[2]);
+  }
+)
+
+;; Predicated ST1 (single).
+(define_insn "*aarch64_maskstore<mode><vpred>"
   [(set (match_operand:SVE_ALL 0 "memory_operand" "+m")
        (unspec:SVE_ALL
          [(match_operand:<VPRED> 2 "register_operand" "Upl")
diff --git a/gcc/config/aarch64/aarch64.cc b/gcc/config/aarch64/aarch64.cc
index fff8d9da49d3..2dc5f4c4b59d 100644
--- a/gcc/config/aarch64/aarch64.cc
+++ b/gcc/config/aarch64/aarch64.cc
@@ -3667,6 +3667,14 @@ aarch64_partial_ptrue_length (rtx_vector_builder 
&builder,
   if (builder.nelts_per_pattern () == 3)
     return 0;
 
+  /* It is conservatively correct to drop the element size to a lower value,
+     and we must do so if the predicate consists of a leading "foreground"
+     sequence that is smaller than the element size.  Without this,
+     we would test only one bit and so treat everything as either an
+     all-true or an all-false predicate.  */
+  if (builder.nelts_per_pattern () == 2)
+    elt_size = MIN (elt_size, builder.npatterns ());
+
   /* Skip over leading set bits.  */
   unsigned int nelts = builder.encoded_nelts ();
   unsigned int i = 0;
@@ -3698,6 +3706,24 @@ aarch64_partial_ptrue_length (rtx_vector_builder 
&builder,
   return vl;
 }
 
+/* Return:
+
+   * -1 if all bits of PRED are set
+   * N if PRED has N leading set bits followed by all clear bits
+   * 0 if PRED does not have any of these forms.  */
+
+int
+aarch64_partial_ptrue_length (rtx pred)
+{
+  rtx_vector_builder builder;
+  if (!aarch64_get_sve_pred_bits (builder, pred))
+    return 0;
+
+  auto elt_size = vector_element_size (GET_MODE_BITSIZE (GET_MODE (pred)),
+                                      GET_MODE_NUNITS (GET_MODE (pred)));
+  return aarch64_partial_ptrue_length (builder, elt_size);
+}
+
 /* See if there is an svpattern that encodes an SVE predicate of mode
    PRED_MODE in which the first VL bits are set and the rest are clear.
    Return the pattern if so, otherwise return AARCH64_NUM_SVPATTERNS.
@@ -6410,8 +6436,32 @@ aarch64_stack_protect_canary_mem (machine_mode mode, rtx 
decl_rtl,
   return gen_rtx_MEM (mode, force_reg (Pmode, addr));
 }
 
-/* Emit an SVE predicated move from SRC to DEST.  PRED is a predicate
-   that is known to contain PTRUE.  */
+/* Emit a load/store from a subreg of SRC to a subreg of DEST.
+   The subregs have mode NEW_MODE. Use only for reg<->mem moves.  */
+void
+aarch64_emit_load_store_through_mode (rtx dest, rtx src, machine_mode new_mode)
+{
+  gcc_assert ((MEM_P (dest) && register_operand (src, VOIDmode))
+             || (MEM_P (src) && register_operand (dest, VOIDmode)));
+  auto mode = GET_MODE (dest);
+  auto int_mode = aarch64_sve_int_mode (mode);
+  if (MEM_P (src))
+    {
+      rtx tmp = force_reg (new_mode, adjust_address (src, new_mode, 0));
+      tmp = force_lowpart_subreg (int_mode, tmp, new_mode);
+      emit_move_insn (dest, force_lowpart_subreg (mode, tmp, int_mode));
+    }
+  else
+    {
+      src = force_lowpart_subreg (int_mode, src, mode);
+      emit_move_insn (adjust_address (dest, new_mode, 0),
+                     force_lowpart_subreg (new_mode, src, int_mode));
+    }
+}
+
+/* PRED is a predicate that is known to contain PTRUE.
+   For 128-bit VLS loads/stores, emit LDR/STR.
+   Else, emit an SVE predicated move from SRC to DEST.  */
 
 void
 aarch64_emit_sve_pred_move (rtx dest, rtx pred, rtx src)
@@ -6421,16 +6471,7 @@ aarch64_emit_sve_pred_move (rtx dest, rtx pred, rtx src)
       && known_eq (GET_MODE_SIZE (mode), 16)
       && aarch64_classify_vector_mode (mode) == VEC_SVE_DATA
       && !BYTES_BIG_ENDIAN)
-    {
-      if (MEM_P (src))
-       {
-         rtx tmp = force_reg (V16QImode, adjust_address (src, V16QImode, 0));
-         emit_move_insn (dest, lowpart_subreg (mode, tmp, V16QImode));
-       }
-      else
-       emit_move_insn (adjust_address (dest, V16QImode, 0),
-                       force_lowpart_subreg (V16QImode, src, mode));
-    }
+    aarch64_emit_load_store_through_mode (dest, src, V16QImode);
   else
     {
       expand_operand ops[3];
@@ -23526,6 +23567,39 @@ aarch64_simd_valid_imm (rtx op, simd_immediate_info 
*info,
   return false;
 }
 
+/* Try to optimize the expansion of a maskload or maskstore with
+   the operands in OPERANDS, given that the vector being loaded or
+   stored has mode MODE.  Return true on success or false if the normal
+   expansion should be used.  */
+
+bool
+aarch64_expand_maskloadstore (rtx *operands, machine_mode mode)
+{
+  /* If the predicate in operands[2] is a patterned SVE PTRUE predicate
+     with patterns VL1, VL2, VL4, VL8, or VL16 and at most the bottom
+     128 bits are loaded/stored, emit an ASIMD load/store.  */
+  int vl = aarch64_partial_ptrue_length (operands[2]);
+  int width = vl * GET_MODE_UNIT_BITSIZE (mode);
+  if (width <= 128
+      && pow2p_hwi (vl)
+      && (vl == 1
+         || (!BYTES_BIG_ENDIAN
+             && aarch64_classify_vector_mode (mode) == VEC_SVE_DATA)))
+    {
+      machine_mode new_mode;
+      if (known_eq (width, 128))
+       new_mode = V16QImode;
+      else if (known_eq (width, 64))
+       new_mode = V8QImode;
+      else
+       new_mode = int_mode_for_size (width, 0).require ();
+      aarch64_emit_load_store_through_mode (operands[0], operands[1],
+                                           new_mode);
+      return true;
+    }
+  return false;
+}
+
 /* Return true if OP is a valid SIMD move immediate for SVE or AdvSIMD.  */
 bool
 aarch64_simd_valid_mov_imm (rtx op)
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/acle/general/whilelt_5.c 
b/gcc/testsuite/gcc.target/aarch64/sve/acle/general/whilelt_5.c
index f06a74aa2daa..05e266aad7d4 100644
--- a/gcc/testsuite/gcc.target/aarch64/sve/acle/general/whilelt_5.c
+++ b/gcc/testsuite/gcc.target/aarch64/sve/acle/general/whilelt_5.c
@@ -11,8 +11,7 @@ extern "C" {
 
 /*
 ** load_vl1:
-**     ptrue   (p[0-7])\.[bhsd], vl1
-**     ld1h    z0\.h, \1/z, \[x0\]
+**     ldr     h0, \[x0\]
 **     ret
 */
 svint16_t
@@ -22,7 +21,12 @@ load_vl1 (int16_t *ptr)
 }
 
 /*
-** load_vl2:
+** load_vl2: { target aarch64_little_endian }
+**     ldr     s0, \[x0\]
+**     ret
+*/
+/*
+** load_vl2: { target aarch64_big_endian }
 **     ptrue   (p[0-7])\.h, vl2
 **     ld1h    z0\.h, \1/z, \[x0\]
 **     ret
@@ -46,7 +50,12 @@ load_vl3 (int16_t *ptr)
 }
 
 /*
-** load_vl4:
+** load_vl4: { target aarch64_little_endian }
+**     ldr     d0, \[x0\]
+**     ret
+*/
+/*
+** load_vl4: { target aarch64_big_endian }
 **     ptrue   (p[0-7])\.h, vl4
 **     ld1h    z0\.h, \1/z, \[x0\]
 **     ret
@@ -94,7 +103,12 @@ load_vl7 (int16_t *ptr)
 }
 
 /*
-** load_vl8:
+** load_vl8: { target aarch64_little_endian }
+**     ldr     q0, \[x0\]
+**     ret
+*/
+/*
+** load_vl8: { target aarch64_big_endian }
 **     ptrue   (p[0-7])\.h, vl8
 **     ld1h    z0\.h, \1/z, \[x0\]
 **     ret
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/ldst_ptrue_pat_128_to_neon.c 
b/gcc/testsuite/gcc.target/aarch64/sve/ldst_ptrue_pat_128_to_neon.c
new file mode 100644
index 000000000000..2d47c1f1a3d7
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/sve/ldst_ptrue_pat_128_to_neon.c
@@ -0,0 +1,81 @@
+/* { dg-do compile } */
+/* { dg-options "-O2" } */
+/* { dg-require-effective-target aarch64_little_endian } */
+
+#include <arm_sve.h>
+
+#define TEST(TYPE, TY, W, B)                                           \
+  sv##TYPE                                                             \
+  ld1_##TY##W##B##_1 (TYPE *x)                                         \
+  {                                                                    \
+    svbool_t pg = svwhilelt_b##B (0, W);                               \
+    return svld1_##TY##B (pg, x);                                      \
+  }                                                                    \
+  sv##TYPE                                                             \
+  ld1_##TY##W##B##_2 (TYPE *x)                                         \
+  {                                                                    \
+    svbool_t pg = svptrue_pat_b##B ((enum svpattern) (W > 8 ? 9 : W)); \
+    return svld1_##TY##B (pg, x);                                      \
+  }                                                                    \
+  void                                                                 \
+  st1_##TY##W##B##_1 (TYPE *x, sv##TYPE data)                                  
        \
+  {                                                                    \
+    svbool_t pg = svwhilelt_b##B (0, W);                               \
+    return svst1_##TY##B (pg, x, data);                                        
\
+  }                                                                    \
+  void                                                                 \
+  st1_##TY##W##B##_2 (TYPE *x, sv##TYPE data)                          \
+  {                                                                    \
+    svbool_t pg = svptrue_pat_b##B ((enum svpattern) (W > 8 ? 9 : W)); \
+    return svst1_##TY##B (pg, x, data);                                        
\
+  }                                                                    \
+
+#define TEST64(TYPE, TY, B)                            \
+  TEST (TYPE, TY, 1, B)                                        \
+  TEST (TYPE, TY, 2, B)                                        \
+
+#define TEST32(TYPE, TY, B)                            \
+  TEST64 (TYPE, TY, B)                                 \
+  TEST (TYPE, TY, 4, B)                                        \
+
+#define TEST16(TYPE, TY, B)                            \
+  TEST32 (TYPE, TY, B)                                 \
+  TEST (TYPE, TY, 8, B)                                        \
+
+#define TEST8(TYPE, TY, B)                             \
+  TEST16 (TYPE, TY, B)                                 \
+  TEST (TYPE, TY, 16, B)
+
+#define T(TYPE, TY, B)                 \
+  TEST##B (TYPE, TY, B)
+
+T (bfloat16_t, bf, 16)
+T (float16_t, f, 16)
+T (float32_t, f, 32)
+T (float64_t, f, 64)
+T (int8_t, s, 8)
+T (int16_t, s, 16)
+T (int32_t, s, 32)
+T (int64_t, s, 64)
+T (uint8_t, u, 8)
+T (uint16_t, u, 16)
+T (uint32_t, u, 32)
+T (uint64_t, u, 64)
+
+/* { dg-final { scan-assembler-times {\tldr\tq0, \[x0\]} 24 } } */
+/* { dg-final { scan-assembler-times {\tldr\td0, \[x0\]} 24 } } */
+/* { dg-final { scan-assembler-times {\tldr\ts0, \[x0\]} 18 } } */
+/* { dg-final { scan-assembler-times {\tldr\th0, \[x0\]} 12 } } */
+/* { dg-final { scan-assembler-times {\tldr\tb0, \[x0\]} 4 } } */
+
+/* { dg-final { scan-assembler-times {\tstr\tq0, \[x0\]} 24 } } */
+/* { dg-final { scan-assembler-times {\tstr\td0, \[x0\]} 24 } } */
+/* { dg-final { scan-assembler-times {\tstr\ts0, \[x0\]} 18 } } */
+/* { dg-final { scan-assembler-times {\tstr\th0, \[x0\]} 12 } } */
+/* { dg-final { scan-assembler-times {\tstr\tb0, \[x0\]} 4 } } */
+
+svint8_t foo (int8_t *x)
+{
+  return svld1_s8 (svptrue_b16 (), x);
+}
+/* { dg-final { scan-assembler-times {\tptrue\tp[0-7]\.h, all\n\tld1b} 1 } } */
\ No newline at end of file
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/while_7.c 
b/gcc/testsuite/gcc.target/aarch64/sve/while_7.c
index a66a20d21f65..ab2fa3646fcf 100644
--- a/gcc/testsuite/gcc.target/aarch64/sve/while_7.c
+++ b/gcc/testsuite/gcc.target/aarch64/sve/while_7.c
@@ -19,7 +19,7 @@
 
 TEST_ALL (ADD_LOOP)
 
-/* { dg-final { scan-assembler-times {\tptrue\tp[0-7]\.b, vl8\n} 1 } } */
-/* { dg-final { scan-assembler-times {\tptrue\tp[0-7]\.h, vl8\n} 1 } } */
+/* { dg-final { scan-assembler-times {\tldr\td[0-9]+, \[x0\]} 1 } } */
+/* { dg-final { scan-assembler-times {\tldr\tq[0-9]+, \[x0\]} 1 } } */
 /* { dg-final { scan-assembler-times {\twhilelo\tp[0-7]\.s,} 2 } } */
 /* { dg-final { scan-assembler-times {\twhilelo\tp[0-7]\.d,} 2 } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/sve/while_9.c 
b/gcc/testsuite/gcc.target/aarch64/sve/while_9.c
index dd3f404ab396..99940dd73fa1 100644
--- a/gcc/testsuite/gcc.target/aarch64/sve/while_9.c
+++ b/gcc/testsuite/gcc.target/aarch64/sve/while_9.c
@@ -19,7 +19,7 @@
 
 TEST_ALL (ADD_LOOP)
 
-/* { dg-final { scan-assembler-times {\tptrue\tp[0-7]\.b, vl16\n} 1 } } */
+/* { dg-final { scan-assembler-times {\tldr\tq[0-9]+\, \[x0\]} 1 } } */
 /* { dg-final { scan-assembler-times {\twhilelo\tp[0-7]\.h,} 2 } } */
 /* { dg-final { scan-assembler-times {\twhilelo\tp[0-7]\.s,} 2 } } */
 /* { dg-final { scan-assembler-times {\twhilelo\tp[0-7]\.d,} 2 } } */

Reply via email to