ChrisHegarty commented on code in PR #13651: URL: https://github.com/apache/lucene/pull/13651#discussion_r1731499127
########## lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java: ########## @@ -761,4 +763,81 @@ private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int l // reduce return acc1.add(acc2).reduceLanes(ADD); } + + @Override + public long ipByteBinByte(byte[] q, byte[] d) { + if (VECTOR_BITSIZE == 128) { + return ipByteBinByte128(MemorySegment.ofArray(q), MemorySegment.ofArray(d)); + } + throw new UnsupportedOperationException("Vector bit size=" + VECTOR_BITSIZE); + } + + static final ByteOrder LE = ByteOrder.LITTLE_ENDIAN; + + public static long ipByteBinByte128(MemorySegment q, MemorySegment d) { + long ret = 0; + long subRet0 = 0; + long subRet1 = 0; + long subRet2 = 0; + long subRet3 = 0; + + final int limit = (int) ByteVector.SPECIES_128.loopBound(d.byteSize()); + // iterate in chunks of 256 bytes to ensure we don't overflow the accumulator + // (256bytes/16lanes=16itrs) + final int dSize = (int) d.byteSize(); + for (int j = 0; j < limit; j += 256) { + ByteVector acc0 = ByteVector.zero(ByteVector.SPECIES_128); + ByteVector acc1 = ByteVector.zero(ByteVector.SPECIES_128); + ByteVector acc2 = ByteVector.zero(ByteVector.SPECIES_128); + ByteVector acc3 = ByteVector.zero(ByteVector.SPECIES_128); + int innerLimit = Math.min(limit - j, 256); + for (int k = 0; k < innerLimit; k += ByteVector.SPECIES_128.length()) { + var vd = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, d, j + k, LE); + var vq0 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, q, j + k, LE); + var vq1 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, q, j + k + dSize, LE); + var vq2 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, q, j + k + 2 * dSize, LE); + var vq3 = ByteVector.fromMemorySegment(ByteVector.SPECIES_128, q, j + k + 3 * dSize, LE); + ByteVector vres0 = vq0.and(vd); + ByteVector vres1 = vq1.and(vd); + ByteVector vres2 = vq2.and(vd); + ByteVector vres3 = vq3.and(vd); + vres0 = vres0.lanewise(VectorOperators.BIT_COUNT); + vres1 = vres1.lanewise(VectorOperators.BIT_COUNT); + vres2 = vres2.lanewise(VectorOperators.BIT_COUNT); + vres3 = vres3.lanewise(VectorOperators.BIT_COUNT); + acc0 = acc0.add(vres0); + acc1 = acc1.add(vres1); + acc2 = acc2.add(vres2); + acc3 = acc3.add(vres3); + } + ShortVector sumShort1 = acc0.reinterpretAsShorts().and((short) 0xFF); + ShortVector sumShort2 = acc0.reinterpretAsShorts().lanewise(VectorOperators.LSHR, 8); + subRet0 += sumShort1.add(sumShort2).reduceLanes(VectorOperators.ADD); + + sumShort1 = acc1.reinterpretAsShorts().and((short) 0xFF); + sumShort2 = acc1.reinterpretAsShorts().lanewise(VectorOperators.LSHR, 8); + subRet1 += sumShort1.add(sumShort2).reduceLanes(VectorOperators.ADD); + + sumShort1 = acc2.reinterpretAsShorts().and((short) 0xFF); + sumShort2 = acc2.reinterpretAsShorts().lanewise(VectorOperators.LSHR, 8); + subRet2 += sumShort1.add(sumShort2).reduceLanes(VectorOperators.ADD); + + sumShort1 = acc3.reinterpretAsShorts().and((short) 0xFF); + sumShort2 = acc3.reinterpretAsShorts().lanewise(VectorOperators.LSHR, 8); + subRet3 += sumShort1.add(sumShort2).reduceLanes(VectorOperators.ADD); Review Comment: reduceLanes is indeed costly, but is only done once per 2048 elements of the dimension of each vector, so is typically not too bad in practice. I'll add some better docs/comments to help describe this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org