Issue 150321
Summary Vector reduction on zeroed vectors not eliminated
Labels new issue
Assignees
Reporter Validark
    Zig version: [Godbolt link](https://zig.godbo.lt/#g:!((g:!((g:!((h:codeEditor,i:(filename:'1',fontScale:14,fontUsePx:'0',j:1,lang:zig,selection:(endColumn:2,endLineNumber:34,positionColumn:2,positionLineNumber:34,selectionStartColumn:1,selectionStartLineNumber:1,startColumn:1,startLineNumber:1),source:'const+std+%3D+@import(%22std%22)%3B%0A%0Aexport+fn+count_external(ptr:+%5B*%5Dconst+u8,+len:+usize)+usize+%7B%0A++++return+count(ptr%5B0..len%5D,+0)%3B%0A%7D%0A%0A///+In+a+64+byte+vector,+we+consider+8+%22lanes%22+of+8+bytes+each.%0A///+This+function+returns+a+vector+where+each+lane+is+the+sum+of+the+8+bytes+in+the+corresponding+input+lane.%0Afn+sum_u64_lanes(a:+@Vector(64,+u8))+@Vector(8,+u64)+%7B%0A++++return+@as(@Vector(8,+u64),+@bitCast(a))+*+@as(@Vector(8,+u64),+@splat(0x0101010101010101))+%3E%3E+@splat(56)%3B%0A%7D%0A%0Afn+count(x:+%5B%5Dconst+u8,+char:+u8)+usize+%7B%0A++++var+cur+%3D+x%3B%0A++++var+accumulators:+@Vector(8,+u64)+%3D+@splat(0)%3B%0A%0A++++const+UNROLL_FACTOR+%3D+1%3B%0A++++//std.debug.assert(x.len+%25+64+%3D%3D+0+and+x.len+!!%3D+0)%3B%0A++++std.debug.assert(x.len+%25+64+*+UNROLL_FACTOR+%3D%3D+0)%3B%0A%0A++++while+(cur.len+%3E%3D+64+*+UNROLL_FACTOR)+%7B%0A++++++++inline+for+(0..UNROLL_FACTOR)+%7C_%7C+%7B%0A++++++++++++accumulators+%2B%3D+sum_u64_lanes(@select(%0A++++++++++++++++u8,%0A++++++++++++++++cur%5B0..64%5D.*+%3D%3D+@as(@Vector(64,+u8),+@splat(char)),%0A++++++++++++++++%5B1%5Du8%7B1%7D+**+64,%0A++++++++++++++++%5B1%5Du8%7B0%7D+**+64,%0A++++++++++++))%3B%0A++++++++++++cur+%3D+cur%5B64..%5D%3B%0A++++++++%7D%0A++++%7D%0A%0A++++return+@reduce(.Add,+accumulators)%3B%0A%7D'),l:'5',n:'0',o:'Zig+source+%231',t:'0')),k:52.020916986304115,l:'4',n:'0',o:'',s:0,t:'0'),(g:!((h:compiler,i:(compiler:ztrunk,filters:(b:'0',binary:'1',binaryObject:'1',commentOnly:'0',debugCalls:'1',demangle:'0',directives:'0',execute:'1',intel:'0',libraryCode:'0',trim:'1',verboseDemangling:'0'),flagsViewOpen:'1',fontScale:14,fontUsePx:'0',j:1,lang:zig,libs:!(),options:'-O+ReleaseFast+-target+x86_64-linux+-mcpu%3Dznver5',overrides:!(),selection:(endColumn:1,endLineNumber:1,positionColumn:1,positionLineNumber:1,selectionStartColumn:1,selectionStartLineNumber:1,startColumn:1,startLineNumber:1),source:1),l:'5',n:'0',o:'+zig+trunk+(Editor+%231)',t:'0')),k:47.9790830136959,l:'4',m:100,n:'0',o:'',s:0,t:'0')),l:'2',n:'0',o:'',t:'0')),version:4)

```zig
const std = @import("std");

export fn count_external(ptr: [*]const u8, len: usize) usize {
    return count(ptr[0..len], 0);
}

/// In a 64 byte vector, we consider 8 "lanes" of 8 bytes each.
/// This function returns a vector where each lane is the sum of the 8 bytes in the corresponding input lane.
fn sum_u64_lanes(a: @Vector(64, u8)) @Vector(8, u64) {
    return @as(@Vector(8, u64), @bitCast(a)) * @as(@Vector(8, u64), @splat(0x0101010101010101)) >> @splat(56);
}

fn count(x: []const u8, char: u8) usize {
    var cur = x;
    var accumulators: @Vector(8, u64) = @splat(0);

    const UNROLL_FACTOR = 1;
    //std.debug.assert(x.len % 64 == 0 and x.len != 0);
    std.debug.assert(x.len % 64 * UNROLL_FACTOR == 0);

    while (cur.len >= 64 * UNROLL_FACTOR) {
        inline for (0..UNROLL_FACTOR) |_| {
            accumulators += sum_u64_lanes(@select(
                u8,
                cur[0..64].* == @as(@Vector(64, u8), @splat(char)),
                [1]u8{1} ** 64,
 [1]u8{0} ** 64,
            ));
            cur = cur[64..];
 }
    }

    return @reduce(.Add, accumulators);
}
```

On Zen 5, we get the following emit:

```asm
.LCPI0_0:
        .zero 64,1
.LCPI0_1:
        .quad   72340172838076673
count_external:
 test    rsi, rsi
        je      .LBB0_1
        push    rbp
        mov rbp, rsp
        vpbroadcastq    zmm1, qword ptr [rip + .LCPI0_1]
 vpxor   xmm0, xmm0, xmm0
        xor     eax, eax
.LBB0_4:
 vmovdqu64       zmm2, zmmword ptr [rdi + rax]
        add     rax, 64
 vptestnmb       k1, zmm2, zmm2
        vmovdqu8        zmm2 {k1} {z}, zmmword ptr [rip + .LCPI0_0]
        vpmullq zmm2, zmm2, zmm1
 vpsrlq  zmm2, zmm2, 56
        vpaddq  zmm0, zmm2, zmm0
        cmp rsi, rax
        jne     .LBB0_4
        pop     rbp
        jmp .LBB0_2
.LBB0_1:
        vpxor   xmm0, xmm0, xmm0
.LBB0_2:
 vextracti64x4   ymm1, zmm0, 1
        vpaddq  zmm0, zmm0, zmm1
 vextracti128    xmm1, ymm0, 1
        vpaddq  xmm0, xmm0, xmm1
 vpshufd xmm1, xmm0, 238
        vpaddq  xmm0, xmm0, xmm1
        vmovq rax, xmm0
        vzeroupper
        ret
```

Same problem occurs on ARM, and I assume other arches as well.

Doing math in `.LBB0_2` is unnecessary. I also wish the two constants here would be recognized as identical. And, if we wanted to get even more nitpicky, the `vpxor   xmm0, xmm0, xmm0` which happens in the TRUE and FALSE branches of the first `test` conditional could have been hoisted above the `test` conditional.

Should be:

```diff
-.LCPI0_0:
-       .zero   64,1
.LCPI0_1:
        .quad 72340172838076673
count_external:
+       xor     eax, eax
        test rsi, rsi
        je      .LBB0_1
        push    rbp
        mov     rbp, rsp
        vpbroadcastq    zmm1, qword ptr [rip + .LCPI0_1]
        vpxor xmm0, xmm0, xmm0
-       xor     eax, eax
.LBB0_4:
        vmovdqu64 zmm2, zmmword ptr [rdi + rax]
        add     rax, 64
        vptestnmb k1, zmm2, zmm2
-       vmovdqu8        zmm2 {k1} {z}, zmmword ptr [rip + .LCPI0_0]
+       vmovdqu8        zmm2 {k1} {z}, zmm1
        vpmullq zmm2, zmm2, zmm1
        vpsrlq  zmm2, zmm2, 56
        vpaddq  zmm0, zmm2, zmm0
        cmp     rsi, rax
        jne     .LBB0_4
        pop rbp
-       jmp     .LBB0_2
-.LBB0_1:
-       vpxor   xmm0, xmm0, xmm0
-.LBB0_2:
        vextracti64x4   ymm1, zmm0, 1
        vpaddq  zmm0, zmm0, zmm1
        vextracti128    xmm1, ymm0, 1
        vpaddq  xmm0, xmm0, xmm1
        vpshufd xmm1, xmm0, 238
        vpaddq  xmm0, xmm0, xmm1
        vmovq   rax, xmm0
        vzeroupper
+.LBB0_1:
 ret
```

LLVM emit (before optimizations):

```llvm
; ModuleID = 'main'
source_filename = "main"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux4.19.0-musl"

@builtin.zig_backend = internal unnamed_addr constant i64 2, align 8
@start.simplified_logic = internal unnamed_addr constant i1 false, align 1
@builtin.output_mode = internal unnamed_addr constant i2 -2, align 1

; Function Attrs: nounwind uwtable
define dso_local i64 @count_external(ptr align 1 readonly nonnull %0, i64 %1) #0 {
2:
  %3 = alloca [8 x i8], align 8
  store ptr %0, ptr %3, align 8
  %4 = load ptr, ptr %3, align 8
  %5 = getelementptr inbounds i8, ptr %4, i64 0
  %6 = insertvalue { ptr, i64 } poison, ptr %5, 0
  %7 = insertvalue { ptr, i64 } %6, i64 %1, 1
  %8 = extractvalue { ptr, i64 } %7, 0
  %9 = extractvalue { ptr, i64 } %7, 1
  %10 = call fastcc i64 @main.count(ptr align 1 readonly nonnull %8, i64 %9, i8 0)
  ret i64 %10
}

; Function Attrs: nounwind uwtable
define internal fastcc i64 @main.count(ptr align 1 readonly nonnull %0, i64 %1, i8 %2) unnamed_addr #0 {
3:
  %4 = alloca [64 x i8], align 64
  %5 = alloca [16 x i8], align 8
 %6 = insertvalue { ptr, i64 } poison, ptr %0, 0
  %7 = insertvalue { ptr, i64 } %6, i64 %1, 1
  store { ptr, i64 } %7, ptr %5, align 8
  store <8 x i64> zeroinitializer, ptr %4, align 64
  %8 = extractvalue { ptr, i64 } %7, 1
  %9 = urem i64 %8, 64
  %10 = icmp eq i64 %9, 0
  call fastcc void @debug.assert(i1 %10)
  br label %14

11:
  %12 = load <8 x i64>, ptr %4, align 64
  %13 = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %12)
 ret i64 %13

14:
  %15 = load { ptr, i64 }, ptr %5, align 8
  %16 = extractvalue { ptr, i64 } %15, 1
  %17 = icmp uge i64 %16, 64
  br i1 %17, label %19, label %38

18:
  br label %14

19:
  %20 = load <8 x i64>, ptr %4, align 64
  %21 = load { ptr, i64 }, ptr %5, align 8
  %22 = extractvalue { ptr, i64 } %21, 0
  %23 = getelementptr inbounds i8, ptr %22, i64 0
  %24 = insertelement <1 x i8> poison, i8 %2, i32 0
  %25 = shufflevector <1 x i8> %24, <1 x i8> poison, <64 x i32> zeroinitializer
 %26 = load <64 x i8>, ptr %23, align 1
  %27 = icmp eq <64 x i8> %26, %25
 %28 = select <64 x i1> %27, <64 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <64 x i8> zeroinitializer
  %29 = call fastcc <8 x i64> @main.sum_u64_lanes(<64 x i8> %28)
  %30 = add nuw <8 x i64> %20, %29
  store <8 x i64> %30, ptr %4, align 64
  %31 = load { ptr, i64 }, ptr %5, align 8
  %32 = extractvalue { ptr, i64 } %31, 0
  %33 = getelementptr inbounds i8, ptr %32, i64 64
  %34 = extractvalue { ptr, i64 } %31, 1
  %35 = sub nuw i64 %34, 64
  %36 = insertvalue { ptr, i64 } poison, ptr %33, 0
  %37 = insertvalue { ptr, i64 } %36, i64 %35, 1
  store { ptr, i64 } %37, ptr %5, align 8
  br label %18

38:
  br label %11
}

; Function Attrs: nounwind uwtable
define internal fastcc void @debug.assert(i1 %0) unnamed_addr #0 {
1:
  %2 = xor i1 %0, true
  br i1 %2, label %4, label %5

3:
  ret void

4:
  unreachable

5:
  br label %3
}

; Function Attrs: nounwind uwtable
define internal fastcc <8 x i64> @main.sum_u64_lanes(<64 x i8> %0) unnamed_addr #0 {
1:
  %2 = bitcast <64 x i8> %0 to <8 x i64>
  %3 = mul nuw <8 x i64> %2, <i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673>
  %4 = zext <8 x i6> <i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8> to <8 x i64>
  %5 = lshr <8 x i64> %3, %4
  ret <8 x i64> %5
}

; Function Attrs: nounwind speculatable willreturn nofree nosync nocallback memory(none)
declare i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %0) #1

attributes #0 = { nounwind uwtable "frame-pointer"="all" "target-cpu"="znver5" "target-features"="+64bit,+adx,+aes,+allow-light-256-bit,+avx,+avx2,+avx512bf16,+avx512bitalg,+avx512bw,+avx512cd,+avx512dq,+avx512f,+avx512ifma,+avx512vbmi,+avx512vbmi2,+avx512vl,+avx512vnni,+avx512vp2intersect,+avx512vpopcntdq,+avxvnni,+bmi,+bmi2,+branchfusion,+clflushopt,+clwb,+clzero,+cmov,+crc32,+cx16,+cx8,+evex512,+f16c,+fast-15bytenop,+fast-bextr,+fast-dpwssd,+fast-imm16,+fast-lzcnt,+fast-movbe,+fast-scalar-fsqrt,+fast-scalar-shift-masks,+fast-variable-perlane-shuffle,+fast-vector-fsqrt,+fma,+fsgsbase,+fsrm,+fxsr,+gfni,+idivq-to-divl,+invpcid,+lzcnt,+macrofusion,+mmx,+movbe,+movdir64b,+movdiri,+mwaitx,+nopl,+pclmul,+pku,+popcnt,+prefetchi,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sbb-dep-breaking,+sha,+shstk,+slow-shld,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+sse4a,+ssse3,+vaes,+vpclmulqdq,+vzeroupper,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,-16bit-mode,-32bit-mode,-amx-bf16,-amx-complex,-amx-fp16,-amx-int8,-amx-tile,-avx10.1-256,-avx10.1-512,-avx512fp16,-avxifma,-avxneconvert,-avxvnniint16,-avxvnniint8,-branch-hint,-ccmp,-cf,-cldemote,-cmpccxadd,-egpr,-enqcmd,-ermsb,-false-deps-getmant,-false-deps-lzcnt-tzcnt,-false-deps-mulc,-false-deps-mullq,-false-deps-perm,-false-deps-popcnt,-false-deps-range,-fast-11bytenop,-fast-7bytenop,-fast-gather,-fast-hops,-fast-shld-rotate,-fast-variable-crosslane-shuffle,-fast-vector-shift-masks,-faster-shift-than-shuffle,-fma4,-harden-sls-ijmp,-harden-sls-ret,-hreset,-idivl-to-divb,-inline-asm-use-gpr32,-kl,-lea-sp,-lea-uses-ag,-lvi-cfi,-lvi-load-hardening,-lwp,-ndd,-nf,-no-bypass-delay,-no-bypass-delay-blend,-no-bypass-delay-mov,-no-bypass-delay-shuffle,-pad-short-functions,-pconfig,-ppx,-prefer-128-bit,-prefer-256-bit,-prefer-mask-registers,-prefer-movmsk-over-vtest,-prefer-no-gather,-prefer-no-scatter,-ptwrite,-push2pop2,-raoint,-retpoline,-retpoline-external-thunk,-retpoline-indirect-branches,-retpoline-indirect-calls,-rtm,-serialize,-seses,-sgx,-sha512,-slow-3ops-lea,-slow-incdec,-slow-lea,-slow-pmaddwd,-slow-pmulld,-slow-two-mem-ops,-slow-unaligned-mem-16,-slow-unaligned-mem-32,-sm3,-sm4,-soft-float,-sse-unaligned-mem,-tagged-globals,-tbm,-tsxldtrk,-tuning-fast-imm-vector-shift,-uintr,-use-glm-div-sqrt-costs,-use-slm-arith-costs,-usermsr,-waitpkg,-widekl,-xop,-zu" }
attributes #1 = { nounwind speculatable willreturn nofree nosync nocallback memory(none) }

!llvm.module.flags = !{}
```
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to