jpountz commented on code in PR #12434: URL: https://github.com/apache/lucene/pull/12434#discussion_r1269632250
########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java: ########## @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * KnnResults is a collector for gathering kNN results and providing topDocs from the gathered + * neighbors + */ +public abstract class KnnResults { + + /** KnnResults when exiting search early and returning empty top docs */ + static class EmptyKnnResults extends KnnResults { + public EmptyKnnResults(int visitedCount) { + this.visitedCount = visitedCount; + } + + @Override + public void doClear() {} + + @Override + public void collect(int vectorId, float similarity) { + throw new IllegalArgumentException(); + } + + @Override + public boolean collectWithOverflow(int vectorId, float similarity) { + return false; + } + + @Override + public boolean isFull() { + return true; + } + + @Override + public float minSimilarity() { + return 0; + } + + @Override + public TopDocs topDocs() { + TotalHits th = new TotalHits(visitedCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + return new TopDocs(th, new ScoreDoc[0]); + } + } + + protected int visitedCount; + private boolean incomplete; + + final void clear() { + this.visitedCount = 0; + this.incomplete = false; + doClear(); + } + + /** Clear the current results. */ + abstract void doClear(); + + /** + * @return is the current result set marked as incomplete? + */ + final boolean incomplete() { + return incomplete; + } + + /** Mark the current result set as incomplete */ + final void markIncomplete() { + this.incomplete = true; + } + + /** + * @param count set the current visited count to the provided value + */ + final void setVisitedCount(int count) { + this.visitedCount = count; + } + + /** + * @return the current visited count + */ + final int visitedCount() { + return visitedCount; + } + + /** + * Collect the provided vectorId and include in the result set. + * + * @param vectorId the vector to collect + * @param similarity its calculated similarity + */ + abstract void collect(int vectorId, float similarity); + + /** + * @param vectorId the vector to collect + * @param similarity its calculated similarity + * @return true if the vector is collected + */ + abstract boolean collectWithOverflow(int vectorId, float similarity); Review Comment: we seem to be doing a `collectWithOverflow` when the result set is full, so maybe we could have a single `collect` method and automatically do the right thing depending on whether it's full or not? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResultsProvider.java: ########## @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** knn results provider */ +public interface KnnResultsProvider { + KnnResults getKnnResults(IntToIntFunction vectorToOrd); Review Comment: This vectorToOrd argument feels specific to the current implementation. Could the API work on doc IDs? ########## lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java: ########## @@ -117,6 +118,85 @@ public abstract TopDocs search( */ public abstract TopDocs search( String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + * <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + * <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in + * order of their similarity to the query vector (decreasing scores). The {@link TotalHits} + * contains the number of documents visited during the search. If the search stopped early because + * it hit {@code visitedLimit}, it is indicated through the relation {@code + * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}. + * + * <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link + * FieldInfo}. The return value is never {@code null}. + * + * @param field the vector field to search + * @param target the vector-valued query + * @param knnResultsProvider a provider that returns a KnnResults collector and topK for gathering + * the vector results + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + * @param visitedLimit the maximum number of nodes that the search is allowed to visit + * @return the k nearest neighbor documents, along with their (similarity-specific) scores. + */ + public TopDocs search( + String field, + float[] target, + KnnResultsProvider knnResultsProvider, + Bits acceptDocs, + int visitedLimit) Review Comment: If we're adding this more flexible way of collecting KNN hits, it would be nice if visitedLimit could be handled directly by the `KnnResultsProvider` instead of being a separate parameter. Handling `acceptDocs` via this new interface crossed my mind too though I'm less sure about it since collectors don't need to handle deleted docs and it might be nice to have consistency between collectors and this new abstraction you're adding, which looks similar. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ToParentJoinKnnResults.java: ########## @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; + +/** parent joining knn results, vectorIds are deduplicated according to the parent bit set. */ +public class ToParentJoinKnnResults extends KnnResults { + + /** provider class for creating a new {@link ToParentJoinKnnResults} */ + public static class Provider implements KnnResultsProvider { + + private final int k; + private final BitSet parentBitSet; + + public Provider(int k, BitSet parentBitSet) { + this.k = k; + this.parentBitSet = parentBitSet; + } + + @Override + public int k() { + return k; + } + + @Override + public KnnResults getKnnResults(IntToIntFunction vectorToOrd) { + return new ToParentJoinKnnResults(k, parentBitSet, vectorToOrd); + } + } + + private final BitSet parentBitSet; + private final int k; + private final IntToIntFunction vectorToOrd; + private final NodeIdCachingHeap heap; + + public ToParentJoinKnnResults(int k, BitSet parentBitSet, IntToIntFunction vectorToOrd) { + this.parentBitSet = parentBitSet; + this.k = k; + this.vectorToOrd = vectorToOrd; + this.heap = new NodeIdCachingHeap(k); + } + + /** + * Adds a new graph arc, extending the storage as needed. + * + * <p>If the provided childNodeId's parent has been previously collected and the nodeScore is less + * than the previously stored score, this node will not be added to the collection. + * + * @param childNodeId the neighbor node id + * @param nodeScore the score of the neighbor, relative to some other node + */ + @Override + public void collect(int childNodeId, float nodeScore) { + childNodeId = vectorToOrd.apply(childNodeId); + assert !parentBitSet.get(childNodeId); + int nodeId = parentBitSet.nextSetBit(childNodeId); + heap.push(nodeId, nodeScore); + } + + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + * <p>If childNodeId's parent node has previously been collected and the provided nodeScore is + * less than the stored score it will not be collected. + * + * @param childNodeId the neighbor node id + * @param nodeScore the score of the neighbor, relative to some other node + */ + @Override + public boolean collectWithOverflow(int childNodeId, float nodeScore) { + // Parent and child nodes should be disjoint sets parent bit set should never have a child node + // ID present + childNodeId = vectorToOrd.apply(childNodeId); + assert !parentBitSet.get(childNodeId); + int nodeId = parentBitSet.nextSetBit(childNodeId); + return heap.insertWithOverflow(nodeId, nodeScore); + } + + @Override + public boolean isFull() { + return heap.size >= k; + } + + @Override + public float minSimilarity() { + return heap.topScore(); + } + + @Override + public void doClear() { + heap.clear(); + } + + @Override + public String toString() { + return "ToParentJoinKnnResults[" + heap.size + "]"; + } + + @Override + public TopDocs topDocs() { + while (heap.size() > k) { + heap.popToDrain(); + } + int i = 0; + ScoreDoc[] scoreDocs = new ScoreDoc[heap.size()]; + while (i < scoreDocs.length) { + int node = heap.topNode(); + float score = heap.topScore(); + heap.popToDrain(); + scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(node, score); + } + + TotalHits.Relation relation = + incomplete() ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); Review Comment: Using the `visitedCount()` doesn't feel correct, as it'd be counting nested docs rather than unique parent docs? It doesn't feel super useful either, maybe we should always set it to 1 (and TotalHits.RELATION.GREATER_THAN_OR_EQUAL_TO)? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResultsProvider.java: ########## @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** knn results provider */ +public interface KnnResultsProvider { Review Comment: I wonder if we need this indirection with KnnResultsProvider/KnnResults, or if we could just have KnnResults? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org