ChrisHegarty commented on code in PR #13321: URL: https://github.com/apache/lucene/pull/13321#discussion_r1582893750
########## lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java: ########## @@ -390,22 +392,202 @@ private int dotProductBody128(byte[] a, byte[] b, int limit) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; int i = 0; int res = 0; - if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { - return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); - res += int4DotProductBody128(a, b, i); - } - // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + if (packed.length >= 32) { + if (VECTOR_BITSIZE >= 512) { + i += ByteVector.SPECIES_256.loopBound(packed.length); + res += dotProductBody512Int4Packed(unpacked, packed, i); + } else if (VECTOR_BITSIZE == 256) { + i += ByteVector.SPECIES_128.loopBound(packed.length); + res += dotProductBody256Int4Packed(unpacked, packed, i); + } else if (HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_64.loopBound(packed.length); + res += dotProductBody128Int4Packed(unpacked, packed, i); + } + } + // scalar tail + for (; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + res += (packedByte & 0x0F) * unpacked2; + res += ((packedByte & 0xFF) >> 4) * unpacked1; + } + } else { + if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { + return dotProduct(a, b); + } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.length); + res += int4DotProductBody128(a, b, i); + } + // scalar tail + for (; i < a.length; i++) { + res += b[i] * a[i]; + } } + return res; } + private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 2048) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); + int innerLimit = Math.min(limit - i, 2048); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + /** vectorized dot product body (128 bit vectors) */ + private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { + // packed + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + // unpacked + ByteVector va8 = + ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + ShortVector prod16 = + prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 0xFF)); + + // lower + va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + prod8 = vb8.lanewise(LSHR, 4).mul(va8); + prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 0xFF)); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody512Packed(byte[] unpacked, byte[] packed, int limit) { + IntVector acc = IntVector.zero(INT_SPECIES); + for (int i = 0; i < limit; i += BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, unpacked, i); + ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, packed, i); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); + Vector<Short> vb16 = vb8.lanewise(LSHR, 4).convertShape(B2S, SHORT_SPECIES, 0); + Vector<Short> prod16 = va16.mul(vb16); + + // 32-bit add + Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES, 0); + acc = acc.add(prod32); + + va8 = ByteVector.fromArray(BYTE_SPECIES, unpacked, i + packed.length); + va16 = va8.convertShape(B2S, SHORT_SPECIES, 0); + vb16 = vb8.and((byte) 0x0F).convertShape(B2S, SHORT_SPECIES, 0); + prod16 = va16.mul(vb16); + prod32 = prod16.convertShape(S2I, INT_SPECIES, 0); + acc = acc.add(prod32); + } + // reduce + return acc.reduceLanes(ADD); + } + + private int dotProductBody256Packed(byte[] unpacked, byte[] packed, int limit) { Review Comment: Same as above, now unused. ########## lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java: ########## @@ -390,22 +392,202 @@ private int dotProductBody128(byte[] a, byte[] b, int limit) { } @Override - public int int4DotProduct(byte[] a, byte[] b) { + public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) { + assert (apacked && bpacked) == false; int i = 0; int res = 0; - if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { - return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { - i += ByteVector.SPECIES_128.loopBound(a.length); - res += int4DotProductBody128(a, b, i); - } - // scalar tail - for (; i < a.length; i++) { - res += b[i] * a[i]; + if (apacked || bpacked) { + byte[] packed = apacked ? a : b; + byte[] unpacked = apacked ? b : a; + if (packed.length >= 32) { + if (VECTOR_BITSIZE >= 512) { + i += ByteVector.SPECIES_256.loopBound(packed.length); + res += dotProductBody512Int4Packed(unpacked, packed, i); + } else if (VECTOR_BITSIZE == 256) { + i += ByteVector.SPECIES_128.loopBound(packed.length); + res += dotProductBody256Int4Packed(unpacked, packed, i); + } else if (HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_64.loopBound(packed.length); + res += dotProductBody128Int4Packed(unpacked, packed, i); + } + } + // scalar tail + for (; i < packed.length; i++) { + byte packedByte = packed[i]; + byte unpacked1 = unpacked[i]; + byte unpacked2 = unpacked[i + packed.length]; + res += (packedByte & 0x0F) * unpacked2; + res += ((packedByte & 0xFF) >> 4) * unpacked1; + } + } else { + if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { + return dotProduct(a, b); + } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.length); + res += int4DotProductBody128(a, b, i); + } + // scalar tail + for (; i < a.length; i++) { + res += b[i] * a[i]; + } } + return res; } + private int dotProductBody512Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 4096) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512); + int innerLimit = Math.min(limit - i, 4096); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_256, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_256, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_512, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_512, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_512, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 2048) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256); + int innerLimit = Math.min(limit - i, 2048); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + // packed + var vb8 = ByteVector.fromArray(ByteVector.SPECIES_128, packed, i + j); + // unpacked + var va8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc0 = acc0.add(prod16); + + // lower + ByteVector vc8 = ByteVector.fromArray(ByteVector.SPECIES_128, unpacked, i + j); + ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(vc8); + Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, ShortVector.SPECIES_256, 0); + acc1 = acc1.add(prod16a); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + /** vectorized dot product body (128 bit vectors) */ + private int dotProductBody128Int4Packed(byte[] unpacked, byte[] packed, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_64.length()) { + // packed + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, packed, i + j); + // unpacked + ByteVector va8 = + ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j + packed.length); + + // upper + ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8); + ShortVector prod16 = + prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 0xFF)); + + // lower + va8 = ByteVector.fromArray(ByteVector.SPECIES_64, unpacked, i + j); + prod8 = vb8.lanewise(LSHR, 4).mul(va8); + prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 0xFF)); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + + private int dotProductBody512Packed(byte[] unpacked, byte[] packed, int limit) { Review Comment: This is currently unused, I didn't remove it, in case @benwtrent wanted it for something else. If not, then it can be removed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org