https://gcc.gnu.org/g:8fbea0880f7a4082203dd4ac32596e3970d5f7d9
commit 8fbea0880f7a4082203dd4ac32596e3970d5f7d9 Author: Michael Meissner <[email protected]> Date: Mon Oct 13 12:41:28 2025 -0400 Rework bfloat16 to v4sf optimization. 2025-10-13 Michael Meissner <[email protected]> gcc/ * config/rs6000/float16.cc (bfloat16_operation_as_v4s): Rewrite bfloat16_binary_op_as_v4sf so it will be able to handle FMA operations in the future. * config/rs6000/float16.md (bfloat16_binary_op_internal1): Likewise. (bfloat16_binary_op_internal2): Likewise. (bfloat16_binary_op_internal3): Likewise. (bfloat16_binary_op_internal4): Likewise. (bfloat16_binary_op_internal5): Likewise. (bfloat16_binary_op_internal6): Likewise. * config/rs6000/rs6000-protos.h (enum bfloat16_operation): New enumeration. (bfloat16_binary_op_as_v4sf): Delete. (bfloat16_operation_as_v4sf): New declaration. * config/rs6000/vsx.md (vsx_fmav4sf4): Add generator. (vsx_fms<mode>4): Likewise. (vsx_nfma<mode>4): Likewise. (vsx_nfmsv4sf4): Likewise. Diff: --- gcc/config/rs6000/float16.cc | 185 ++++++++++++++++++++++---------------- gcc/config/rs6000/float16.md | 84 +++++++---------- gcc/config/rs6000/rs6000-protos.h | 13 ++- gcc/config/rs6000/vsx.md | 8 +- 4 files changed, 151 insertions(+), 139 deletions(-) diff --git a/gcc/config/rs6000/float16.cc b/gcc/config/rs6000/float16.cc index 0d606609dab3..484d04f4ddb4 100644 --- a/gcc/config/rs6000/float16.cc +++ b/gcc/config/rs6000/float16.cc @@ -42,15 +42,14 @@ #include "common/common-target.h" #include "rs6000-internal.h" -/* Expand a bfloat16 floating point binary operation: +/* Expand a bfloat16 floating point operation: - ICODE: Operation to perform. - OP0: Result (BFmode or SFmode). - OP1: First input argument (BFmode or SFmode). - OP2: Second input argument (BFmode or SFmode). - TMP0: Temporary for result (V4SFmode). - TMP1: Temporary for first input argument (V4SFmode). - TMP2: Temporary for second input argument (V4SFmode). + ICODE: Operation to perform. + RESULT: Result of the operation. + OP1: Input operand1. + OP2: Input operand2. + OP3: Input operand3 or NULL_RTX. + SUBTYPE: Describe the operation. The operation is done as a V4SFmode vector operation. This is because converting BFmode from a scalar BFmode to SFmode to do the operation and @@ -60,108 +59,136 @@ SFmode. */ void -bfloat16_binary_op_as_v4sf (enum rtx_code icode, - rtx op0, +bfloat16_operation_as_v4sf (enum rtx_code icode, + rtx result, rtx op1, rtx op2, - rtx tmp0, - rtx tmp1, - rtx tmp2) + rtx op3, + enum bfloat16_operation subtype) { - if (GET_CODE (tmp0) == SCRATCH) - tmp0 = gen_reg_rtx (V4SFmode); + gcc_assert (can_create_pseudo_p ()); - if (GET_CODE (tmp1) == SCRATCH) - tmp1 = gen_reg_rtx (V4SFmode); + rtx result_v4sf = gen_reg_rtx (V4SFmode); + rtx ops_bf[3]; + rtx ops_v4sf[3]; + size_t n_opts; - if (GET_CODE (tmp2) == SCRATCH) - tmp2 = gen_reg_rtx (V4SFmode); - - /* Convert operand1 and operand2 to V4SFmode format. We use SPLAT for - registers to get the value into the upper 32-bits. We can use XXSPLTW - to splat words instead of VSPLTIH since the XVCVBF16SPN instruction - ignores the odd half-words, and XXSPLTW can operate on all VSX registers - instead of just the Altivec registers. Using SPLAT instead of a shift - also insure that other bits are not a signalling NaN. If we are using - XXSPLTIW or XXSPLTIB to load the constant the other bits are duplicated. */ - - /* Operand1. */ - if (GET_MODE (op1) == BFmode) + switch (subtype) { - emit_insn (gen_xxspltw_bf (tmp1, op1)); - emit_insn (gen_xvcvbf16spn_bf (tmp1, tmp1)); + case BF16_BINARY: + n_opts = 2; + ops_bf[0] = op1; + ops_bf[1] = op2; + gcc_assert (op3 == NULL_RTX); + break; + + case BF16_FMA: + case BF16_FMS: + case BF16_NFMA: + case BF16_NFMS: + gcc_assert (icode == FMA); + n_opts = 3; + ops_bf[0] = op1; + ops_bf[1] = op2; + ops_bf[3] = op3; + break; + + default: + gcc_unreachable (); } - else if (GET_MODE (op1) == SFmode) - emit_insn (gen_vsx_splat_v4sf (tmp1, - force_reg (SFmode, op1))); - - else - gcc_unreachable (); - - /* Operand2. */ - if (GET_MODE (op2) == BFmode) + for (size_t i = 0; i < n_opts; i++) { - if (REG_P (op2) || SUBREG_P (op2)) - emit_insn (gen_xxspltw_bf (tmp2, op2)); + rtx op = ops_bf[i]; + rtx tmp = ops_v4sf[i] = gen_reg_rtx (V4SFmode); + + gcc_assert (op != NULL_RTX); - else if (op2 == CONST0_RTX (BFmode)) - emit_move_insn (tmp2, CONST0_RTX (V4SFmode)); + /* Convert operands to V4SFmode format. We use SPLAT for registers to + get the value into the upper 32-bits. We can use XXSPLTW to splat + words instead of VSPLTIH since the XVCVBF16SPN instruction ignores the + odd half-words, and XXSPLTW can operate on all VSX registers instead + of just the Altivec registers. Using SPLAT instead of a shift also + insure that other bits are not a signalling NaN. If we are using + XXSPLTIW or XXSPLTIB to load the constant the other bits are + duplicated. */ - else if (fp16_xxspltiw_constant (op2, BFmode)) + if (GET_MODE (op) == BFmode) { - rtx op2_bf = gen_lowpart (BFmode, tmp2); - emit_move_insn (op2_bf, op2); + emit_insn (gen_xxspltw_bf (tmp, op)); + emit_insn (gen_xvcvbf16spn_bf (tmp, tmp)); } - else - gcc_unreachable (); + else if (op == CONST0_RTX (SFmode) + || op == CONST0_RTX (BFmode)) + emit_move_insn (tmp, CONST0_RTX (V4SFmode)); - emit_insn (gen_xvcvbf16spn_bf (tmp2, tmp2)); - } + else if (GET_MODE (op) == SFmode) + { + if (GET_CODE (op) == CONST_DOUBLE) + { + rtvec v = rtvec_alloc (4); - else if (GET_MODE (op2) == SFmode) - { - if (REG_P (op2) || SUBREG_P (op2)) - emit_insn (gen_vsx_splat_v4sf (tmp2, op2)); + for (size_t i = 0; i < 4; i++) + RTVEC_ELT (v, i) = op; - else if (op2 == CONST0_RTX (SFmode)) - emit_move_insn (tmp2, CONST0_RTX (V4SFmode)); + emit_insn (gen_rtx_SET (tmp, + gen_rtx_CONST_VECTOR (V4SFmode, v))); + } - else if (GET_CODE (op2) == CONST_DOUBLE) - { - rtvec v = rtvec_alloc (4); - RTVEC_ELT (v, 0) = op2; - RTVEC_ELT (v, 1) = op2; - RTVEC_ELT (v, 2) = op2; - RTVEC_ELT (v, 3) = op2; - emit_insn (gen_rtx_SET (tmp2, - gen_rtx_CONST_VECTOR (V4SFmode, v))); + else + emit_insn (gen_vsx_splat_v4sf (tmp, + force_reg (SFmode, op))); } else - emit_insn (gen_vsx_splat_v4sf (tmp2, - force_reg (SFmode, op2))); + gcc_unreachable (); } - else - gcc_unreachable (); - /* Do the operation in V4SFmode. */ - emit_insn (gen_rtx_SET (tmp0, - gen_rtx_fmt_ee (icode, V4SFmode, tmp1, tmp2))); + switch (subtype) + { + case BF16_BINARY: + emit_insn (gen_rtx_SET (result_v4sf, + gen_rtx_fmt_ee (icode, V4SFmode, + ops_v4sf[0], + ops_v4sf[1]))); + break; + + case BF16_FMA: + emit_insn (gen_vsx_fmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], + ops_v4sf[2])); + break; + + case BF16_FMS: + emit_insn (gen_vsx_fmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], + ops_v4sf[2])); + break; + + case BF16_NFMA: + emit_insn (gen_vsx_nfmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], + ops_v4sf[2])); + break; + + case BF16_NFMS: + emit_insn (gen_vsx_nfmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], + ops_v4sf[2])); + break; + + default: + gcc_unreachable (); + } /* Convert V4SF result back to scalar mode. */ - if (GET_MODE (op0) == BFmode) - emit_insn (gen_xvcvspbf16_bf (op0, tmp0)); + if (GET_MODE (result) == BFmode) + emit_insn (gen_xvcvspbf16_bf (result, result_v4sf)); - else if (GET_MODE (op0) == SFmode) + else if (GET_MODE (result) == SFmode) { rtx element = GEN_INT (WORDS_BIG_ENDIAN ? 2 : 3); - emit_insn (gen_vsx_extract_v4sf (op0, tmp0, element)); + emit_insn (gen_vsx_extract_v4sf (result, result_v4sf, element)); } else gcc_unreachable (); } - diff --git a/gcc/config/rs6000/float16.md b/gcc/config/rs6000/float16.md index bab03ffddb6e..3715bde0df03 100644 --- a/gcc/config/rs6000/float16.md +++ b/gcc/config/rs6000/float16.md @@ -450,22 +450,18 @@ [(float_extend:SF (match_operand:BF 2 "vsx_register_operand" "wa")) (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))])) - (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa")) - (clobber (match_scratch:V4SF 6 "=&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:BF 3 "vsx_register_operand" "wa"))]))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) @@ -476,22 +472,18 @@ [(float_extend:SF (match_operand:BF 2 "vsx_register_operand" "wa")) (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))]))) - (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa")) - (clobber (match_scratch:V4SF 6 "=&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:BF 3 "vsx_register_operand" "wa"))])))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) @@ -500,22 +492,18 @@ (match_operator:SF 1 "bfloat16_binary_operator" [(float_extend:SF (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 3 "input_operand" "wa,j,eP")])) - (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:SF 3 "input_operand" "wa,j,eP")]))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) @@ -525,22 +513,18 @@ (match_operator:SF 1 "bfloat16_binary_operator" [(float_extend:SF (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 3 "input_operand" "wa,j,eP")]))) - (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa")) - (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:SF 3 "input_operand" "wa,j,eP")])))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) @@ -549,22 +533,18 @@ (match_operator:SF 1 "bfloat16_binary_operator" [(match_operand:SF 2 "vsx_register_operand" "wa") (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))])) - (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa")) - (clobber (match_scratch:V4SF 6 "=&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:BF 3 "vsx_register_operand" "wa"))]))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) @@ -574,22 +554,18 @@ (match_operator:SF 1 "bfloat16_binary_operator" [(match_operand:SF 3 "vsx_register_operand" "wa") (float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa"))]))) - (clobber (match_scratch:V4SF 4 "=&wa")) - (clobber (match_scratch:V4SF 5 "=&wa")) - (clobber (match_scratch:V4SF 6 "=&wa"))] - "TARGET_BFLOAT16_HW" + (match_operand:BF 2 "vsx_register_operand" "wa"))])))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" "#" "&& 1" [(pc)] { - bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (GET_CODE (operands[1]), operands[0], operands[2], operands[3], - operands[4], - operands[5], - operands[6]); + NULL_RTX, + BF16_BINARY); DONE; }) diff --git a/gcc/config/rs6000/rs6000-protos.h b/gcc/config/rs6000/rs6000-protos.h index 063f74f6e3f6..db38468df816 100644 --- a/gcc/config/rs6000/rs6000-protos.h +++ b/gcc/config/rs6000/rs6000-protos.h @@ -260,8 +260,17 @@ extern unsigned constant_generates_xxspltiw (vec_const_128bit_type *); extern unsigned constant_generates_xxspltidp (vec_const_128bit_type *); /* From float16.cc. */ -extern void bfloat16_binary_op_as_v4sf (enum rtx_code, rtx, rtx, rtx, - rtx, rtx, rtx); +/* Optimize bfloat16 operations. */ +enum bfloat16_operation { + BF16_BINARY, /* Bfloat16 binary op. */ + BF16_FMA, /* (a * b) + c. */ + BF16_FMS, /* (a * b) - c. */ + BF16_NFMA, /* - ((a * b) + c). */ + BF16_NFMS /* - ((a * b) - c). */ +}; + +extern void bfloat16_operation_as_v4sf (enum rtx_code, rtx, rtx, rtx, rtx, + enum bfloat16_operation); #endif /* RTX_CODE */ #ifdef TREE_CODE diff --git a/gcc/config/rs6000/vsx.md b/gcc/config/rs6000/vsx.md index 6c11d7766ed1..2611660921a5 100644 --- a/gcc/config/rs6000/vsx.md +++ b/gcc/config/rs6000/vsx.md @@ -2098,7 +2098,7 @@ ;; vmaddfp and vnmsubfp can have different behaviors than the VSX instructions ;; in some corner cases due to VSCR[NJ] being set or if the addend is +0.0 ;; instead of -0.0. -(define_insn "*vsx_fmav4sf4" +(define_insn "vsx_fmav4sf4" [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa") (fma:V4SF (match_operand:V4SF 1 "vsx_register_operand" "%wa,wa") @@ -2122,7 +2122,7 @@ xvmaddmdp %x0,%x1,%x3" [(set_attr "type" "vecdouble")]) -(define_insn "*vsx_fms<mode>4" +(define_insn "vsx_fms<mode>4" [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa") (fma:VSX_F (match_operand:VSX_F 1 "vsx_register_operand" "%wa,wa") @@ -2135,7 +2135,7 @@ xvmsubm<sd>p %x0,%x1,%x3" [(set_attr "type" "<VStype_mul>")]) -(define_insn "*vsx_nfma<mode>4" +(define_insn "vsx_nfma<mode>4" [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa") (neg:VSX_F (fma:VSX_F @@ -2148,7 +2148,7 @@ xvnmaddm<sd>p %x0,%x1,%x3" [(set_attr "type" "<VStype_mul>")]) -(define_insn "*vsx_nfmsv4sf4" +(define_insn "vsx_nfmsv4sf4" [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa") (neg:V4SF (fma:V4SF
