jpountz commented on code in PR #12434:
URL: https://github.com/apache/lucene/pull/12434#discussion_r1274490843


##########
lucene/core/src/java/org/apache/lucene/util/hnsw/KnnResults.java:
##########
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.hnsw;
+
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+
+/**
+ * KnnResults is a collector for gathering kNN results and providing topDocs 
from the gathered
+ * neighbors
+ */
+public abstract class KnnResults {
+
+  /** KnnResults when exiting search early and returning empty top docs */
+  static class EmptyKnnResults extends KnnResults {
+    public EmptyKnnResults(int k, int visitedCount, int visitLimit) {
+      super(k, visitLimit);
+      this.visitedCount = visitedCount;
+    }
+
+    @Override
+    public void doClear() {}
+
+    @Override
+    public boolean collect(int vectorId, float similarity) {
+      throw new IllegalArgumentException();
+    }
+
+    @Override
+    public boolean isFull() {
+      return true;
+    }
+
+    @Override
+    public float minSimilarity() {
+      return 0;
+    }
+
+    @Override
+    public TopDocs topDocs() {
+      TotalHits th = new TotalHits(visitedCount, 
TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
+      return new TopDocs(th, new ScoreDoc[0]);
+    }
+  }
+
+  static class OrdinalTranslatedKnnResults extends KnnResults {
+    private final KnnResults in;
+    private final IntToIntFunction vectorOrdinalToDocId;
+
+    OrdinalTranslatedKnnResults(KnnResults in, IntToIntFunction 
vectorOrdinalToDocId) {
+      super(in.k, in.visitLimit);
+      this.in = in;
+      this.vectorOrdinalToDocId = vectorOrdinalToDocId;
+    }
+
+    @Override
+    void doClear() {
+      in.clear();
+    }
+
+    @Override
+    boolean collect(int vectorId, float similarity) {
+      return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity);
+    }
+
+    @Override
+    boolean isFull() {
+      return in.isFull();
+    }
+
+    @Override
+    float minSimilarity() {
+      return in.minSimilarity();
+    }
+
+    @Override
+    public TopDocs topDocs() {
+      TopDocs td = in.topDocs();
+      return new TopDocs(
+          new TotalHits(
+              visitedCount(),
+              incomplete()
+                  ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
+                  : TotalHits.Relation.EQUAL_TO),
+          td.scoreDocs);
+    }
+  }
+
+  protected int visitedCount;
+  private final int visitLimit;
+  private final int k;
+
+  protected KnnResults(int k, int visitLimit) {
+    this.visitLimit = visitLimit;
+    this.k = k;
+  }
+
+  final void clear() {
+    this.visitedCount = 0;
+    doClear();
+  }
+
+  /** Clear the current results. */
+  abstract void doClear();
+
+  /**
+   * @return is the current result set marked as incomplete?
+   */
+  final boolean incomplete() {
+    return visitedCount >= visitLimit;
+  }
+
+  final void incVisitedCount(int count) {
+    assert count > 0;
+    this.visitedCount += count;
+  }
+
+  /**
+   * @return the current visited count
+   */
+  final int visitedCount() {
+    return visitedCount;
+  }
+
+  final int visitLimit() {
+    return visitLimit;
+  }
+
+  public final int k() {
+    return k;
+  }
+
+  /**
+   * Collect the provided vectorId and include in the result set.
+   *
+   * @param vectorId the vector to collect
+   * @param similarity its calculated similarity
+   * @return true if the vector is collected
+   */
+  abstract boolean collect(int vectorId, float similarity);
+
+  /**
+   * @return Is the current result set considered full
+   */
+  abstract boolean isFull();

Review Comment:
   I see what you are saying, but I think I still like the "trusting 
implementers" option better. The semantics I'd like for this method is 
something like "the minimum similarity for a vector to be competitive", so it 
would naturally be NEGATIVE_INFINITY as long as the queue is not full. If we 
don't trust implementers, then we need to update javadocs of `minSimilarity()` 
to add something like "it is only legal to call this method when isFull() 
returns true" which isn't nice, let's make `minSimilarity()` always correct? In 
terms of implementation, I imagine that `minSimilarity()` would need to do 
something like `queue.size() >= k() ? queue.topScore() : 
Float.NEGATIVE_INFINITY`? I see that we also use `minSimilarity()` on 
`GraphBuilderKnnResults` as a way to know the score of the top node, but I'm 
assuming we could address this by adding a new method on 
`GraphBuilderKnnResults`, since it's an implementation details of our HNSW impl?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to