kaivalnp commented on code in PR #14863:
URL: https://github.com/apache/lucene/pull/14863#discussion_r2253097713


##########
lucene/core/src/java24/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java:
##########
@@ -530,7 +566,41 @@ private int dotProductBody512Int4Packed(byte[] unpacked, 
byte[] packed, int limi
     return sum;
   }
 
-  private int dotProductBody256Int4Packed(byte[] unpacked, byte[] packed, int 
limit) {
+  private static int dotProductBody512Int4PackedPacked(
+      ByteVectorLoader a, ByteVectorLoader b, int limit) {
+    int sum = 0;
+    // iterate in chunks of 1024 items to ensure we don't overflow the short 
accumulator
+    for (int i = 0; i < limit; i += 4096) {
+      ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_512);
+      ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_512);
+      int innerLimit = Math.min(limit - i, 4096);
+      for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
+        // packed
+        var vb8 = b.load(ByteVector.SPECIES_256, i + j);
+        // packed
+        var va8 = a.load(ByteVector.SPECIES_256, i + j);
+
+        // upper
+        ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F));
+        Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, 
ShortVector.SPECIES_512, 0);

Review Comment:
   Ah, I see what you mean..
   
   On my machine, the preferred bit size was 256 so I changed the following 
function:
   ```java
     private static int dotProductBody256Int4PackedPacked(
         ByteVectorLoader a, ByteVectorLoader b, int limit) {
       int sum = 0;
       // iterate in chunks of 1024 items to ensure we don't overflow the short 
accumulator
       for (int i = 0; i < limit; i += 2048) {
         ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256);
         ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256);
         int innerLimit = Math.min(limit - i, 2048);
         for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) {
           // packed
           var vb8 = b.load(ByteVector.SPECIES_128, i + j);
           // packed
           var va8 = a.load(ByteVector.SPECIES_128, i + j);
   
           // upper
           ByteVector prod8 = vb8.and((byte) 0x0F).mul(va8.and((byte) 0x0F));
           Vector<Short> prod16 = prod8.convertShape(ZERO_EXTEND_B2S, 
ShortVector.SPECIES_256, 0);
           acc0 = acc0.add(prod16);
   
           // lower
           ByteVector prod8a = vb8.lanewise(LSHR, 4).mul(va8.lanewise(LSHR, 4));
           Vector<Short> prod16a = prod8a.convertShape(ZERO_EXTEND_B2S, 
ShortVector.SPECIES_256, 0);
           acc1 = acc1.add(prod16a);
         }
         IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
       }
       return sum;
     }
   ```
   
   ..to:
   ```java
     private static int dotProductBody256Int4PackedPacked(
         ByteVectorLoader a, ByteVectorLoader b, int limit) {
       int sum = 0;
       // iterate in chunks of 1024 items to ensure we don't overflow the short 
accumulator
       for (int i = 0; i < limit; i += 2048) {
         ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_256);
         ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_256);
         ShortVector acc2 = ShortVector.zero(ShortVector.SPECIES_256);
         ShortVector acc3 = ShortVector.zero(ShortVector.SPECIES_256);
         int innerLimit = Math.min(limit - i, 2048);
         for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_256.length()) {
           // packed
           var sb8 = b.load(ByteVector.SPECIES_256, i + 
j).reinterpretAsShorts();
           // packed
           var sa8 = a.load(ByteVector.SPECIES_256, i + 
j).reinterpretAsShorts();
   
           // s1
           ShortVector prod8a = sb8.and((short) 0x000F).mul(sa8.and((short) 
0x00F));
           acc0 = acc0.add(prod8a);
   
           // s2
           ShortVector prod8b =
               sb8.lanewise(LSHR, 4).and((short) 0x000F).mul(sa8.lanewise(LSHR, 
4).and((short) 0x000F));
           acc1 = acc1.add(prod8b);
   
           // s3
           ShortVector prod8c =
               sb8.lanewise(LSHR, 8).and((short) 0x000F).mul(sa8.lanewise(LSHR, 
8).and((short) 0x000F));
           acc2 = acc2.add(prod8c);
   
           // s4
           ShortVector prod8d =
               sb8.lanewise(LSHR, 12)
                   .and((short) 0x000F)
                   .mul(sa8.lanewise(LSHR, 12).and((short) 0x000F));
           acc3 = acc3.add(prod8d);
         }
         IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         IntVector intAcc4 = acc2.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc5 = acc2.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         IntVector intAcc6 = acc3.convertShape(S2I, IntVector.SPECIES_256, 
0).reinterpretAsInts();
         IntVector intAcc7 = acc3.convertShape(S2I, IntVector.SPECIES_256, 
1).reinterpretAsInts();
         sum +=
             intAcc0
                 .add(intAcc1)
                 .add(intAcc2)
                 .add(intAcc3)
                 .add(intAcc4)
                 .add(intAcc5)
                 .add(intAcc6)
                 .add(intAcc7)
                 .reduceLanes(ADD);
       }
       return sum;
     }
   ```
   
   - I basically read 256 bits worth of `short` values, extracted 4x 
"half-bytes" per `short` using a combination of `LSHR` and `and 0x000F`
   - Since we have 4x short accumulators, they'll map to 8x int vectors on 
expansion
   - Note that I also changed the limit above (`i += 
ByteVector.SPECIES_256.loopBound(a.length());`) to account for reading 256 bits 
at the same time
   
   Is this what you meant?
   
   Baseline (this PR):
   ```
   VectorUtilBenchmark.binaryHalfByteVectorPackedPacked    1024  thrpt   15  
13.427 ± 0.031  ops/us
   ```
   
   Modification:
   ```
   VectorUtilBenchmark.binaryHalfByteVectorPackedPacked    1024  thrpt   15  
13.213 ± 0.039  ops/us
   ```
   
   There doesn't seem to be much difference between the two..



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to