benwtrent commented on code in PR #14304:
URL: https://github.com/apache/lucene/pull/14304#discussion_r1985728235


##########
lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java:
##########
@@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[] 
d) {
     }
     return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
   }
+
+  @Override
+  public float quantize(
+      float[] vector, byte[] dest, float scale, float alpha, float 
minQuantile, float maxQuantile) {
+    float correction = 0;
+    int i = 0;
+    // only vectorize if we have a viable BYTE_SPECIES we can use for output
+    if (VECTOR_BITSIZE >= 256) {
+      for (; i < FLOAT_SPECIES.loopBound(vector.length); i += 
FLOAT_SPECIES.length()) {
+        FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
+
+        // Make sure the value is within the quantile range, cutting off the 
tails
+        // see first parenthesis in equation: byte = (float - minQuantile) * 
127/(maxQuantile -
+        // minQuantile)
+        FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
+        // Scale the value to the range [0, 127], this is our quantized value
+        // scale = 127/(maxQuantile - minQuantile)
+        // Math.round rounds to positive infinity, so do the same by +0.5 then 
truncating to int
+        Vector<Integer> roundedDxs = 
dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0);
+        // output this to the array
+        ((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, 
i);
+        // We multiply by `alpha` here to get the quantized value back into 
the original range
+        // to aid in calculating the corrective offset
+        Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 
0)).mul(alpha);
+        // Calculate the corrective offset that needs to be applied to the 
score
+        // in addition to the `byte * minQuantile * alpha` term in the equation
+        // we add the `(dx - dxq) * dxq` term to account for the fact that the 
quantized value
+        // will be rounded to the nearest whole number and lose some accuracy
+        // Additionally, we account for the global correction of 
`minQuantile^2` in the equation
+        correction +=
+            v.sub(minQuantile / 2f)
+                .mul(minQuantile)
+                .add(v.sub(minQuantile).sub(dxq).mul(dxq))
+                .reduceLanes(VectorOperators.ADD);

Review Comment:
   Could you collect the corrections in a float array? This way we keep all 
lanes parallized and then sum the floats later?



##########
lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java:
##########
@@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[] 
d) {
     }
     return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
   }
+
+  @Override
+  public float quantize(

Review Comment:
   Let's name this something better, we can call it "minMaxScalarQuantization" 
or something?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to