thecoop commented on code in PR #14304: URL: https://github.com/apache/lucene/pull/14304#discussion_r1987194449
########## lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java: ########## @@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) { } return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + + @Override + public float quantize( + float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) { + float correction = 0; + int i = 0; + // only vectorize if we have a viable BYTE_SPECIES we can use for output + if (VECTOR_BITSIZE >= 256) { + for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) { + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); + + // Make sure the value is within the quantile range, cutting off the tails + // see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile - + // minQuantile) + FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile); + // Scale the value to the range [0, 127], this is our quantized value + // scale = 127/(maxQuantile - minQuantile) + // Math.round rounds to positive infinity, so do the same by +0.5 then truncating to int + Vector<Integer> roundedDxs = dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0); + // output this to the array + ((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, i); + // We multiply by `alpha` here to get the quantized value back into the original range + // to aid in calculating the corrective offset + Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha); + // Calculate the corrective offset that needs to be applied to the score + // in addition to the `byte * minQuantile * alpha` term in the equation + // we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value + // will be rounded to the nearest whole number and lose some accuracy + // Additionally, we account for the global correction of `minQuantile^2` in the equation + correction += + v.sub(minQuantile / 2f) + .mul(minQuantile) + .add(v.sub(minQuantile).sub(dxq).mul(dxq)) + .reduceLanes(VectorOperators.ADD); Review Comment: And even more with FMA operations -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org