thecoop commented on code in PR #14863:
URL: https://github.com/apache/lucene/pull/14863#discussion_r2254153413


##########
lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java:
##########
@@ -455,49 +455,85 @@ private static int dotProductBody128(ByteVectorLoader a, 
ByteVectorLoader b, int
 
   @Override
   public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean 
bpacked) {
+    return int4DotProductBody(new ArrayLoader(a), apacked, new ArrayLoader(b), 
bpacked);
+  }
+
+  public static int int4DotProduct(
+      MemorySegment a, boolean apacked, MemorySegment b, boolean bpacked) {
+    return int4DotProductBody(
+        new MemorySegmentLoader(a), apacked, new MemorySegmentLoader(b), 
bpacked);
+  }
+
+  public static int int4DotProduct(byte[] a, boolean apacked, MemorySegment b, 
boolean bpacked) {
+    return int4DotProductBody(new ArrayLoader(a), apacked, new 
MemorySegmentLoader(b), bpacked);
+  }
+
+  private static int int4DotProductBody(
+      ByteVectorLoader a, boolean apacked, ByteVectorLoader b, boolean 
bpacked) {
     assert (apacked && bpacked) == false;
     int i = 0;
     int res = 0;
-    if (apacked || bpacked) {
-      byte[] packed = apacked ? a : b;
-      byte[] unpacked = apacked ? b : a;
-      if (packed.length >= 32) {
+    if (apacked && bpacked) {
+      if (a.length() >= 32) {
+        if (VECTOR_BITSIZE >= 512) {
+          i += ByteVector.SPECIES_256.loopBound(a.length());
+          res += dotProductBody512Int4PackedPacked(a, b, i);
+        } else if (VECTOR_BITSIZE == 256) {
+          i += ByteVector.SPECIES_128.loopBound(a.length());
+          res += dotProductBody256Int4PackedPacked(a, b, i);
+        } else {
+          i += ByteVector.SPECIES_64.loopBound(a.length());
+          res += dotProductBody128Int4PackedPacked(a, b, i);
+        }
+      }
+      // scalar tail
+      for (; i < a.length(); i++) {
+        byte aByte = a.tail(i);
+        byte bByte = b.tail(i);
+        res += (aByte & 0x0F) * (bByte & 0x0F);
+        res += ((aByte & 0xFF) >> 4) * ((bByte & 0xFF) >> 4);

Review Comment:
   Ah, hadn't spotted that. Your call then.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to