https://gcc.gnu.org/bugzilla/show_bug.cgi?id=118933

            Bug ID: 118933
           Summary: Missed optimization: __builtin_unreachable() does not
                    work as expected on std::vector data compared to raw
                    pointer and length
           Product: gcc
           Version: 14.2.0
            Status: UNCONFIRMED
          Severity: normal
          Priority: P3
         Component: tree-optimization
          Assignee: unassigned at gcc dot gnu.org
          Reporter: nicula.iccc at gmail dot com
  Target Milestone: ---

I have the following code which calculates the sum of an array of uint8_t
values into a uint8_t accumulator:

    #include <cstdint>
    #include <cstdlib>
    #include <numeric>

    uint8_t get_sum(const uint8_t *data, size_t len)
    {
        constexpr size_t U8_VALUES_PER_YMMWORD = 32;
        if ((len % U8_VALUES_PER_YMMWORD) != 0 || len == 0)
            __builtin_unreachable();

        return std::accumulate(data, data + len, uint8_t(0));
    }

They key point here is that I'm telling the compiler to assume that the entire
array can be traversed in YMM-sized chunks, and that the array is non-empty.
The compiler uses this information to generate a really compact, vectorized
loop (.L2 ... jmp .L2), which traverses the entire array YMMWORD-by-YMMWORD. No
other superfluous branches:

    get_sum(unsigned char const*, unsigned long):
            and     rsi, -32
            vpxor   xmm0, xmm0, xmm0
            add     rsi, rdi
    .L2:
            vpaddb  ymm0, ymm0, YMMWORD PTR [rdi]
            add     rdi, 32
            cmp     rdi, rsi
            jne     .L2
            vextracti32x4   xmm1, ymm0, 0x1
            vpaddb  xmm0, xmm1, xmm0
            vpsrldq xmm1, xmm0, 8
            vpaddb  xmm0, xmm0, xmm1
            vpxor   xmm1, xmm1, xmm1
            vpsadbw xmm0, xmm0, xmm1
            vpextrb eax, xmm0, 0
            vzeroupper
            ret

The problem is that when I try to do the same thing, but with a vector,
__builtin_unreachable() has no effect. Code:

    #include <cstdint>
    #include <cstdlib>
    #include <numeric>
    #include <vector>

    __always_inline
    uint8_t get_sum_helper(const uint8_t *data, size_t len)
    {
        constexpr size_t U8_VALUES_PER_YMMWORD = 32;
        if ((len % U8_VALUES_PER_YMMWORD) != 0 || len == 0)
            __builtin_unreachable();

        return std::accumulate(data, data + len, uint8_t(0));
    }

    uint8_t get_sum(const std::vector<uint8_t> &vec)
    {
        const uint8_t *data = vec.begin().base();
        const size_t   len  = vec.size();
        return get_sum_helper(data, len);
    }

Assembly:

    get_sum(std::vector<unsigned char, std::allocator<unsigned char>> const&):
            mov     r8, QWORD PTR [rdi]
            mov     rsi, QWORD PTR [rdi+8]
            cmp     rsi, r8
            je      .L9
            mov     rdi, rsi
            mov     rax, r8
            sub     rdi, r8
            lea     rdx, [rdi-1]
            cmp     rdx, 30
            jbe     .L10
            mov     r9, rdi
            vpxor   xmm0, xmm0, xmm0
            and     r9, -32
            lea     rcx, [r8+r9]
    .L4:
            vpaddb  ymm0, ymm0, YMMWORD PTR [rax]
            add     rax, 32
            cmp     rcx, rax
            jne     .L4
            vextracti32x4   xmm1, ymm0, 0x1
            vpxor   xmm2, xmm2, xmm2
            mov     rax, rcx
            vpaddb  xmm0, xmm1, xmm0
            vpsrldq xmm1, xmm0, 8
            vpaddb  xmm1, xmm0, xmm1
            vpsadbw xmm1, xmm1, xmm2
            vpextrb edx, xmm1, 0
            cmp     rdi, r9
            je      .L17
            vzeroupper
    .L3:
            sub     rdi, r9
            lea     rcx, [rdi-1]
            cmp     rcx, 14
            jbe     .L7
            vpaddb  xmm0, xmm0, XMMWORD PTR [r8+r9]
            mov     rcx, rdi
            and     rcx, -16
            vpsrldq xmm1, xmm0, 8
            add     rax, rcx
            and     edi, 15
            vpaddb  xmm0, xmm0, xmm1
            vpxor   xmm1, xmm1, xmm1
            vpsadbw xmm0, xmm0, xmm1
            vpextrb edx, xmm0, 0
            je      .L1
    .L7:
            lea     rcx, [rax+1]
            add     dl, BYTE PTR [rax]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+2]
            add     dl, BYTE PTR [rax+1]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+3]
            add     dl, BYTE PTR [rax+2]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+4]
            add     dl, BYTE PTR [rax+3]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+5]
            add     dl, BYTE PTR [rax+4]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+6]
            add     dl, BYTE PTR [rax+5]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+7]
            add     dl, BYTE PTR [rax+6]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+8]
            add     dl, BYTE PTR [rax+7]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+9]
            add     dl, BYTE PTR [rax+8]
            cmp     rsi, rcx
            je      .L1
            lea     rdi, [rax+10]
            add     dl, BYTE PTR [rax+9]
            cmp     rsi, rdi
            je      .L1
            lea     rcx, [rax+11]
            add     dl, BYTE PTR [rax+10]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+12]
            add     dl, BYTE PTR [rax+11]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+13]
            add     dl, BYTE PTR [rax+12]
            cmp     rsi, rcx
            je      .L1
            lea     rcx, [rax+14]
            add     dl, BYTE PTR [rax+13]
            cmp     rsi, rcx
            je      .L1
            add     dl, BYTE PTR [rax+14]
    .L1:
            mov     eax, edx
            ret
    .L9:
            xor     edx, edx
            mov     eax, edx
            ret
    .L10:
            vpxor   xmm0, xmm0, xmm0
            xor     r9d, r9d
            xor     edx, edx
            jmp     .L3
    .L17:
            vzeroupper
            jmp     .L1

As you can see, it generates a redundant branch for the possible remaining
XMM-sized chunk, and for single BYTE elements. I'm expecting this version with
std::vector to generate assembly *identical* to the first version, but in this
case the __builtin_unreachable() call has no effect, like I said.

Can this std::vector case be optimized just like the raw pointer + length one?

Note: I'm compiling with -O3 -march=skylake-avx512, gcc 14.2. Same problem on
trunk.

Godbolt links:
* raw pointer + length version: https://godbolt.org/z/Tq9xaofhj
* std::vector version: https://godbolt.org/z/PGq49vPPP

Reply via email to