jpountz commented on code in PR #12434: URL: https://github.com/apache/lucene/pull/12434#discussion_r1273607520
########## lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java: ########## @@ -117,6 +118,71 @@ public abstract TopDocs search( */ public abstract TopDocs search( String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException; + + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + * <p>The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + * <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in + * order of their similarity to the query vector (decreasing scores). The {@link TotalHits} + * contains the number of documents visited during the search. If the search stopped early because + * it hit {@code visitedLimit}, it is indicated through the relation {@code + * TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}. + * + * <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link + * FieldInfo}. The return value is never {@code null}. + * + * @param field the vector field to search + * @param target the vector-valued query + * @param knnResults a KnnResults collector and relevant settings for gathering vector results + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + * @return the k nearest neighbor documents, along with their (similarity-specific) scores. + */ + public TopDocs search(String field, float[] target, KnnResults knnResults, Bits acceptDocs) + throws IOException { + throw new UnsupportedOperationException( + "vector reader doesn't provide KNN search with results provider"); Review Comment: How feasible do you think it would be to only have this method on `KnnVectorsReader`, fix older codecs to implement it, and make `LeafReader#searchNearestVectors(String, float[], int, Bits, int)` final by delegating to `LeafReader#searchNearestVectors(String, float[], KnnResults, int)`? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * KnnResults is a collector for gathering kNN results and providing topDocs from the gathered + * neighbors + */ +public abstract class KnnResults { + + /** KnnResults when exiting search early and returning empty top docs */ + static class EmptyKnnResults extends KnnResults { + public EmptyKnnResults(int k, int visitedCount, int visitLimit) { + super(k, visitLimit); + this.visitedCount = visitedCount; + } + + @Override + public void doClear() {} + + @Override + public boolean collect(int vectorId, float similarity) { + throw new IllegalArgumentException(); + } + + @Override + public boolean isFull() { + return true; + } + + @Override + public float minSimilarity() { + return 0; + } + + @Override + public TopDocs topDocs() { + TotalHits th = new TotalHits(visitedCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + return new TopDocs(th, new ScoreDoc[0]); + } + } + + static class OrdinalTranslatedKnnResults extends KnnResults { + private final KnnResults in; + private final IntToIntFunction vectorOrdinalToDocId; + + OrdinalTranslatedKnnResults(KnnResults in, IntToIntFunction vectorOrdinalToDocId) { + super(in.k, in.visitLimit); + this.in = in; + this.vectorOrdinalToDocId = vectorOrdinalToDocId; + } + + @Override + void doClear() { + in.clear(); + } + + @Override + boolean collect(int vectorId, float similarity) { + return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); + } + + @Override + boolean isFull() { + return in.isFull(); + } + + @Override + float minSimilarity() { + return in.minSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs td = in.topDocs(); + return new TopDocs( + new TotalHits( + visitedCount(), + incomplete() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO), + td.scoreDocs); + } + } + + protected int visitedCount; + private final int visitLimit; + private final int k; + + protected KnnResults(int k, int visitLimit) { + this.visitLimit = visitLimit; + this.k = k; + } + + final void clear() { + this.visitedCount = 0; + doClear(); + } + + /** Clear the current results. */ + abstract void doClear(); Review Comment: Both `clear` and `doClear` seem to be only needed for graph building, could we remove them from here? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ToParentJoinKnnResults.java: ########## @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BitSet; + +/** parent joining knn results, vectorIds are deduplicated according to the parent bit set. */ +public class ToParentJoinKnnResults extends KnnResults { Review Comment: Could this live in the lucene/join module instead of core? Likewise for `NodeIdCachingHead`? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * KnnResults is a collector for gathering kNN results and providing topDocs from the gathered + * neighbors + */ +public abstract class KnnResults { Review Comment: This class should probably be in oal.search rather than oal.util.hnsw since it's exposed by the search APIs (but the empty and ordinals-translating impls should stay here since they're implementation details of HNSWGraphBuilder)? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * KnnResults is a collector for gathering kNN results and providing topDocs from the gathered + * neighbors + */ +public abstract class KnnResults { + + /** KnnResults when exiting search early and returning empty top docs */ + static class EmptyKnnResults extends KnnResults { + public EmptyKnnResults(int k, int visitedCount, int visitLimit) { + super(k, visitLimit); + this.visitedCount = visitedCount; + } + + @Override + public void doClear() {} + + @Override + public boolean collect(int vectorId, float similarity) { + throw new IllegalArgumentException(); + } + + @Override + public boolean isFull() { + return true; + } + + @Override + public float minSimilarity() { + return 0; + } + + @Override + public TopDocs topDocs() { + TotalHits th = new TotalHits(visitedCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + return new TopDocs(th, new ScoreDoc[0]); + } + } + + static class OrdinalTranslatedKnnResults extends KnnResults { + private final KnnResults in; + private final IntToIntFunction vectorOrdinalToDocId; + + OrdinalTranslatedKnnResults(KnnResults in, IntToIntFunction vectorOrdinalToDocId) { + super(in.k, in.visitLimit); + this.in = in; + this.vectorOrdinalToDocId = vectorOrdinalToDocId; + } + + @Override + void doClear() { + in.clear(); + } + + @Override + boolean collect(int vectorId, float similarity) { + return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); + } + + @Override + boolean isFull() { + return in.isFull(); + } + + @Override + float minSimilarity() { + return in.minSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs td = in.topDocs(); + return new TopDocs( + new TotalHits( + visitedCount(), + incomplete() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO), + td.scoreDocs); + } + } + + protected int visitedCount; + private final int visitLimit; + private final int k; + + protected KnnResults(int k, int visitLimit) { + this.visitLimit = visitLimit; + this.k = k; + } + + final void clear() { + this.visitedCount = 0; + doClear(); + } + + /** Clear the current results. */ + abstract void doClear(); + + /** + * @return is the current result set marked as incomplete? + */ + final boolean incomplete() { + return visitedCount >= visitLimit; + } + + final void incVisitedCount(int count) { + assert count > 0; + this.visitedCount += count; + } + + /** + * @return the current visited count + */ + final int visitedCount() { + return visitedCount; + } + + final int visitLimit() { + return visitLimit; + } + + public final int k() { + return k; + } + + /** + * Collect the provided vectorId and include in the result set. + * + * @param vectorId the vector to collect + * @param similarity its calculated similarity + * @return true if the vector is collected + */ + abstract boolean collect(int vectorId, float similarity); Review Comment: The first parameter should be called docID now, right? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * KnnResults is a collector for gathering kNN results and providing topDocs from the gathered + * neighbors + */ +public abstract class KnnResults { + + /** KnnResults when exiting search early and returning empty top docs */ + static class EmptyKnnResults extends KnnResults { + public EmptyKnnResults(int k, int visitedCount, int visitLimit) { + super(k, visitLimit); + this.visitedCount = visitedCount; + } + + @Override + public void doClear() {} + + @Override + public boolean collect(int vectorId, float similarity) { + throw new IllegalArgumentException(); + } + + @Override + public boolean isFull() { + return true; + } + + @Override + public float minSimilarity() { + return 0; + } + + @Override + public TopDocs topDocs() { + TotalHits th = new TotalHits(visitedCount, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + return new TopDocs(th, new ScoreDoc[0]); + } + } + + static class OrdinalTranslatedKnnResults extends KnnResults { + private final KnnResults in; + private final IntToIntFunction vectorOrdinalToDocId; + + OrdinalTranslatedKnnResults(KnnResults in, IntToIntFunction vectorOrdinalToDocId) { + super(in.k, in.visitLimit); + this.in = in; + this.vectorOrdinalToDocId = vectorOrdinalToDocId; + } + + @Override + void doClear() { + in.clear(); + } + + @Override + boolean collect(int vectorId, float similarity) { + return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); + } + + @Override + boolean isFull() { + return in.isFull(); + } + + @Override + float minSimilarity() { + return in.minSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs td = in.topDocs(); + return new TopDocs( + new TotalHits( + visitedCount(), + incomplete() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO), + td.scoreDocs); + } + } + + protected int visitedCount; + private final int visitLimit; + private final int k; + + protected KnnResults(int k, int visitLimit) { + this.visitLimit = visitLimit; + this.k = k; + } + + final void clear() { + this.visitedCount = 0; + doClear(); + } + + /** Clear the current results. */ + abstract void doClear(); + + /** + * @return is the current result set marked as incomplete? + */ + final boolean incomplete() { + return visitedCount >= visitLimit; + } + + final void incVisitedCount(int count) { + assert count > 0; + this.visitedCount += count; + } + + /** + * @return the current visited count + */ + final int visitedCount() { + return visitedCount; + } + + final int visitLimit() { + return visitLimit; + } + + public final int k() { + return k; + } + + /** + * Collect the provided vectorId and include in the result set. + * + * @param vectorId the vector to collect + * @param similarity its calculated similarity + * @return true if the vector is collected + */ + abstract boolean collect(int vectorId, float similarity); + + /** + * @return Is the current result set considered full + */ + abstract boolean isFull(); Review Comment: I wonder if we can avoid introducing this method and change call-sites to check if minSimilarity() returns a greater value than `NEGATIVE_INFINITY`, in order to keep the surface area minimal. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org