jtibshirani commented on a change in pull request #416: URL: https://github.com/apache/lucene/pull/416#discussion_r783342986
########## File path: lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswGraph.java ########## @@ -116,33 +116,59 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { ((CodecReader) ctx.reader()).getVectorReader()) .getFieldReader("field")) .getGraphValues("field"); - assertGraphEqual(hnsw, graphValues, nVec); + assertGraphEqual(hnsw, graphValues); } } } } + private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h) throws IOException { + assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); + assertEquals("the number of nodes in the graphs are different!", g.size(), h.size()); + + // assert equal nodes on each level + for (int level = 0; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = g.getNodesOnLevel(level); + NodesIterator nodesOnLevel2 = h.getNodesOnLevel(level); + while (nodesOnLevel.hasNext() && nodesOnLevel2.hasNext()) { + int node = nodesOnLevel.nextInt(); + int node2 = nodesOnLevel2.nextInt(); + assertEquals("nodes in the graphs are different", node, node2); + } + } + + // assert equal nodes' neighbours on each level + for (int level = 0; level < g.numLevels(); level++) { + NodesIterator nodesOnLevel = g.getNodesOnLevel(level); + while (nodesOnLevel.hasNext()) { + int node = nodesOnLevel.nextInt(); + g.seek(level, node); + h.seek(level, node); + assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); + } + } + } + // Make sure we actually approximately find the closest k elements. Mostly this is about // ensuring that we have all the distance functions, comparators, priority queues and so on // oriented in the right directions public void testAknnDiverse() throws IOException { + int maxConn = 10; int nDoc = 100; - CircularVectorValues vectors = new CircularVectorValues(nDoc); + TestHnswGraph.CircularVectorValues vectors = new TestHnswGraph.CircularVectorValues(nDoc); Review comment: Tiny comment, don't need these qualifiers ########## File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java ########## @@ -56,75 +57,124 @@ public final class HnswGraph extends KnnGraphValues { private final int maxConn; + private int numLevels; // the current number of levels in the graph + private int entryNode; // the current graph entry node on the top level - // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to - // HnswBuilder, and the - // node values are the ordinals of those vectors. - private final List<NeighborArray> graph; + // Nodes by level expressed as the level 0's nodes' ordinals. + // As level 0 contains all nodes, nodesByLevel.get(0) is null. + private final List<int[]> nodesByLevel; + + // graph is a list of graph levels. + // Each level is represented as List<NeighborArray> – nodes' connections on this level. + // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors + // added to HnswBuilder, and the node values are the ordinals of those vectors. + // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. + private final List<List<NeighborArray>> graph; // KnnGraphValues iterator members private int upto; private NeighborArray cur; - HnswGraph(int maxConn) { - graph = new ArrayList<>(); - // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be - // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM - graph.add(new NeighborArray(Math.max(32, maxConn / 4))); + HnswGraph(int maxConn, int levelOfFirstNode) { this.maxConn = maxConn; + this.numLevels = levelOfFirstNode + 1; + this.graph = new ArrayList<>(numLevels); + this.entryNode = 0; + for (int i = 0; i < numLevels; i++) { + graph.add(new ArrayList<>()); + // Typically with diversity criteria we see nodes not fully occupied; + // average fanout seems to be about 1/2 maxConn. + // There is some indexing time penalty for under-allocating, but saves RAM + graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4))); + } + + this.nodesByLevel = new ArrayList<>(numLevels); + nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes + for (int l = 1; l < numLevels; l++) { + nodesByLevel.add(new int[] {0}); + } } /** - * Searches for the nearest neighbors of a query vector. + * Searches HNSW graph for the nearest neighbors of a query vector. * * @param query search query vector * @param topK the number of nodes to be returned - * @param numSeed the size of the queue maintained while searching, and controls the number of - * random entry points to sample * @param vectors vector values * @param graphValues the graph values. May represent the entire graph, or a level in a * hierarchical graph. * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. - * @param random a source of randomness, used for generating entry points to the graph * @return a priority queue holding the closest neighbors found */ public static NeighborQueue search( float[] query, int topK, - int numSeed, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, KnnGraphValues graphValues, - Bits acceptOrds, - SplittableRandom random) + Bits acceptOrds) throws IOException { + int size = graphValues.size(); + int queueSize = Math.min(topK, 2 * size); + NeighborQueue results; + + int[] eps = new int[] {graphValues.entryNode()}; + for (int level = graphValues.numLevels() - 1; level >= 1; level--) { + results = searchLevel(query, 1, level, eps, vectors, similarityFunction, graphValues, null); + eps[0] = results.pop(); + } + results = + searchLevel(query, queueSize, 0, eps, vectors, similarityFunction, graphValues, acceptOrds); + return results; + } + + /** + * Searches for the nearest neighbors of a query vector in a given level + * + * @param query search query vector + * @param topK the number of nearest to query results to return + * @param level level to search + * @param eps the entry points for search at this level expressed as level 0th ordinals + * @param vectors vector values + * @param similarityFunction similarity function + * @param graphValues the graph values + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + * @return a priority queue holding the closest neighbors found + */ + static NeighborQueue searchLevel( + float[] query, + int topK, + int level, + final int[] eps, + RandomAccessVectorValues vectors, + VectorSimilarityFunction similarityFunction, + KnnGraphValues graphValues, + Bits acceptOrds) + throws IOException { + int size = graphValues.size(); + int queueSize = Math.max(eps.length, topK); Review comment: Same question here, do we need `queueSize` or could it just be `topK`? From the paper it looks like `eps.length` will always be bounded by `topK`? ########## File path: lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java ########## @@ -279,6 +303,77 @@ public void testSearch() throws Exception { } } + private void indexData(IndexWriter iw) throws IOException { + // Add a document for every cartesian point in an NxN square so we can + // easily know which are the nearest neighbors to every point. Insert by iterating + // using a prime number that is not a divisor of N*N so that we will hit each point once, + // and chosen so that points will be inserted in a deterministic + // but somewhat distributed pattern + int n = 5, stepSize = 17; + float[][] values = new float[n * n][]; + int index = 0; + for (int i = 0; i < values.length; i++) { + // System.out.printf("%d: (%d, %d)\n", i, index % n, index / n); + int x = index % n, y = index / n; + values[i] = new float[] {x, y}; + index = (index + stepSize) % (n * n); + add(iw, i, values[i]); + if (i == 13) { + // create 2 segments + iw.commit(); + } + } + boolean forceMerge = random().nextBoolean(); + if (forceMerge) { + iw.forceMerge(1); + } + assertConsistentGraph(iw, values); + } + + public void testMultiThreadedSearch() throws Exception { + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + IndexWriterConfig config = newIndexWriterConfig(); + config.setCodec(codec); + Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, config); + indexData(iw); + + final SearcherManager manager = new SearcherManager(iw, new SearcherFactory()); + Thread[] threads = new Thread[randomIntBetween(2, 5)]; + final CountDownLatch latch = new CountDownLatch(1); + for (int i = 0; i < threads.length; i++) { + threads[i] = + new Thread( + () -> { + try { + latch.await(); + IndexSearcher searcher = manager.acquire(); + try { + KnnVectorQuery query = new KnnVectorQuery("vector", new float[] {0f, 0.1f}, 5); Review comment: Tiny comment, should we use `assertGraphSearch` or `doKnnSearch` here like we do elsewhere in the test? ########## File path: lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java ########## @@ -153,21 +161,56 @@ public void testMergeProducesSameGraph() throws Exception { int dimension = atLeast(10); float[][] values = randomVectors(numDoc, dimension); int mergePoint = random().nextInt(numDoc); - int[][] mergedGraph = getIndexedGraph(values, mergePoint, seed); - int[][] singleSegmentGraph = getIndexedGraph(values, -1, seed); + int[][][] mergedGraph = getIndexedGraph(values, mergePoint, seed); + int[][][] singleSegmentGraph = getIndexedGraph(values, -1, seed); assertGraphEquals(singleSegmentGraph, mergedGraph); } - private void assertGraphEquals(int[][] expected, int[][] actual) { + /** Test writing and reading of multiple vector fields * */ + public void testMultipleVectorFields() throws Exception { + int numVectorFields = randomIntBetween(2, 5); + int numDoc = atLeast(100); + int[] dims = new int[numVectorFields]; + float[][][] values = new float[numVectorFields][][]; + for (int field = 0; field < numVectorFields; field++) { + dims[field] = atLeast(3); + values[field] = randomVectors(numDoc, dims[field]); + } + + try (Directory dir = newDirectory(); + IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) { + for (int docID = 0; docID < numDoc; docID++) { + Document doc = new Document(); + for (int field = 0; field < numVectorFields; field++) { + float[] vector = values[field][docID]; + if (vector != null) { + FieldType fieldType = KnnVectorField.createFieldType(vector.length, similarityFunction); + doc.add(new KnnVectorField(KNN_GRAPH_FIELD + field, vector, fieldType)); + } + } + String idString = Integer.toString(docID); + doc.add(new StringField("id", idString, Field.Store.YES)); + iw.addDocument(doc); + } + for (int field = 0; field < numVectorFields; field++) { + assertConsistentGraph(iw, values[field], KNN_GRAPH_FIELD + field); + } + } + } + + private void assertGraphEquals(int[][][] expected, int[][][] actual) { assertEquals("graph sizes differ", expected.length, actual.length); - for (int i = 0; i < expected.length; i++) { - assertArrayEquals("difference at ord=" + i, expected[i], actual[i]); + for (int level = 0; level < expected.length; level++) { + for (int node = 0; node < expected[level].length; node++) { + assertArrayEquals("difference at ord=" + node, expected[level][node], actual[level][node]); + } } } - private int[][] getIndexedGraph(float[][] values, int mergePoint, long seed) throws IOException { + private int[][][] getIndexedGraph(float[][] values, int mergePoint, long seed) Review comment: Super small comment, it'd be nice if javadoc described what was returned here (it's getting to be hard to follow with three array dimensions). ########## File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java ########## @@ -40,10 +41,10 @@ * <h2>Hyperparameters</h2> * * <ul> - * <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2012 paper; it controls the + * <li><code>numSeed</code> is the equivalent of <code>m</code> in the 2014 paper; it controls the Review comment: Should we update this javadoc comment now that we implement the algorithm from the 2018 paper? We could remove the link to the 2014 paper and this reference to `numSeed`? ########## File path: lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java ########## @@ -56,75 +57,124 @@ public final class HnswGraph extends KnnGraphValues { private final int maxConn; + private int numLevels; // the current number of levels in the graph + private int entryNode; // the current graph entry node on the top level - // Each entry lists the top maxConn neighbors of a node. The nodes correspond to vectors added to - // HnswBuilder, and the - // node values are the ordinals of those vectors. - private final List<NeighborArray> graph; + // Nodes by level expressed as the level 0's nodes' ordinals. + // As level 0 contains all nodes, nodesByLevel.get(0) is null. + private final List<int[]> nodesByLevel; + + // graph is a list of graph levels. + // Each level is represented as List<NeighborArray> – nodes' connections on this level. + // Each entry in the list has the top maxConn neighbors of a node. The nodes correspond to vectors + // added to HnswBuilder, and the node values are the ordinals of those vectors. + // Thus, on all levels, neighbors expressed as the level 0's nodes' ordinals. + private final List<List<NeighborArray>> graph; // KnnGraphValues iterator members private int upto; private NeighborArray cur; - HnswGraph(int maxConn) { - graph = new ArrayList<>(); - // Typically with diversity criteria we see nodes not fully occupied; average fanout seems to be - // about 1/2 maxConn. There is some indexing time penalty for under-allocating, but saves RAM - graph.add(new NeighborArray(Math.max(32, maxConn / 4))); + HnswGraph(int maxConn, int levelOfFirstNode) { this.maxConn = maxConn; + this.numLevels = levelOfFirstNode + 1; + this.graph = new ArrayList<>(numLevels); + this.entryNode = 0; + for (int i = 0; i < numLevels; i++) { + graph.add(new ArrayList<>()); + // Typically with diversity criteria we see nodes not fully occupied; + // average fanout seems to be about 1/2 maxConn. + // There is some indexing time penalty for under-allocating, but saves RAM + graph.get(i).add(new NeighborArray(Math.max(32, maxConn / 4))); + } + + this.nodesByLevel = new ArrayList<>(numLevels); + nodesByLevel.add(null); // we don't need this for 0th level, as it contains all nodes + for (int l = 1; l < numLevels; l++) { + nodesByLevel.add(new int[] {0}); + } } /** - * Searches for the nearest neighbors of a query vector. + * Searches HNSW graph for the nearest neighbors of a query vector. * * @param query search query vector * @param topK the number of nodes to be returned - * @param numSeed the size of the queue maintained while searching, and controls the number of - * random entry points to sample * @param vectors vector values * @param graphValues the graph values. May represent the entire graph, or a level in a * hierarchical graph. * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. - * @param random a source of randomness, used for generating entry points to the graph * @return a priority queue holding the closest neighbors found */ public static NeighborQueue search( float[] query, int topK, - int numSeed, RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction, KnnGraphValues graphValues, - Bits acceptOrds, - SplittableRandom random) + Bits acceptOrds) throws IOException { + int size = graphValues.size(); + int queueSize = Math.min(topK, 2 * size); Review comment: Do we need a separate `queueSize` parameter or can this be `topK`? We already bound `topK` by the number of vectors in `Lucene90HsnwVectorsReader#search`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org