Issue |
150321
|
Summary |
Vector reduction on zeroed vectors not eliminated
|
Labels |
new issue
|
Assignees |
|
Reporter |
Validark
|
Zig version: [Godbolt link](https://zig.godbo.lt/#g:!((g:!((g:!((h:codeEditor,i:(filename:'1',fontScale:14,fontUsePx:'0',j:1,lang:zig,selection:(endColumn:2,endLineNumber:34,positionColumn:2,positionLineNumber:34,selectionStartColumn:1,selectionStartLineNumber:1,startColumn:1,startLineNumber:1),source:'const+std+%3D+@import(%22std%22)%3B%0A%0Aexport+fn+count_external(ptr:+%5B*%5Dconst+u8,+len:+usize)+usize+%7B%0A++++return+count(ptr%5B0..len%5D,+0)%3B%0A%7D%0A%0A///+In+a+64+byte+vector,+we+consider+8+%22lanes%22+of+8+bytes+each.%0A///+This+function+returns+a+vector+where+each+lane+is+the+sum+of+the+8+bytes+in+the+corresponding+input+lane.%0Afn+sum_u64_lanes(a:+@Vector(64,+u8))+@Vector(8,+u64)+%7B%0A++++return+@as(@Vector(8,+u64),+@bitCast(a))+*+@as(@Vector(8,+u64),+@splat(0x0101010101010101))+%3E%3E+@splat(56)%3B%0A%7D%0A%0Afn+count(x:+%5B%5Dconst+u8,+char:+u8)+usize+%7B%0A++++var+cur+%3D+x%3B%0A++++var+accumulators:+@Vector(8,+u64)+%3D+@splat(0)%3B%0A%0A++++const+UNROLL_FACTOR+%3D+1%3B%0A++++//std.debug.assert(x.len+%25+64+%3D%3D+0+and+x.len+!!%3D+0)%3B%0A++++std.debug.assert(x.len+%25+64+*+UNROLL_FACTOR+%3D%3D+0)%3B%0A%0A++++while+(cur.len+%3E%3D+64+*+UNROLL_FACTOR)+%7B%0A++++++++inline+for+(0..UNROLL_FACTOR)+%7C_%7C+%7B%0A++++++++++++accumulators+%2B%3D+sum_u64_lanes(@select(%0A++++++++++++++++u8,%0A++++++++++++++++cur%5B0..64%5D.*+%3D%3D+@as(@Vector(64,+u8),+@splat(char)),%0A++++++++++++++++%5B1%5Du8%7B1%7D+**+64,%0A++++++++++++++++%5B1%5Du8%7B0%7D+**+64,%0A++++++++++++))%3B%0A++++++++++++cur+%3D+cur%5B64..%5D%3B%0A++++++++%7D%0A++++%7D%0A%0A++++return+@reduce(.Add,+accumulators)%3B%0A%7D'),l:'5',n:'0',o:'Zig+source+%231',t:'0')),k:52.020916986304115,l:'4',n:'0',o:'',s:0,t:'0'),(g:!((h:compiler,i:(compiler:ztrunk,filters:(b:'0',binary:'1',binaryObject:'1',commentOnly:'0',debugCalls:'1',demangle:'0',directives:'0',execute:'1',intel:'0',libraryCode:'0',trim:'1',verboseDemangling:'0'),flagsViewOpen:'1',fontScale:14,fontUsePx:'0',j:1,lang:zig,libs:!(),options:'-O+ReleaseFast+-target+x86_64-linux+-mcpu%3Dznver5',overrides:!(),selection:(endColumn:1,endLineNumber:1,positionColumn:1,positionLineNumber:1,selectionStartColumn:1,selectionStartLineNumber:1,startColumn:1,startLineNumber:1),source:1),l:'5',n:'0',o:'+zig+trunk+(Editor+%231)',t:'0')),k:47.9790830136959,l:'4',m:100,n:'0',o:'',s:0,t:'0')),l:'2',n:'0',o:'',t:'0')),version:4)
```zig
const std = @import("std");
export fn count_external(ptr: [*]const u8, len: usize) usize {
return count(ptr[0..len], 0);
}
/// In a 64 byte vector, we consider 8 "lanes" of 8 bytes each.
/// This function returns a vector where each lane is the sum of the 8 bytes in the corresponding input lane.
fn sum_u64_lanes(a: @Vector(64, u8)) @Vector(8, u64) {
return @as(@Vector(8, u64), @bitCast(a)) * @as(@Vector(8, u64), @splat(0x0101010101010101)) >> @splat(56);
}
fn count(x: []const u8, char: u8) usize {
var cur = x;
var accumulators: @Vector(8, u64) = @splat(0);
const UNROLL_FACTOR = 1;
//std.debug.assert(x.len % 64 == 0 and x.len != 0);
std.debug.assert(x.len % 64 * UNROLL_FACTOR == 0);
while (cur.len >= 64 * UNROLL_FACTOR) {
inline for (0..UNROLL_FACTOR) |_| {
accumulators += sum_u64_lanes(@select(
u8,
cur[0..64].* == @as(@Vector(64, u8), @splat(char)),
[1]u8{1} ** 64,
[1]u8{0} ** 64,
));
cur = cur[64..];
}
}
return @reduce(.Add, accumulators);
}
```
On Zen 5, we get the following emit:
```asm
.LCPI0_0:
.zero 64,1
.LCPI0_1:
.quad 72340172838076673
count_external:
test rsi, rsi
je .LBB0_1
push rbp
mov rbp, rsp
vpbroadcastq zmm1, qword ptr [rip + .LCPI0_1]
vpxor xmm0, xmm0, xmm0
xor eax, eax
.LBB0_4:
vmovdqu64 zmm2, zmmword ptr [rdi + rax]
add rax, 64
vptestnmb k1, zmm2, zmm2
vmovdqu8 zmm2 {k1} {z}, zmmword ptr [rip + .LCPI0_0]
vpmullq zmm2, zmm2, zmm1
vpsrlq zmm2, zmm2, 56
vpaddq zmm0, zmm2, zmm0
cmp rsi, rax
jne .LBB0_4
pop rbp
jmp .LBB0_2
.LBB0_1:
vpxor xmm0, xmm0, xmm0
.LBB0_2:
vextracti64x4 ymm1, zmm0, 1
vpaddq zmm0, zmm0, zmm1
vextracti128 xmm1, ymm0, 1
vpaddq xmm0, xmm0, xmm1
vpshufd xmm1, xmm0, 238
vpaddq xmm0, xmm0, xmm1
vmovq rax, xmm0
vzeroupper
ret
```
Same problem occurs on ARM, and I assume other arches as well.
Doing math in `.LBB0_2` is unnecessary. I also wish the two constants here would be recognized as identical. And, if we wanted to get even more nitpicky, the `vpxor xmm0, xmm0, xmm0` which happens in the TRUE and FALSE branches of the first `test` conditional could have been hoisted above the `test` conditional.
Should be:
```diff
-.LCPI0_0:
- .zero 64,1
.LCPI0_1:
.quad 72340172838076673
count_external:
+ xor eax, eax
test rsi, rsi
je .LBB0_1
push rbp
mov rbp, rsp
vpbroadcastq zmm1, qword ptr [rip + .LCPI0_1]
vpxor xmm0, xmm0, xmm0
- xor eax, eax
.LBB0_4:
vmovdqu64 zmm2, zmmword ptr [rdi + rax]
add rax, 64
vptestnmb k1, zmm2, zmm2
- vmovdqu8 zmm2 {k1} {z}, zmmword ptr [rip + .LCPI0_0]
+ vmovdqu8 zmm2 {k1} {z}, zmm1
vpmullq zmm2, zmm2, zmm1
vpsrlq zmm2, zmm2, 56
vpaddq zmm0, zmm2, zmm0
cmp rsi, rax
jne .LBB0_4
pop rbp
- jmp .LBB0_2
-.LBB0_1:
- vpxor xmm0, xmm0, xmm0
-.LBB0_2:
vextracti64x4 ymm1, zmm0, 1
vpaddq zmm0, zmm0, zmm1
vextracti128 xmm1, ymm0, 1
vpaddq xmm0, xmm0, xmm1
vpshufd xmm1, xmm0, 238
vpaddq xmm0, xmm0, xmm1
vmovq rax, xmm0
vzeroupper
+.LBB0_1:
ret
```
LLVM emit (before optimizations):
```llvm
; ModuleID = 'main'
source_filename = "main"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux4.19.0-musl"
@builtin.zig_backend = internal unnamed_addr constant i64 2, align 8
@start.simplified_logic = internal unnamed_addr constant i1 false, align 1
@builtin.output_mode = internal unnamed_addr constant i2 -2, align 1
; Function Attrs: nounwind uwtable
define dso_local i64 @count_external(ptr align 1 readonly nonnull %0, i64 %1) #0 {
2:
%3 = alloca [8 x i8], align 8
store ptr %0, ptr %3, align 8
%4 = load ptr, ptr %3, align 8
%5 = getelementptr inbounds i8, ptr %4, i64 0
%6 = insertvalue { ptr, i64 } poison, ptr %5, 0
%7 = insertvalue { ptr, i64 } %6, i64 %1, 1
%8 = extractvalue { ptr, i64 } %7, 0
%9 = extractvalue { ptr, i64 } %7, 1
%10 = call fastcc i64 @main.count(ptr align 1 readonly nonnull %8, i64 %9, i8 0)
ret i64 %10
}
; Function Attrs: nounwind uwtable
define internal fastcc i64 @main.count(ptr align 1 readonly nonnull %0, i64 %1, i8 %2) unnamed_addr #0 {
3:
%4 = alloca [64 x i8], align 64
%5 = alloca [16 x i8], align 8
%6 = insertvalue { ptr, i64 } poison, ptr %0, 0
%7 = insertvalue { ptr, i64 } %6, i64 %1, 1
store { ptr, i64 } %7, ptr %5, align 8
store <8 x i64> zeroinitializer, ptr %4, align 64
%8 = extractvalue { ptr, i64 } %7, 1
%9 = urem i64 %8, 64
%10 = icmp eq i64 %9, 0
call fastcc void @debug.assert(i1 %10)
br label %14
11:
%12 = load <8 x i64>, ptr %4, align 64
%13 = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %12)
ret i64 %13
14:
%15 = load { ptr, i64 }, ptr %5, align 8
%16 = extractvalue { ptr, i64 } %15, 1
%17 = icmp uge i64 %16, 64
br i1 %17, label %19, label %38
18:
br label %14
19:
%20 = load <8 x i64>, ptr %4, align 64
%21 = load { ptr, i64 }, ptr %5, align 8
%22 = extractvalue { ptr, i64 } %21, 0
%23 = getelementptr inbounds i8, ptr %22, i64 0
%24 = insertelement <1 x i8> poison, i8 %2, i32 0
%25 = shufflevector <1 x i8> %24, <1 x i8> poison, <64 x i32> zeroinitializer
%26 = load <64 x i8>, ptr %23, align 1
%27 = icmp eq <64 x i8> %26, %25
%28 = select <64 x i1> %27, <64 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <64 x i8> zeroinitializer
%29 = call fastcc <8 x i64> @main.sum_u64_lanes(<64 x i8> %28)
%30 = add nuw <8 x i64> %20, %29
store <8 x i64> %30, ptr %4, align 64
%31 = load { ptr, i64 }, ptr %5, align 8
%32 = extractvalue { ptr, i64 } %31, 0
%33 = getelementptr inbounds i8, ptr %32, i64 64
%34 = extractvalue { ptr, i64 } %31, 1
%35 = sub nuw i64 %34, 64
%36 = insertvalue { ptr, i64 } poison, ptr %33, 0
%37 = insertvalue { ptr, i64 } %36, i64 %35, 1
store { ptr, i64 } %37, ptr %5, align 8
br label %18
38:
br label %11
}
; Function Attrs: nounwind uwtable
define internal fastcc void @debug.assert(i1 %0) unnamed_addr #0 {
1:
%2 = xor i1 %0, true
br i1 %2, label %4, label %5
3:
ret void
4:
unreachable
5:
br label %3
}
; Function Attrs: nounwind uwtable
define internal fastcc <8 x i64> @main.sum_u64_lanes(<64 x i8> %0) unnamed_addr #0 {
1:
%2 = bitcast <64 x i8> %0 to <8 x i64>
%3 = mul nuw <8 x i64> %2, <i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673, i64 72340172838076673>
%4 = zext <8 x i6> <i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8, i6 -8> to <8 x i64>
%5 = lshr <8 x i64> %3, %4
ret <8 x i64> %5
}
; Function Attrs: nounwind speculatable willreturn nofree nosync nocallback memory(none)
declare i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %0) #1
attributes #0 = { nounwind uwtable "frame-pointer"="all" "target-cpu"="znver5" "target-features"="+64bit,+adx,+aes,+allow-light-256-bit,+avx,+avx2,+avx512bf16,+avx512bitalg,+avx512bw,+avx512cd,+avx512dq,+avx512f,+avx512ifma,+avx512vbmi,+avx512vbmi2,+avx512vl,+avx512vnni,+avx512vp2intersect,+avx512vpopcntdq,+avxvnni,+bmi,+bmi2,+branchfusion,+clflushopt,+clwb,+clzero,+cmov,+crc32,+cx16,+cx8,+evex512,+f16c,+fast-15bytenop,+fast-bextr,+fast-dpwssd,+fast-imm16,+fast-lzcnt,+fast-movbe,+fast-scalar-fsqrt,+fast-scalar-shift-masks,+fast-variable-perlane-shuffle,+fast-vector-fsqrt,+fma,+fsgsbase,+fsrm,+fxsr,+gfni,+idivq-to-divl,+invpcid,+lzcnt,+macrofusion,+mmx,+movbe,+movdir64b,+movdiri,+mwaitx,+nopl,+pclmul,+pku,+popcnt,+prefetchi,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sbb-dep-breaking,+sha,+shstk,+slow-shld,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+sse4a,+ssse3,+vaes,+vpclmulqdq,+vzeroupper,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,-16bit-mode,-32bit-mode,-amx-bf16,-amx-complex,-amx-fp16,-amx-int8,-amx-tile,-avx10.1-256,-avx10.1-512,-avx512fp16,-avxifma,-avxneconvert,-avxvnniint16,-avxvnniint8,-branch-hint,-ccmp,-cf,-cldemote,-cmpccxadd,-egpr,-enqcmd,-ermsb,-false-deps-getmant,-false-deps-lzcnt-tzcnt,-false-deps-mulc,-false-deps-mullq,-false-deps-perm,-false-deps-popcnt,-false-deps-range,-fast-11bytenop,-fast-7bytenop,-fast-gather,-fast-hops,-fast-shld-rotate,-fast-variable-crosslane-shuffle,-fast-vector-shift-masks,-faster-shift-than-shuffle,-fma4,-harden-sls-ijmp,-harden-sls-ret,-hreset,-idivl-to-divb,-inline-asm-use-gpr32,-kl,-lea-sp,-lea-uses-ag,-lvi-cfi,-lvi-load-hardening,-lwp,-ndd,-nf,-no-bypass-delay,-no-bypass-delay-blend,-no-bypass-delay-mov,-no-bypass-delay-shuffle,-pad-short-functions,-pconfig,-ppx,-prefer-128-bit,-prefer-256-bit,-prefer-mask-registers,-prefer-movmsk-over-vtest,-prefer-no-gather,-prefer-no-scatter,-ptwrite,-push2pop2,-raoint,-retpoline,-retpoline-external-thunk,-retpoline-indirect-branches,-retpoline-indirect-calls,-rtm,-serialize,-seses,-sgx,-sha512,-slow-3ops-lea,-slow-incdec,-slow-lea,-slow-pmaddwd,-slow-pmulld,-slow-two-mem-ops,-slow-unaligned-mem-16,-slow-unaligned-mem-32,-sm3,-sm4,-soft-float,-sse-unaligned-mem,-tagged-globals,-tbm,-tsxldtrk,-tuning-fast-imm-vector-shift,-uintr,-use-glm-div-sqrt-costs,-use-slm-arith-costs,-usermsr,-waitpkg,-widekl,-xop,-zu" }
attributes #1 = { nounwind speculatable willreturn nofree nosync nocallback memory(none) }
!llvm.module.flags = !{}
```
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs