nitirajrathore commented on a change in pull request #83:
URL: https://github.com/apache/lucene/pull/83#discussion_r616404025



##########
File path: lucene/test-framework/src/java/org/apache/lucene/util/FullKnn.java
##########
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.apache.lucene.index.VectorValues;
+
+/**
+ * A utility class to calculate the Full KNN / Exact KNN over a set of query 
vectors and document
+ * vectors.
+ */
+public class FullKnn {
+
+  private final int dim;
+  private final int topK;
+  private final VectorValues.SearchStrategy searchStrategy;
+  private final boolean quiet;
+
+  public FullKnn(int dim, int topK, VectorValues.SearchStrategy 
searchStrategy, boolean quiet) {
+    this.dim = dim;
+    this.topK = topK;
+    this.searchStrategy = searchStrategy;
+    this.quiet = quiet;
+  }
+
+  /** internal object to track KNN calculation for one query */
+  private static class KnnJob {
+    public int currDocIndex;
+    float[] queryVector;
+    float[] currDocVector;
+    int queryIndex;
+    private LongHeap queue;
+    FloatBuffer docVectors;
+    VectorValues.SearchStrategy searchStrategy;
+
+    public KnnJob(
+        int queryIndex, float[] queryVector, int topK, 
VectorValues.SearchStrategy searchStrategy) {
+      this.queryIndex = queryIndex;
+      this.queryVector = queryVector;
+      this.currDocVector = new float[queryVector.length];
+      if (searchStrategy.reversed) {
+        queue = LongHeap.create(LongHeap.Order.MAX, topK);
+      } else {
+        queue = LongHeap.create(LongHeap.Order.MIN, topK);
+      }
+      this.searchStrategy = searchStrategy;
+    }
+
+    public void execute() {
+      while (this.docVectors.hasRemaining()) {
+        this.docVectors.get(this.currDocVector);
+        float d = this.searchStrategy.compare(this.queryVector, 
this.currDocVector);
+        this.queue.insertWithOverflow(encodeNodeIdAndScore(this.currDocIndex, 
d));
+        this.currDocIndex++;
+      }
+    }
+  }
+
+  /**
+   * computes the exact KNN match for each query vector in queryPath for all 
the document vectors in
+   * docPath
+   *
+   * @param docPath : path to the file containing the float 32 document 
vectors in bytes with
+   *     little-endian byte order
+   * @param queryPath : path to the file containing the containing 32-bit 
floating point vectors in
+   *     little-endian byte order
+   * @param numThreads : create numThreads to parallelize work
+   * @return : returns an int 2D array ( int matches[][]) of size 'numIters x 
topK'. matches[i] is
+   *     an array containing the indexes of the topK most similar document 
vectors to the ith query
+   *     vector, and is sorted by similarity, with the most similar vector 
first. Similarity is
+   *     defined by the searchStrategy used to construct this FullKnn.
+   * @throws IllegalArgumentException : if topK is greater than number of 
documents in docPath file
+   *     IOException : In case of IO exception while reading files.
+   */
+  public int[][] computeNN(Path docPath, Path queryPath, int numThreads) 
throws IOException {
+    assert numThreads > 0;
+    final int numDocs = (int) (Files.size(docPath) / (dim * Float.BYTES));
+    final int numQueries = (int) (Files.size(docPath) / (dim * Float.BYTES));
+
+    if (!quiet) {
+      System.out.println(
+          "computing true nearest neighbors of "
+              + numQueries
+              + " target vectors using "
+              + numThreads
+              + " threads.");
+    }
+
+    try (FileChannel docInput = FileChannel.open(docPath);
+        FileChannel queryInput = FileChannel.open(queryPath)) {
+      return doFullKnn(
+          numDocs,
+          numQueries,
+          numThreads,
+          new FileChannelBufferProvider(docInput),
+          new FileChannelBufferProvider(queryInput));
+    }
+  }
+
+  int[][] doFullKnn(
+      int numDocs,
+      int numQueries,
+      int numThreads,
+      BufferProvider docInput,
+      BufferProvider queryInput)
+      throws IOException {
+    if (numDocs < topK) {
+      throw new IllegalArgumentException(
+          String.format(
+              Locale.ROOT,
+              "topK (%d) cannot be greater than number of docs in docPath 
(%d)",
+              topK,
+              numDocs));
+    }
+
+    final ExecutorService executorService =
+        Executors.newFixedThreadPool(numThreads, new 
NamedThreadFactory("FullKnnExecutor"));
+    int[][] result = new int[numQueries][];
+
+    FloatBuffer queries = queryInput.getBuffer(0, numQueries * dim * 
Float.BYTES).asFloatBuffer();
+    float[] query = new float[dim];
+    List<KnnJob> jobList = new ArrayList<>(numThreads);
+    for (int i = 0; i < numQueries; ) {
+
+      for (int j = 0; j < numThreads && i < numQueries; i++, j++) {
+        queries.get(query);
+        jobList.add(
+            new KnnJob(i, ArrayUtil.copyOfSubArray(query, 0, query.length), 
topK, searchStrategy));
+      }
+
+      long maxBufferSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * 
Float.BYTES);
+      int docsLeft = numDocs;
+      int currDocIndex = 0;
+      int offset = 0;
+      while (docsLeft > 0) {
+        long totalBytes = (long) docsLeft * dim * Float.BYTES;
+        int blockSize = (int) Math.min(totalBytes, maxBufferSize);
+
+        FloatBuffer docVectors = docInput.getBuffer(offset, 
blockSize).asFloatBuffer();
+        offset += blockSize;
+
+        final List<CompletableFuture<Void>> completableFutures =
+            jobList.stream()
+                .peek(job -> job.docVectors = docVectors.duplicate())

Review comment:
       converted to for loop.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to