https://gcc.gnu.org/g:ed6d0867fdbb95691f4054eef6c4b1bd16099700

commit r15-5495-ged6d0867fdbb95691f4054eef6c4b1bd16099700
Author: Richard Sandiford <richard.sandif...@arm.com>
Date:   Wed Nov 20 10:04:46 2024 +0000

    Extend expand_absneg_bit to vector modes
    
    Expand can implement NEG and ABS of scalar floating-point modes
    by using logic ops to manipulate the sign bit.  This patch extends
    that approach to vectors, since it fits relatively easily into the
    same structure.
    
    The motivating use case was to inline bf16 NEG and ABS operations
    for AArch64.  The patch includes tests for that.
    
    get_absneg_bit_mode required a new opt_mode constructor, so that
    opt_mode<T> can be constructed from opt_mode<U> if T is no less
    general than U.
    
    gcc/
            * machmode.h (opt_mode::opt_mode): New overload.
            * optabs-query.h (get_absneg_bit_mode): Declare.
            * optabs-query.cc (get_absneg_bit_mode): New function, split
            out from expand_absneg_bit.
            (can_open_code_p): Use get_absneg_bit_mode.
            * optabs.cc (expand_absneg_bit): Likewise.  Take an outer and inner
            mode, rather than just one.  Handle vector modes.
            (expand_unop, expand_abs_nojump): Update calls accordingly.
            Handle vector modes.
    
    gcc/testsuite/
            * gcc.target/aarch64/abs_bf_1.c: New test.
            * gcc.target/aarch64/neg_bf_1.c: Likewise.
            * gcc.target/aarch64/neg_bf_2.c: Likewise.

Diff:
---
 gcc/machmode.h                              | 10 ++++++
 gcc/optabs-query.cc                         | 42 +++++++++++++++++++++++
 gcc/optabs-query.h                          |  2 ++
 gcc/optabs.cc                               | 53 ++++++++++++-----------------
 gcc/testsuite/gcc.target/aarch64/abs_bf_1.c | 10 ++++++
 gcc/testsuite/gcc.target/aarch64/neg_bf_1.c | 11 ++++++
 gcc/testsuite/gcc.target/aarch64/neg_bf_2.c | 10 ++++++
 7 files changed, 107 insertions(+), 31 deletions(-)

diff --git a/gcc/machmode.h b/gcc/machmode.h
index 4c2a8d943cf6..9cf792b5ccab 100644
--- a/gcc/machmode.h
+++ b/gcc/machmode.h
@@ -268,6 +268,8 @@ public:
   ALWAYS_INLINE CONSTEXPR opt_mode (const T &m) : m_mode (m) {}
   template<typename U>
   ALWAYS_INLINE CONSTEXPR opt_mode (const U &m) : m_mode (T (m)) {}
+  template<typename U>
+  ALWAYS_INLINE CONSTEXPR opt_mode (const opt_mode<U> &);
   ALWAYS_INLINE CONSTEXPR opt_mode (from_int m) : m_mode (machine_mode (m)) {}
 
   machine_mode else_void () const;
@@ -285,6 +287,14 @@ private:
   machine_mode m_mode;
 };
 
+template<typename T>
+template<typename U>
+ALWAYS_INLINE CONSTEXPR
+opt_mode<T>::opt_mode (const opt_mode<U> &m)
+  : m_mode (m.exists () ? T (m.require ()) : E_VOIDmode)
+{
+}
+
 /* If the object contains a T, return its enum value, otherwise return
    E_VOIDmode.  */
 
diff --git a/gcc/optabs-query.cc b/gcc/optabs-query.cc
index 6d28d620eb51..8ab4164e82c8 100644
--- a/gcc/optabs-query.cc
+++ b/gcc/optabs-query.cc
@@ -782,6 +782,39 @@ can_vec_extract (machine_mode mode, machine_mode extr_mode)
   return true;
 }
 
+/* OP is either neg_optab or abs_optab and FMODE is the floating-point inner
+   mode of MODE.  Check whether we can implement OP for mode MODE by using
+   xor_optab to flip the sign bit (for neg_optab) or and_optab to clear the
+   sign bit (for abs_optab).  If so, return the integral mode that should be
+   used to do the operation and set *BITPOS to the index of the sign bit
+   (counting from the lsb).  */
+
+opt_machine_mode
+get_absneg_bit_mode (optab op, machine_mode mode,
+                    scalar_float_mode fmode, int *bitpos)
+{
+  /* The format has to have a simple sign bit.  */
+  auto fmt = REAL_MODE_FORMAT (fmode);
+  if (fmt == NULL)
+    return {};
+
+  *bitpos = fmt->signbit_rw;
+  if (*bitpos < 0)
+    return {};
+
+  /* Don't create negative zeros if the format doesn't support them.  */
+  if (op == neg_optab && !fmt->has_signed_zero)
+    return {};
+
+  if (VECTOR_MODE_P (mode))
+    return related_int_vector_mode (mode);
+
+  if (GET_MODE_SIZE (fmode) <= UNITS_PER_WORD)
+    return int_mode_for_mode (fmode);
+
+  return word_mode;
+}
+
 /* Return true if we can implement OP for mode MODE directly, without resorting
    to a libfunc.   This usually means that OP will be implemented inline.
 
@@ -800,6 +833,15 @@ can_open_code_p (optab op, machine_mode mode)
   if (op == smul_highpart_optab)
     return can_mult_highpart_p (mode, false);
 
+  machine_mode new_mode;
+  scalar_float_mode fmode;
+  int bitpos;
+  if ((op == neg_optab || op == abs_optab)
+      && is_a<scalar_float_mode> (GET_MODE_INNER (mode), &fmode)
+      && get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode)
+      && can_implement_p (op == neg_optab ? xor_optab : and_optab, new_mode))
+    return true;
+
   return false;
 }
 
diff --git a/gcc/optabs-query.h b/gcc/optabs-query.h
index 89a7b02ef437..60c8021a1b75 100644
--- a/gcc/optabs-query.h
+++ b/gcc/optabs-query.h
@@ -171,6 +171,8 @@ bool lshift_cheap_p (bool);
 bool supports_vec_gather_load_p (machine_mode = E_VOIDmode,
                                 vec<int> * = nullptr);
 bool supports_vec_scatter_store_p (machine_mode = E_VOIDmode);
+opt_machine_mode get_absneg_bit_mode (optab, machine_mode,
+                                     scalar_float_mode, int *);
 bool can_vec_extract (machine_mode, machine_mode);
 bool can_open_code_p (optab, machine_mode);
 bool can_implement_p (optab, machine_mode);
diff --git a/gcc/optabs.cc b/gcc/optabs.cc
index fa51e498a98e..b9c51f78af41 100644
--- a/gcc/optabs.cc
+++ b/gcc/optabs.cc
@@ -3101,48 +3101,37 @@ expand_ffs (scalar_int_mode mode, rtx op0, rtx target)
 }
 
 /* Expand a floating point absolute value or negation operation via a
-   logical operation on the sign bit.  */
+   logical operation on the sign bit.  MODE is the mode of the operands
+   and FMODE is the scalar inner mode.  */
 
 static rtx
-expand_absneg_bit (enum rtx_code code, scalar_float_mode mode,
-                  rtx op0, rtx target)
+expand_absneg_bit (rtx_code code, machine_mode mode,
+                  scalar_float_mode fmode, rtx op0, rtx target)
 {
-  const struct real_format *fmt;
   int bitpos, word, nwords, i;
+  machine_mode new_mode;
   scalar_int_mode imode;
   rtx temp;
   rtx_insn *insns;
 
-  /* The format has to have a simple sign bit.  */
-  fmt = REAL_MODE_FORMAT (mode);
-  if (fmt == NULL)
-    return NULL_RTX;
-
-  bitpos = fmt->signbit_rw;
-  if (bitpos < 0)
+  auto op = code == NEG ? neg_optab : abs_optab;
+  if (!get_absneg_bit_mode (op, mode, fmode, &bitpos).exists (&new_mode))
     return NULL_RTX;
 
-  /* Don't create negative zeros if the format doesn't support them.  */
-  if (code == NEG && !fmt->has_signed_zero)
-    return NULL_RTX;
-
-  if (GET_MODE_SIZE (mode) <= UNITS_PER_WORD)
+  imode = as_a<scalar_int_mode> (GET_MODE_INNER (new_mode));
+  if (VECTOR_MODE_P (mode) || GET_MODE_SIZE (fmode) <= UNITS_PER_WORD)
     {
-      if (!int_mode_for_mode (mode).exists (&imode))
-       return NULL_RTX;
       word = 0;
       nwords = 1;
     }
   else
     {
-      imode = word_mode;
-
       if (FLOAT_WORDS_BIG_ENDIAN)
-       word = (GET_MODE_BITSIZE (mode) - bitpos) / BITS_PER_WORD;
+       word = (GET_MODE_BITSIZE (fmode) - bitpos) / BITS_PER_WORD;
       else
        word = bitpos / BITS_PER_WORD;
       bitpos = bitpos % BITS_PER_WORD;
-      nwords = (GET_MODE_BITSIZE (mode) + BITS_PER_WORD - 1) / BITS_PER_WORD;
+      nwords = (GET_MODE_BITSIZE (fmode) + BITS_PER_WORD - 1) / BITS_PER_WORD;
     }
 
   wide_int mask = wi::set_bit_in_zero (bitpos, GET_MODE_PRECISION (imode));
@@ -3184,11 +3173,13 @@ expand_absneg_bit (enum rtx_code code, 
scalar_float_mode mode,
     }
   else
     {
-      temp = expand_binop (imode, code == ABS ? and_optab : xor_optab,
-                          gen_lowpart (imode, op0),
-                          immed_wide_int_const (mask, imode),
-                          gen_lowpart (imode, target), 1, OPTAB_LIB_WIDEN);
-      target = force_lowpart_subreg (mode, temp, imode);
+      rtx mask_rtx = immed_wide_int_const (mask, imode);
+      if (VECTOR_MODE_P (new_mode))
+       mask_rtx = gen_const_vec_duplicate (new_mode, mask_rtx);
+      temp = expand_binop (new_mode, code == ABS ? and_optab : xor_optab,
+                          gen_lowpart (new_mode, op0), mask_rtx,
+                          gen_lowpart (new_mode, target), 1, OPTAB_LIB_WIDEN);
+      target = force_lowpart_subreg (mode, temp, new_mode);
 
       set_dst_reg_note (get_last_insn (), REG_EQUAL,
                        gen_rtx_fmt_e (code, mode, copy_rtx (op0)),
@@ -3478,9 +3469,9 @@ expand_unop (machine_mode mode, optab unoptab, rtx op0, 
rtx target,
   if (optab_to_code (unoptab) == NEG)
     {
       /* Try negating floating point values by flipping the sign bit.  */
-      if (is_a <scalar_float_mode> (mode, &float_mode))
+      if (is_a <scalar_float_mode> (GET_MODE_INNER (mode), &float_mode))
        {
-         temp = expand_absneg_bit (NEG, float_mode, op0, target);
+         temp = expand_absneg_bit (NEG, mode, float_mode, op0, target);
          if (temp)
            return temp;
        }
@@ -3698,9 +3689,9 @@ expand_abs_nojump (machine_mode mode, rtx op0, rtx target,
 
   /* For floating point modes, try clearing the sign bit.  */
   scalar_float_mode float_mode;
-  if (is_a <scalar_float_mode> (mode, &float_mode))
+  if (is_a <scalar_float_mode> (GET_MODE_INNER (mode), &float_mode))
     {
-      temp = expand_absneg_bit (ABS, float_mode, op0, target);
+      temp = expand_absneg_bit (ABS, mode, float_mode, op0, target);
       if (temp)
        return temp;
     }
diff --git a/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c 
b/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c
new file mode 100644
index 000000000000..42e03bca0bea
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/abs_bf_1.c
@@ -0,0 +1,10 @@
+/* { dg-options "-O2 -ffast-math" } */
+
+void
+foo (__bf16 *ptr)
+{
+  for (int i = 0; i < 8; ++i)
+    ptr[i] = __builtin_fabsf (ptr[i]);
+}
+
+/* { dg-final { scan-assembler {\t(?:bic|and)\t[zv]} } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c 
b/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c
new file mode 100644
index 000000000000..564ff1ec9cb2
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/neg_bf_1.c
@@ -0,0 +1,11 @@
+/* { dg-options "-O2" } */
+
+typedef __bf16 v8bf __attribute__((vector_size(16)));
+typedef __bf16 v16bf __attribute__((vector_size(32)));
+typedef __bf16 v64bf __attribute__((vector_size(128)));
+
+v8bf f1(v8bf x) { return -x; }
+v16bf f2(v16bf x) { return -x; }
+v64bf f3(v64bf x) { return -x; }
+
+/* { dg-final { scan-assembler-times {\teor\t[zv]} 11 } } */
diff --git a/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c 
b/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c
new file mode 100644
index 000000000000..072924284188
--- /dev/null
+++ b/gcc/testsuite/gcc.target/aarch64/neg_bf_2.c
@@ -0,0 +1,10 @@
+/* { dg-options "-O2" } */
+
+void
+foo (__bf16 *ptr)
+{
+  for (int i = 0; i < 8; ++i)
+    ptr[i] = -ptr[i];
+}
+
+/* { dg-final { scan-assembler {\teor\t[zv]} } } */

Reply via email to