jpountz commented on code in PR #13697:
URL: https://github.com/apache/lucene/pull/13697#discussion_r1735760199


##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -156,6 +161,15 @@ public Scorer get(long leadCost) throws IOException {
           return new BlockJoinScorer(childScorerSupplier.get(leadCost), 
parents, scoreMode);
         }
 
+        @Override
+        public BulkScorer bulkScorer() throws IOException {
+          final BulkScorer innerBulkScorer = childScorerSupplier.bulkScorer();
+          if (innerBulkScorer == null) {

Review Comment:
   `innerBulkScorer` cannot be `null`, you can skip this check.



##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -440,6 +500,83 @@ private String formatScoreExplanation(int matches, int 
start, int end, ScoreMode
     }
   }
 
+  private abstract static class BatchAwareLeafCollector extends 
FilterLeafCollector {
+    public BatchAwareLeafCollector(LeafCollector in) {
+      super(in);
+    }
+
+    public void endBatch(int doc) throws IOException {}
+  }
+
+  private static class BlockJoinBulkScorer extends BulkScorer {
+    private final BulkScorer childBulkScorer;
+    private final ScoreMode scoreMode;
+    private final BitSet parents;
+    private final Score currentParentScore;
+    private Integer currentParent;
+
+    public BlockJoinBulkScorer(BulkScorer childBulkScorer, ScoreMode 
scoreMode, BitSet parents) {
+      this.childBulkScorer = childBulkScorer;
+      this.scoreMode = scoreMode;
+      this.parents = parents;
+      this.currentParentScore = new Score(scoreMode);
+      this.currentParent = null;
+    }
+
+    @Override
+    public int score(LeafCollector collector, Bits acceptDocs, int min, int 
max)
+        throws IOException {
+      BatchAwareLeafCollector wrappedCollector = wrapCollector(collector);
+      childBulkScorer.score(wrappedCollector, acceptDocs, min, max);

Review Comment:
   When this method is called on parents documents, it is expected to score all 
parent documents in the [min, max) range. So I think that we should translate 
the range of doc IDs when calling the child bulk scorer so that:
    - We do not score child documents beyond the last parent document in this 
range, this would be wasteful.
    - We include all child documents of the first parent document in the range 
(some of them may be before `min`).



##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -275,6 +289,51 @@ public float matchCost() {
     }
   }
 
+  private static class Score extends Scorable {
+    private final ScoreMode scoreMode;
+    private Float score;

Review Comment:
   would it work if we made it a primitive `float` and set it to `0` in 
`reset()`?



##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -440,6 +500,83 @@ private String formatScoreExplanation(int matches, int 
start, int end, ScoreMode
     }
   }
 
+  private abstract static class BatchAwareLeafCollector extends 
FilterLeafCollector {
+    public BatchAwareLeafCollector(LeafCollector in) {
+      super(in);
+    }
+
+    public void endBatch(int doc) throws IOException {}
+  }
+
+  private static class BlockJoinBulkScorer extends BulkScorer {
+    private final BulkScorer childBulkScorer;
+    private final ScoreMode scoreMode;
+    private final BitSet parents;
+    private final Score currentParentScore;
+    private Integer currentParent;
+
+    public BlockJoinBulkScorer(BulkScorer childBulkScorer, ScoreMode 
scoreMode, BitSet parents) {
+      this.childBulkScorer = childBulkScorer;
+      this.scoreMode = scoreMode;
+      this.parents = parents;
+      this.currentParentScore = new Score(scoreMode);
+      this.currentParent = null;
+    }
+
+    @Override
+    public int score(LeafCollector collector, Bits acceptDocs, int min, int 
max)
+        throws IOException {
+      BatchAwareLeafCollector wrappedCollector = wrapCollector(collector);
+      childBulkScorer.score(wrappedCollector, acceptDocs, min, max);
+      wrappedCollector.endBatch(max);
+      return max;
+    }
+
+    @Override
+    public long cost() {
+      return childBulkScorer.cost();
+    }
+
+    // TODO: Need to resolve parent doc IDs in multi-reader space?
+    private BatchAwareLeafCollector wrapCollector(LeafCollector collector) {
+      return new BatchAwareLeafCollector(collector) {
+        private Scorable scorer = null;
+
+        @Override
+        public void setScorer(Scorable scorer) throws IOException {
+          this.scorer = scorer;
+          super.setScorer(scorer != null ? currentParentScore : null);
+        }
+
+        @Override
+        public void collect(int doc) throws IOException {
+          if (currentParent == null) {
+            currentParent = parents.nextSetBit(doc);
+          } else if (doc > currentParent) {
+            in.collect(currentParent); // Emit the current parent
+
+            // Get the next parent and reset the score
+            currentParent = parents.nextSetBit(doc);
+            currentParentScore.reset();
+          }
+
+          if (scorer != null && scoreMode != ScoreMode.None) {
+            currentParentScore.addChildScore(scorer.score());
+          }
+        }
+
+        @Override
+        public void endBatch(int doc) throws IOException {
+          if (currentParent != null && doc > currentParent) {

Review Comment:
   I believe that the second condition (doc > currentParent) would always be 
true? So we don't even need to pass a doc ID to this method, and we could just 
make sure to collect the current parent if there is one?



##########
lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java:
##########
@@ -0,0 +1,264 @@
+package org.apache.lucene.search.join;
+
+import com.carrotsearch.randomizedtesting.generators.RandomPicks;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.BoostQuery;
+import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.ConstantScoreQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScorerSupplier;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestBlockJoinBulkScorer extends LuceneTestCase {
+  private static final String TYPE_FIELD_NAME = "type";
+  private static final String VALUE_FIELD_NAME = "value";
+  private static final String PARENT_FILTER_VALUE = "parent";
+  private static final String CHILD_FILTER_VALUE = "child";
+
+  private enum MatchValue {
+    MATCH_A("A", 1),
+    MATCH_B("B", 2),
+    MATCH_C("C", 3);
+
+    private static final List<MatchValue> VALUES = List.of(values());
+
+    private final String text;
+    private final int score;
+
+    MatchValue(String text, int score) {
+      this.text = text;
+      this.score = score;
+    }
+
+    public String getText() {
+      return text;
+    }
+
+    public int getScore() {
+      return score;
+    }
+
+    @Override
+    public String toString() {
+      return text;
+    }
+
+    public static MatchValue random() {
+      return RandomPicks.randomFrom(LuceneTestCase.random(), VALUES);
+    }
+  }
+
+  private record ChildDocMatch(int docId, List<MatchValue> matches) {
+    public ChildDocMatch(int docId, List<MatchValue> matches) {
+      this.docId = docId;
+      this.matches = Collections.unmodifiableList(matches);
+    }
+  }
+
+  private static Map<Integer, List<ChildDocMatch>> populateIndex(
+      RandomIndexWriter writer, int maxParentDocCount, int maxChildDocCount, 
int maxChildDocMatches)
+      throws IOException {
+    Map<Integer, List<ChildDocMatch>> expectedMatches = new HashMap<>();
+
+    final int parentDocCount = random().nextInt(1, maxParentDocCount + 1);
+    int currentDocId = 0;
+    for (int i = 0; i < parentDocCount; i++) {
+      final int childDocCount = random().nextInt(maxChildDocCount + 1);
+      List<Document> docs = new ArrayList<>(childDocCount);
+      List<ChildDocMatch> childDocMatches = new ArrayList<>(childDocCount);
+
+      for (int j = 0; j < childDocCount; j++) {
+        // Build a child doc
+        Document childDoc = new Document();
+        childDoc.add(newStringField(TYPE_FIELD_NAME, CHILD_FILTER_VALUE, 
Field.Store.NO));
+
+        final int matchCount = random().nextInt(maxChildDocMatches + 1);
+        List<MatchValue> matchValues = new ArrayList<>(matchCount);
+        for (int k = 0; k < matchCount; k++) {
+          // Add a match to the child doc
+          MatchValue matchValue = MatchValue.random();
+          matchValues.add(matchValue);
+          childDoc.add(newStringField(VALUE_FIELD_NAME, matchValue.getText(), 
Field.Store.NO));
+        }
+
+        docs.add(childDoc);
+        childDocMatches.add(new ChildDocMatch(currentDocId++, matchValues));
+      }
+
+      // Build a parent doc
+      Document parentDoc = new Document();
+      parentDoc.add(newStringField(TYPE_FIELD_NAME, PARENT_FILTER_VALUE, 
Field.Store.NO));
+      docs.add(parentDoc);
+
+      // Don't add parent docs with no children to expectedMatches
+      if (childDocCount > 0) {
+        expectedMatches.put(currentDocId, childDocMatches);
+      }
+      currentDocId++;
+
+      writer.addDocuments(docs);
+    }
+
+    return expectedMatches;
+  }
+
+  private static Map<Integer, Float> computeExpectedScores(
+      Map<Integer, List<ChildDocMatch>> expectedMatches,
+      ScoreMode joinScoreMode,
+      org.apache.lucene.search.ScoreMode searchScoreMode) {
+    Map<Integer, Float> expectedScores = new HashMap<>();
+    for (var entry : expectedMatches.entrySet()) {
+      // Filter out child docs with no matches since those will never 
contribute to the score
+      List<ChildDocMatch> childDocMatches =
+          entry.getValue().stream().filter(m -> 
!m.matches().isEmpty()).toList();
+      if (childDocMatches.isEmpty()) {
+        continue;
+      }
+
+      Float expectedScore = null;
+      if (searchScoreMode.needsScores()) {
+        for (ChildDocMatch childDocMatch : childDocMatches) {
+          float expectedChildDocScore = computeExpectedScore(childDocMatch);
+          switch (joinScoreMode) {
+            case Total:
+            case Avg:
+              expectedScore =
+                  expectedScore == null
+                      ? expectedChildDocScore
+                      : expectedScore + expectedChildDocScore;
+              break;
+            case Min:
+              expectedScore =
+                  expectedScore == null
+                      ? expectedChildDocScore
+                      : Math.min(expectedScore, expectedChildDocScore);
+              break;
+            case Max:
+              expectedScore =
+                  expectedScore == null
+                      ? expectedChildDocScore
+                      : Math.max(expectedScore, expectedChildDocScore);
+              break;
+            case None:
+              break;
+            default:
+              throw new AssertionError();
+          }
+        }
+
+        if (joinScoreMode == ScoreMode.Avg) {
+          expectedScore /= childDocMatches.size();
+        }
+      }
+
+      expectedScores.put(entry.getKey(), expectedScore != null ? expectedScore 
: 0);
+    }
+
+    return expectedScores;
+  }
+
+  private static float computeExpectedScore(ChildDocMatch childDocMatch) {
+    float expectedScore = 0.0f;
+    Set<MatchValue> matchValueSet = new HashSet<>(childDocMatch.matches());
+    for (MatchValue matchValue : matchValueSet) {
+      expectedScore += matchValue.getScore();
+    }
+
+    return expectedScore;
+  }
+
+  public void testScoreRandomIndices() throws IOException {
+    for (int i = 0; i < 100; i++) {
+      try (Directory dir = newDirectory()) {
+        Map<Integer, List<ChildDocMatch>> expectedMatches;
+        try (RandomIndexWriter w =
+            new RandomIndexWriter(
+                random(),
+                dir,
+                newIndexWriterConfig()
+                    .setMergePolicy(
+                        // retain doc id order
+                        newLogMergePolicy(random().nextBoolean())))) {
+
+          expectedMatches = populateIndex(w, 10, 5, 3);
+          w.forceMerge(1);
+        }
+
+        try (IndexReader reader = DirectoryReader.open(dir)) {
+          final IndexSearcher searcher = newSearcher(reader);
+          final ScoreMode joinScoreMode =
+              RandomPicks.randomFrom(LuceneTestCase.random(), 
ScoreMode.values());
+          final org.apache.lucene.search.ScoreMode searchScoreMode =
+              RandomPicks.randomFrom(
+                  LuceneTestCase.random(), 
org.apache.lucene.search.ScoreMode.values());
+          final Map<Integer, Float> expectedScores =
+              computeExpectedScores(expectedMatches, joinScoreMode, 
searchScoreMode);
+
+          BooleanQuery.Builder childQueryBuilder = new BooleanQuery.Builder();
+          for (MatchValue matchValue : MatchValue.VALUES) {
+            childQueryBuilder.add(
+                new BoostQuery(
+                    new ConstantScoreQuery(
+                        new TermQuery(new Term(VALUE_FIELD_NAME, 
matchValue.getText()))),
+                    matchValue.getScore()),
+                BooleanClause.Occur.SHOULD);
+          }
+          BitSetProducer parentsFilter =
+              new QueryBitSetProducer(
+                  new TermQuery(new Term(TYPE_FIELD_NAME, 
PARENT_FILTER_VALUE)));
+          ToParentBlockJoinQuery parentQuery =
+              new ToParentBlockJoinQuery(childQueryBuilder.build(), 
parentsFilter, joinScoreMode);
+
+          Weight weight = searcher.createWeight(searcher.rewrite(parentQuery), 
searchScoreMode, 1);
+          ScorerSupplier ss = 
weight.scorerSupplier(searcher.getIndexReader().leaves().get(0));
+
+          // TODO: Why is score supplier null sometimes?
+          if (ss == null) {
+            continue;
+          }

Review Comment:
   ScorerSupplier may be null when a query has no matches on a segment. However 
if the scorer supplier is not null, then both ScorerSupplier#get and 
ScorerSupplier#bulkScorer return a non-null value.



##########
lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoinBulkScorer.java:
##########
@@ -0,0 +1,264 @@
+package org.apache.lucene.search.join;
+
+import com.carrotsearch.randomizedtesting.generators.RandomPicks;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.BoostQuery;
+import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.ConstantScoreQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScorerSupplier;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestBlockJoinBulkScorer extends LuceneTestCase {
+  private static final String TYPE_FIELD_NAME = "type";
+  private static final String VALUE_FIELD_NAME = "value";
+  private static final String PARENT_FILTER_VALUE = "parent";
+  private static final String CHILD_FILTER_VALUE = "child";
+
+  private enum MatchValue {
+    MATCH_A("A", 1),
+    MATCH_B("B", 2),
+    MATCH_C("C", 3);
+
+    private static final List<MatchValue> VALUES = List.of(values());

Review Comment:
   Yes, we'll fix it upong backporting.



##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -440,6 +500,83 @@ private String formatScoreExplanation(int matches, int 
start, int end, ScoreMode
     }
   }
 
+  private abstract static class BatchAwareLeafCollector extends 
FilterLeafCollector {
+    public BatchAwareLeafCollector(LeafCollector in) {
+      super(in);
+    }
+
+    public void endBatch(int doc) throws IOException {}
+  }
+
+  private static class BlockJoinBulkScorer extends BulkScorer {
+    private final BulkScorer childBulkScorer;
+    private final ScoreMode scoreMode;
+    private final BitSet parents;
+    private final Score currentParentScore;
+    private Integer currentParent;
+
+    public BlockJoinBulkScorer(BulkScorer childBulkScorer, ScoreMode 
scoreMode, BitSet parents) {
+      this.childBulkScorer = childBulkScorer;
+      this.scoreMode = scoreMode;
+      this.parents = parents;
+      this.currentParentScore = new Score(scoreMode);
+      this.currentParent = null;
+    }
+
+    @Override
+    public int score(LeafCollector collector, Bits acceptDocs, int min, int 
max)
+        throws IOException {
+      BatchAwareLeafCollector wrappedCollector = wrapCollector(collector);
+      childBulkScorer.score(wrappedCollector, acceptDocs, min, max);
+      wrappedCollector.endBatch(max);
+      return max;
+    }
+
+    @Override
+    public long cost() {
+      return childBulkScorer.cost();
+    }
+
+    // TODO: Need to resolve parent doc IDs in multi-reader space?
+    private BatchAwareLeafCollector wrapCollector(LeafCollector collector) {
+      return new BatchAwareLeafCollector(collector) {
+        private Scorable scorer = null;
+
+        @Override
+        public void setScorer(Scorable scorer) throws IOException {
+          this.scorer = scorer;
+          super.setScorer(scorer != null ? currentParentScore : null);

Review Comment:
   No, it will never be `null`.



##########
lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java:
##########
@@ -440,6 +500,83 @@ private String formatScoreExplanation(int matches, int 
start, int end, ScoreMode
     }
   }
 
+  private abstract static class BatchAwareLeafCollector extends 
FilterLeafCollector {
+    public BatchAwareLeafCollector(LeafCollector in) {
+      super(in);
+    }
+
+    public void endBatch(int doc) throws IOException {}
+  }
+
+  private static class BlockJoinBulkScorer extends BulkScorer {
+    private final BulkScorer childBulkScorer;
+    private final ScoreMode scoreMode;
+    private final BitSet parents;
+    private final Score currentParentScore;
+    private Integer currentParent;
+
+    public BlockJoinBulkScorer(BulkScorer childBulkScorer, ScoreMode 
scoreMode, BitSet parents) {
+      this.childBulkScorer = childBulkScorer;
+      this.scoreMode = scoreMode;
+      this.parents = parents;
+      this.currentParentScore = new Score(scoreMode);
+      this.currentParent = null;
+    }
+
+    @Override
+    public int score(LeafCollector collector, Bits acceptDocs, int min, int 
max)
+        throws IOException {
+      BatchAwareLeafCollector wrappedCollector = wrapCollector(collector);
+      childBulkScorer.score(wrappedCollector, acceptDocs, min, max);
+      wrappedCollector.endBatch(max);
+      return max;
+    }
+
+    @Override
+    public long cost() {
+      return childBulkScorer.cost();
+    }
+
+    // TODO: Need to resolve parent doc IDs in multi-reader space?
+    private BatchAwareLeafCollector wrapCollector(LeafCollector collector) {
+      return new BatchAwareLeafCollector(collector) {
+        private Scorable scorer = null;
+
+        @Override
+        public void setScorer(Scorable scorer) throws IOException {
+          this.scorer = scorer;
+          super.setScorer(scorer != null ? currentParentScore : null);
+        }
+
+        @Override
+        public void collect(int doc) throws IOException {
+          if (currentParent == null) {

Review Comment:
   I'd rather avoid boxing/unboxing constantly by initializing `currentParent` 
to `-1` and then checking if `currentParent < doc` here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org
For additional commands, e-mail: issues-h...@lucene.apache.org

Reply via email to