gf2121 opened a new issue, #12721:
URL: https://github.com/apache/lucene/issues/12721

   ### Description
   
   An immature idea ! :)
   
   I noticed that `BPIndexReorderer$ComputeGainsTask#computeGain()` took a lot 
in CPU profile:
   
   ```
   PERCENT       CPU SAMPLES   STACK
   4.75%         53042         
org.apache.lucene.misc.index.BPIndexReorderer$ComputeGainsTask#computeGain()
   4.47%         49905         
org.apache.lucene.analysis.standard.StandardTokenizerImpl#getNextToken()
   3.87%         43194         
org.apache.lucene.index.FreqProxTermsWriter$SortingPostingsEnum#reset()
   3.67%         40975         
org.apache.lucene.index.TermsHashPerField#writeByte()
   3.59%         40003         org.apache.lucene.util.BytesRefHash#equals()
   3.55%         39588         
org.apache.lucene.store.MemorySegmentIndexInput#readByte()
   2.80%         31197         
org.apache.lucene.util.ByteBlockPool#setBytesRef()
   2.52%         28154         java.io.BufferedOutputStream#write()
   2.33%         25978         org.apache.lucene.util.BytesRefHash#findHash()
   2.32%         25896         
org.apache.lucene.codecs.lucene90.Lucene90PostingsWriter#startDoc()
   2.24%         24966         org.apache.lucene.util.Sorter#binarySort()
   2.23%         24864         org.apache.lucene.store.DataInput#readVInt()
   2.10%         23484         
org.apache.lucene.misc.index.BPIndexReorderer$ForwardIndex#seek()
   2.04%         22808         
org.apache.lucene.codecs.PushPostingsWriterBase#writeTerm()
   ```
   
   I tried to implement this with vector api, the result looks good on AVX-512 
(Java 21):
   ```
   Benchmark              (maxTerm)  (termsNum)   Mode  Cnt  Score   Error   
Units
   GainBenchmark.gainNew       1024         128  thrpt    5  3.471 ± 0.142  
ops/us
   GainBenchmark.gainOld       1024         128  thrpt    5  2.031 ± 0.432  
ops/us
   ```
   
   
   <details><summary> Code </summary>
   
   ```
   package testing;
   
   import jdk.incubator.vector.FloatVector;
   import jdk.incubator.vector.IntVector;
   import jdk.incubator.vector.Vector;
   import jdk.incubator.vector.VectorMask;
   import jdk.incubator.vector.VectorOperators;
   import jdk.incubator.vector.VectorSpecies;
   import org.openjdk.jmh.annotations.*;
   
   import java.util.Arrays;
   import java.util.concurrent.ThreadLocalRandom;
   import java.util.concurrent.TimeUnit;
   
   @BenchmarkMode(Mode.Throughput)
   @OutputTimeUnit(TimeUnit.MICROSECONDS)
   @State(Scope.Benchmark)
   @Warmup(iterations = 3, time = 3)
   @Measurement(iterations = 5, time = 3)
   @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
   public class GainBenchmark {
     private static final VectorSpecies<Integer> INT_SPECIES = 
IntVector.SPECIES_PREFERRED;
     private static final VectorSpecies<Float> FLOAT_SPECIES = 
FloatVector.SPECIES_PREFERRED;
   
     private int[] terms;
     private int[] fromFreqs;
     private int[] toFreqs;
     private final int[] indicesBuffer = new int[128];
     @Param({"128"})
     int termsNum;
     @Param({"1024"})
     int maxTerm;
   
     private static final float[] LOG2_TABLE = new float[256];
   
     static {
       LOG2_TABLE[0] = 1f;
       // float that has the biased exponent of 1f and zeros for sign and 
mantissa bits
       final int one = Float.floatToIntBits(1f);
       for (int i = 0; i < 256; ++i) {
         float f = Float.intBitsToFloat(one | (i << (23 - 8)));
         LOG2_TABLE[i] = (float) (Math.log(f) / Math.log(2));
       }
     }
   
     @Setup(Level.Trial)
     public void init() {
       terms = new int[termsNum];
       for (int i=0; i < termsNum; i++) {
         terms[i] = ThreadLocalRandom.current().nextInt(maxTerm);
       }
       fromFreqs = new int[maxTerm];
       toFreqs = new int[maxTerm];
       for (int i = 0; i < maxTerm; i++) {
         fromFreqs[i] = ThreadLocalRandom.current().nextInt(16);
         toFreqs[i] = ThreadLocalRandom.current().nextInt(16);
       }
       float o = gainOld();
       float n = gainNew();
       if (Math.abs(o - n) > 1E-4f) {
         throw new RuntimeException("New is wrong, old: " + o + ", new: " + n);
       }
     }
   
     @Benchmark
     public float gainOld() {
       float gain = 0;
       for (int i = 0; i < terms.length; ++i) {
         final int termID = terms[i];
         final int fromDocFreq = fromFreqs[termID];
         final int toDocFreq = toFreqs[termID];
         assert fromDocFreq >= 0;
         assert toDocFreq >= 0;
         gain +=
             (toDocFreq == 0 ? 0 : fastLog2(toDocFreq))
                 - (fromDocFreq == 0 ? 0 : fastLog2(fromDocFreq));
       }
       return gain;
     }
   
     @Benchmark
     public float gainNew() {
       final int upperBound = INT_SPECIES.loopBound(termsNum);
       int i = 0;
       FloatVector acc = FloatVector.zero(FLOAT_SPECIES);
       for (; i < upperBound; i += INT_SPECIES.length()) {
         IntVector fromVector = IntVector.fromArray(INT_SPECIES, fromFreqs, 0, 
terms, i);
         IntVector toVector = IntVector.fromArray(INT_SPECIES, toFreqs, 0, 
terms, i);
         Vector<Float> gainVector = 
fastLog2(toVector).sub(fastLog2(fromVector));
         acc = acc.add(gainVector);
   
       }
       float gain = acc.reduceLanes(VectorOperators.ADD);
   
       for ( ; i < terms.length; ++i) {
         final int termID = terms[i];
         final int fromDocFreq = fromFreqs[termID];
         final int toDocFreq = toFreqs[termID];
         assert fromDocFreq >= 0;
         assert toDocFreq >= 0;
         gain +=
             (toDocFreq == 0 ? 0 : fastLog2(toDocFreq))
                 - (fromDocFreq == 0 ? 0 : fastLog2(fromDocFreq));
       }
       return gain;
     }
   
     private static final IntVector BROAD_31 = IntVector.broadcast(INT_SPECIES, 
31);
     private static final IntVector BROAD_32 = IntVector.broadcast(INT_SPECIES, 
32);
     private static final IntVector BROAD_24 = IntVector.broadcast(INT_SPECIES, 
24);
     private static final FloatVector BROAD_0F = 
FloatVector.broadcast(FLOAT_SPECIES, 0);
   
     private Vector<Float> fastLog2(IntVector intVector) {
       VectorMask<Integer> mask = intVector.test(VectorOperators.IS_DEFAULT);
       IntVector floorLog2 = 
BROAD_31.sub(intVector.lanewise(VectorOperators.LEADING_ZEROS_COUNT));
       IntVector tableIndices = intVector
           .lanewise(VectorOperators.LSHL, BROAD_32.sub(floorLog2))
           .lanewise(VectorOperators.LSHR, BROAD_24);
       tableIndices.intoArray(indicesBuffer, 0);
       return floorLog2.convert(VectorOperators.I2F, 0)
           .add(FloatVector.fromArray(FLOAT_SPECIES, LOG2_TABLE, 0, 
indicesBuffer, 0))
           .blend(BROAD_0F, mask.cast(FLOAT_SPECIES));
     }
   
     static float fastLog2(int i) {
       assert i > 0 : "Cannot compute log of i=" + i;
       int floorLog2 = 31 - Integer.numberOfLeadingZeros(i);
       int tableIndex = i << (32 - floorLog2) >>> (32 - 8);
       return floorLog2 + LOG2_TABLE[tableIndex];
     }
   }
   ```
   
   </details>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to