On Fri, Nov 1, 2024 at 8:33 AM Hongyu Wang <hongyu.w...@intel.com> wrote: > > From: Levy Hsu <ad...@levyhsu.com> > > This patch enables the use of the VCOMSBF16 instruction from AVX10.2 for > efficient BF16 comparisons. > > Bootstrapped & regtested on x86-64-pc-linux-gnu. > Ok for trunk? Ok. > > gcc/ChangeLog: > > * config/i386/i386-expand.cc (ix86_expand_branch): Handle BFmode > when TARGET_AVX10_2_256 is enabled. > (ix86_prepare_fp_compare_args): Use SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P. > (ix86_expand_fp_movcc): Ditto. > (ix86_expand_fp_compare): Handle BFmode under IX86_FPCMP_COMI. > * config/i386/i386.cc (ix86_multiplication_cost): Use > SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P. > (ix86_division_cost): Ditto. > (ix86_rtx_costs): Ditto. > (ix86_vector_costs::add_stmt_cost): Ditto. > * config/i386/i386.h (SSE_FLOAT_MODE_SSEMATH_OR_HF_P): Rename to ... > (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P): ...this, and add BFmode. > * config/i386/i386.md (*cmpibf): New define_insn. > > gcc/testsuite/ChangeLog: > > * gcc.target/i386/avx10_2-comibf-1.c: New test. > * gcc.target/i386/avx10_2-comibf-2.c: Ditto. > --- > gcc/config/i386/i386-expand.cc | 22 ++-- > gcc/config/i386/i386.cc | 22 ++-- > gcc/config/i386/i386.h | 7 +- > gcc/config/i386/i386.md | 33 +++-- > .../gcc.target/i386/avx10_2-comibf-1.c | 40 ++++++ > .../gcc.target/i386/avx10_2-comibf-2.c | 118 ++++++++++++++++++ > 6 files changed, 214 insertions(+), 28 deletions(-) > create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c > create mode 100644 gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c > > diff --git a/gcc/config/i386/i386-expand.cc b/gcc/config/i386/i386-expand.cc > index 0de0e842731..96e4659da10 100644 > --- a/gcc/config/i386/i386-expand.cc > +++ b/gcc/config/i386/i386-expand.cc > @@ -2531,6 +2531,10 @@ ix86_expand_branch (enum rtx_code code, rtx op0, rtx > op1, rtx label) > emit_jump_insn (gen_rtx_SET (pc_rtx, tmp)); > return; > > + case E_BFmode: > + gcc_assert (TARGET_AVX10_2_256 && !flag_trapping_math); > + goto simple; > + > case E_DImode: > if (TARGET_64BIT) > goto simple; > @@ -2797,9 +2801,9 @@ ix86_prepare_fp_compare_args (enum rtx_code code, rtx > *pop0, rtx *pop1) > bool unordered_compare = ix86_unordered_fp_compare (code); > rtx op0 = *pop0, op1 = *pop1; > machine_mode op_mode = GET_MODE (op0); > - bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HF_P (op_mode); > + bool is_sse = SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (op_mode); > > - if (op_mode == BFmode) > + if (op_mode == BFmode && (!TARGET_AVX10_2_256 || flag_trapping_math)) > { > rtx op = gen_lowpart (HImode, op0); > if (CONST_INT_P (op)) > @@ -2918,10 +2922,14 @@ ix86_expand_fp_compare (enum rtx_code code, rtx op0, > rtx op1) > { > case IX86_FPCMP_COMI: > tmp = gen_rtx_COMPARE (CCFPmode, op0, op1); > - if (TARGET_AVX10_2_256 && (code == EQ || code == NE)) > - tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_OPTCOMX); > - if (unordered_compare) > - tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), UNSPEC_NOTRAP); > + /* We only have vcomsbf16, No vcomubf16 nor vcomxbf16 */ > + if (GET_MODE (op0) != E_BFmode) > + { > + if (TARGET_AVX10_2_256 && (code == EQ || code == NE)) > + tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), > UNSPEC_OPTCOMX); > + if (unordered_compare) > + tmp = gen_rtx_UNSPEC (CCFPmode, gen_rtvec (1, tmp), > UNSPEC_NOTRAP); > + } > cmp_mode = CCFPmode; > emit_insn (gen_rtx_SET (gen_rtx_REG (CCFPmode, FLAGS_REG), tmp)); > break; > @@ -4636,7 +4644,7 @@ ix86_expand_fp_movcc (rtx operands[]) > && !ix86_fp_comparison_operator (operands[1], VOIDmode)) > return false; > > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > { > machine_mode cmode; > > diff --git a/gcc/config/i386/i386.cc b/gcc/config/i386/i386.cc > index 473e4cbf10e..6ac3a5d55f2 100644 > --- a/gcc/config/i386/i386.cc > +++ b/gcc/config/i386/i386.cc > @@ -21324,7 +21324,7 @@ ix86_multiplication_cost (const struct > processor_costs *cost, > if (VECTOR_MODE_P (mode)) > inner_mode = GET_MODE_INNER (mode); > > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > return inner_mode == DFmode ? cost->mulsd : cost->mulss; > else if (X87_FLOAT_MODE_P (mode)) > return cost->fmul; > @@ -21449,7 +21449,7 @@ ix86_division_cost (const struct processor_costs > *cost, > if (VECTOR_MODE_P (mode)) > inner_mode = GET_MODE_INNER (mode); > > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > return inner_mode == DFmode ? cost->divsd : cost->divss; > else if (X87_FLOAT_MODE_P (mode)) > return cost->fdiv; > @@ -21991,7 +21991,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int > outer_code_i, int opno, > return true; > } > > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = cost->addss; > else if (X87_FLOAT_MODE_P (mode)) > *total = cost->fadd; > @@ -22198,7 +22198,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int > outer_code_i, int opno, > return false; > > case NEG: > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = cost->sse_op; > else if (X87_FLOAT_MODE_P (mode)) > *total = cost->fchs; > @@ -22306,14 +22306,14 @@ ix86_rtx_costs (rtx x, machine_mode mode, int > outer_code_i, int opno, > return false; > > case FLOAT_EXTEND: > - if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = 0; > else > *total = ix86_vec_cost (mode, cost->addss); > return false; > > case FLOAT_TRUNCATE: > - if (!SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (!SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = cost->fadd; > else > *total = ix86_vec_cost (mode, cost->addss); > @@ -22323,7 +22323,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int > outer_code_i, int opno, > /* SSE requires memory load for the constant operand. It may make > sense to account for this. Of course the constant operand may or > may not be reused. */ > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = cost->sse_op; > else if (X87_FLOAT_MODE_P (mode)) > *total = cost->fabs; > @@ -22334,7 +22334,7 @@ ix86_rtx_costs (rtx x, machine_mode mode, int > outer_code_i, int opno, > return false; > > case SQRT: > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > *total = mode == SFmode ? cost->sqrtss : cost->sqrtsd; > else if (X87_FLOAT_MODE_P (mode)) > *total = cost->fsqrt; > @@ -25083,7 +25083,7 @@ ix86_vector_costs::add_stmt_cost (int count, > vect_cost_for_stmt kind, > case MINUS_EXPR: > if (kind == scalar_stmt) > { > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > stmt_cost = ix86_cost->addss; > else if (X87_FLOAT_MODE_P (mode)) > stmt_cost = ix86_cost->fadd; > @@ -25109,7 +25109,7 @@ ix86_vector_costs::add_stmt_cost (int count, > vect_cost_for_stmt kind, > break; > > case NEGATE_EXPR: > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > stmt_cost = ix86_cost->sse_op; > else if (X87_FLOAT_MODE_P (mode)) > stmt_cost = ix86_cost->fchs; > @@ -25165,7 +25165,7 @@ ix86_vector_costs::add_stmt_cost (int count, > vect_cost_for_stmt kind, > case BIT_XOR_EXPR: > case BIT_AND_EXPR: > case BIT_NOT_EXPR: > - if (SSE_FLOAT_MODE_SSEMATH_OR_HF_P (mode)) > + if (SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P (mode)) > stmt_cost = ix86_cost->sse_op; > else if (VECTOR_MODE_P (mode)) > stmt_cost = ix86_vec_cost (mode, ix86_cost->sse_op); > diff --git a/gcc/config/i386/i386.h b/gcc/config/i386/i386.h > index 51934400951..a4874a46dc7 100644 > --- a/gcc/config/i386/i386.h > +++ b/gcc/config/i386/i386.h > @@ -1158,9 +1158,10 @@ extern const char *host_detect_local_cpu (int argc, > const char **argv); > #define SSE_FLOAT_MODE_P(MODE) \ > ((TARGET_SSE && (MODE) == SFmode) || (TARGET_SSE2 && (MODE) == DFmode)) > > -#define SSE_FLOAT_MODE_SSEMATH_OR_HF_P(MODE) \ > - ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) > \ > - || (TARGET_AVX512FP16 && (MODE) == HFmode)) > +#define SSE_FLOAT_MODE_SSEMATH_OR_HFBF_P(MODE) \ > + ((SSE_FLOAT_MODE_P (MODE) && TARGET_SSE_MATH) \ > + || (TARGET_AVX512FP16 && (MODE) == HFmode) \ > + || (TARGET_AVX10_2_256 && (MODE) == BFmode)) > > #define FMA4_VEC_FLOAT_MODE_P(MODE) \ > (TARGET_FMA4 && ((MODE) == V4SFmode || (MODE) == V2DFmode \ > diff --git a/gcc/config/i386/i386.md b/gcc/config/i386/i386.md > index fb6aaa81505..11855d793a4 100644 > --- a/gcc/config/i386/i386.md > +++ b/gcc/config/i386/i386.md > @@ -1814,13 +1814,21 @@ (define_expand "cbranchbf4" > (pc)))] > "TARGET_80387 || (SSE_FLOAT_MODE_P (SFmode) && TARGET_SSE_MATH)" > { > - rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]); > - rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]); > - do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0, > - SFmode, NULL_RTX, NULL, > - as_a <rtx_code_label *> (operands[3]), > - /* Unfortunately this isn't propagated. */ > - profile_probability::even ()); > + if (TARGET_AVX10_2_256 && !flag_trapping_math) > + { > + ix86_expand_branch (GET_CODE (operands[0]), > + operands[1], operands[2], operands[3]); > + } > + else > + { > + rtx op1 = ix86_expand_fast_convert_bf_to_sf (operands[1]); > + rtx op2 = ix86_expand_fast_convert_bf_to_sf (operands[2]); > + do_compare_rtx_and_jump (op1, op2, GET_CODE (operands[0]), 0, > + SFmode, NULL_RTX, NULL, > + as_a <rtx_code_label *> (operands[3]), > + /* Unfortunately this isn't propagated. */ > + profile_probability::even ()); > + } > DONE; > }) > > @@ -2096,6 +2104,17 @@ (define_insn "*cmpi<unord>hf" > (set_attr "prefix" "evex") > (set_attr "mode" "HF")]) > > +(define_insn "*cmpibf" > + [(set (reg:CCFP FLAGS_REG) > + (compare:CCFP > + (match_operand:BF 0 "register_operand" "v") > + (match_operand:BF 1 "nonimmediate_operand" "vm")))] > + "TARGET_AVX10_2_256" > + "vcomsbf16\t{%1, %0|%0, %1}" > + [(set_attr "type" "ssecomi") > + (set_attr "prefix" "evex") > + (set_attr "mode" "BF")]) > + > ;; Set carry flag. > (define_insn "x86_stc" > [(set (reg:CCC FLAGS_REG) (unspec:CCC [(const_int 0)] UNSPEC_STC))] > diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c > b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c > new file mode 100644 > index 00000000000..85b773b89f2 > --- /dev/null > +++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-1.c > @@ -0,0 +1,40 @@ > +/* { dg-do compile } */ > +/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */ > +/* { dg-final { scan-assembler-times "vcomsbf16\[ > \\t\]+\[^{}\n\]*%xmm\[0-9\]+(?:\n|\[ \\t\]+#)" 6 } } */ > +/* { dg-final { scan-assembler-times {j[a-z]+\s} 6 } } */ > + > +__bf16 > +foo_eq (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a == b ? c + d : c - d; > +} > + > +__bf16 > +foo_ne (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a != b ? c + d : c - d; > +} > + > +__bf16 > +foo_lt (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a < b ? c + d : c - d; > +} > + > +__bf16 > +foo_le (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a <= b ? c + d : c - d; > +} > + > +__bf16 > +foo_gt (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a > b ? c + d : c - d; > +} > + > +__bf16 > +foo_ge (__bf16 a, __bf16 b, __bf16 c, __bf16 d) > +{ > + return a >= b ? c + d : c - d; > +} > diff --git a/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c > b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c > new file mode 100644 > index 00000000000..126957bf272 > --- /dev/null > +++ b/gcc/testsuite/gcc.target/i386/avx10_2-comibf-2.c > @@ -0,0 +1,118 @@ > + /* { dg-do run } */ > +/* { dg-options "-march=x86-64-v3 -mavx10.2 -O2 -fno-trapping-math" } */ > + > +#include <stdlib.h> > +#include <stdint.h> > +#include <string.h> > + > +/* Fast shift conversion here for convenience */ > +static __bf16 > +float_to_bf16 (float f) > +{ > + uint32_t float_bits; > + uint16_t bf16_bits; > + > + memcpy (&float_bits, &f, sizeof (float_bits)); > + bf16_bits = (uint16_t) (float_bits >> 16); > + > + __bf16 bf; > + memcpy (&bf, &bf16_bits, sizeof (bf)); > + return bf; > +} > + > +static float > +bf16_to_float (__bf16 bf) > +{ > + uint32_t float_bits; > + uint16_t bf16_bits; > + > + memcpy (&bf16_bits, &bf, sizeof (bf16_bits)); > + float_bits = ((uint32_t) bf16_bits) << 16; > + > + float f; > + memcpy (&f, &float_bits, sizeof (f)); > + return f; > +} > + > +static void > +test_eq (__bf16 a, __bf16 b) > +{ > + int result = (a == b); > + int expected = (bf16_to_float (a) == bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +static void > +test_ne (__bf16 a, __bf16 b) > +{ > + int result = (a != b); > + int expected = (bf16_to_float (a) != bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +static void > +test_lt (__bf16 a, __bf16 b) > +{ > + int result = (a < b); > + int expected = (bf16_to_float (a) < bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +static void > +test_le (__bf16 a, __bf16 b) > +{ > + int result = (a <= b); > + int expected = (bf16_to_float (a) <= bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +static void > +test_gt (__bf16 a, __bf16 b) > +{ > + int result = (a > b); > + int expected = (bf16_to_float (a) > bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +static void > +test_ge (__bf16 a, __bf16 b) > +{ > + int result = (a >= b); > + int expected = (bf16_to_float (a) >= bf16_to_float (b)); > + if (result != expected) > + abort (); > +} > + > +int > +main (void) > +{ > + if (!__builtin_cpu_supports ("avx10.2")) > + return 0; > + > + float test_values[] = { > + -10.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 10.0f, 100.0f, -100.0f > + }; > + > + size_t num_values = sizeof (test_values) / sizeof (test_values[0]); > + > + for (size_t i = 0; i < num_values; i++) > + for (size_t j = 0; j < num_values; j++) > + { > + __bf16 a = float_to_bf16 (test_values[i]); > + __bf16 b = float_to_bf16 (test_values[j]); > + > + test_eq (a, b); > + test_ne (a, b); > + test_lt (a, b); > + test_le (a, b); > + test_gt (a, b); > + test_ge (a, b); > + } > + > + return 0; > +} > -- > 2.31.1 >
-- BR, Hongtao