benwtrent commented on code in PR #12434: URL: https://github.com/apache/lucene/pull/12434#discussion_r1269789239
########## lucene/core/src/java/org/apache/lucene/util/hnsw/ToParentJoinKnnResults.java: ########## @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; + +/** parent joining knn results, vectorIds are deduplicated according to the parent bit set. */ +public class ToParentJoinKnnResults extends KnnResults { + + /** provider class for creating a new {@link ToParentJoinKnnResults} */ + public static class Provider implements KnnResultsProvider { + + private final int k; + private final BitSet parentBitSet; + + public Provider(int k, BitSet parentBitSet) { + this.k = k; + this.parentBitSet = parentBitSet; + } + + @Override + public int k() { + return k; + } + + @Override + public KnnResults getKnnResults(IntToIntFunction vectorToOrd) { + return new ToParentJoinKnnResults(k, parentBitSet, vectorToOrd); + } + } + + private final BitSet parentBitSet; + private final int k; + private final IntToIntFunction vectorToOrd; + private final NodeIdCachingHeap heap; + + public ToParentJoinKnnResults(int k, BitSet parentBitSet, IntToIntFunction vectorToOrd) { + this.parentBitSet = parentBitSet; + this.k = k; + this.vectorToOrd = vectorToOrd; + this.heap = new NodeIdCachingHeap(k); + } + + /** + * Adds a new graph arc, extending the storage as needed. + * + * <p>If the provided childNodeId's parent has been previously collected and the nodeScore is less + * than the previously stored score, this node will not be added to the collection. + * + * @param childNodeId the neighbor node id + * @param nodeScore the score of the neighbor, relative to some other node + */ + @Override + public void collect(int childNodeId, float nodeScore) { + childNodeId = vectorToOrd.apply(childNodeId); + assert !parentBitSet.get(childNodeId); + int nodeId = parentBitSet.nextSetBit(childNodeId); + heap.push(nodeId, nodeScore); + } + + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + * <p>If childNodeId's parent node has previously been collected and the provided nodeScore is + * less than the stored score it will not be collected. + * + * @param childNodeId the neighbor node id + * @param nodeScore the score of the neighbor, relative to some other node + */ + @Override + public boolean collectWithOverflow(int childNodeId, float nodeScore) { + // Parent and child nodes should be disjoint sets parent bit set should never have a child node + // ID present + childNodeId = vectorToOrd.apply(childNodeId); + assert !parentBitSet.get(childNodeId); + int nodeId = parentBitSet.nextSetBit(childNodeId); + return heap.insertWithOverflow(nodeId, nodeScore); + } + + @Override + public boolean isFull() { + return heap.size >= k; + } + + @Override + public float minSimilarity() { + return heap.topScore(); + } + + @Override + public void doClear() { + heap.clear(); + } + + @Override + public String toString() { + return "ToParentJoinKnnResults[" + heap.size + "]"; + } + + @Override + public TopDocs topDocs() { + while (heap.size() > k) { + heap.popToDrain(); + } + int i = 0; + ScoreDoc[] scoreDocs = new ScoreDoc[heap.size()]; + while (i < scoreDocs.length) { + int node = heap.topNode(); + float score = heap.topScore(); + heap.popToDrain(); + scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(node, score); + } + + TotalHits.Relation relation = + incomplete() ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); Review Comment: `visitedCount()` is always weird to me anyways as its the number of vectors we compared the query vector with, not necessarily the number of valid vectors within the index. Knowing the number of vector comparisons is useful, but it definitely doesn't mean the same thing as other "TotalHits" calculations :/. I guess vector search has hijacked this class to mean something else. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org