kaivalnp commented on code in PR #14863: URL: https://github.com/apache/lucene/pull/14863#discussion_r2253097713
########## lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java: ########## @@ -530,7 +566,41 @@ private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limi return sum; } - private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + private static int dotProductBody512Int4PackedPacked( + ByteVectorLoader a, ByteVectorLoader b, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + // packed + var vb8 = b.load(ByteVector.SPECIES_256, i + j); + // packed + var va8 = a.load(ByteVector.SPECIES_256, i + j); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F)); + Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); Review Comment: Ah, I see what you mean.. On my machine, the preferred bit size was 256 so I changed the following function: ```java private static int dotProductBody256Int4PackedPacked( ByteVectorLoader a, ByteVectorLoader b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { // packed var vb8 = b.load(ByteVector.SPECIES_128, i + j); // packed var va8 = a.load(ByteVector.SPECIES_128, i + j); // upper ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F)); Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc0 = acc0.add(prod16); // lower ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(va8.lanewise(LSHR, 4)); Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); acc1 = acc1.add(prod16a); } IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); } return sum; } ``` ..to: ```java private static int dotProductBody256Int4PackedPacked( ByteVectorLoader a, ByteVectorLoader b, int limit) { int sum = 0; // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator for (int i = 0; i < limit; i += 2048) { ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); ShortVector acc2 = ShortVector.zero(ShortVector.SPECIES_256); ShortVector acc3 = ShortVector.zero(ShortVector.SPECIES_256); int innerLimit = Math.min(limit - i, 2048); for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { // packed var sb8 = b.load(ByteVector.SPECIES_256, i + j).reinterpretAsShorts(); // packed var sa8 = a.load(ByteVector.SPECIES_256, i + j).reinterpretAsShorts(); // s1 ShortVector prod8a = sb8.and((short) 0x000F).mul(sa8.and((short) 0x00F)); acc0 = acc0.add(prod8a); // s2 ShortVector prod8b = sb8.lanewise(LSHR, 4).and((short) 0x000F).mul(sa8.lanewise(LSHR, 4).and((short) 0x000F)); acc1 = acc1.add(prod8b); // s3 ShortVector prod8c = sb8.lanewise(LSHR, 8).and((short) 0x000F).mul(sa8.lanewise(LSHR, 8).and((short) 0x000F)); acc2 = acc2.add(prod8c); // s4 ShortVector prod8d = sb8.lanewise(LSHR, 12) .and((short) 0x000F) .mul(sa8.lanewise(LSHR, 12).and((short) 0x000F)); acc3 = acc3.add(prod8d); } IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); IntVector intAcc4 = acc2.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc5 = acc2.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); IntVector intAcc6 = acc3.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); IntVector intAcc7 = acc3.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); sum += intAcc0 .add(intAcc1) .add(intAcc2) .add(intAcc3) .add(intAcc4) .add(intAcc5) .add(intAcc6) .add(intAcc7) .reduceLanes(ADD); } return sum; } ``` - I basically read 256 bits worth of `short` values, extracted 4x "half-bytes" per `short` using a combination of `LSHR` and `and 0x000F` - Since we have 4x short accumulators, they'll map to 8x int vectors on expansion - Note that I also changed the limit above (`i += ByteVector.SPECIES_256.loopBound(a.length());`) to account for reading 256 bits at the same time Is this what you meant? Baseline (this PR): ``` VectorUtilBenchmark.binaryHalfByteVectorPackedPacked 1024 thrpt 15 13.427 ± 0.031 ops/us ``` Modification: ``` VectorUtilBenchmark.binaryHalfByteVectorPackedPacked 1024 thrpt 15 13.213 ± 0.039 ops/us ``` There doesn't seem to be much difference between the two.. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org