benwtrent commented on code in PR #13651: URL: https://github.com/apache/lucene/pull/13651#discussion_r1731849841
########## lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912BinaryFlatVectorsScorer.java: ########## @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene912; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.BQSpaceUtils; +import org.apache.lucene.util.quantization.BQVectorUtils; +import org.apache.lucene.util.quantization.BinaryQuantizer; + +public class Lucene912BinaryFlatVectorsScorer implements BinaryFlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + public Lucene912BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + // FIXME: write me ... presumably we can create a supplier here without a set of query vectors; + // need to do that and figure out what that instantiation looks like for the Supplier + if (vectorValues instanceof RandomAccessBinarizedByteVectorValues binarizedQueryVectors) { + return new BinarizedRandomVectorScorerSupplier( + null, binarizedQueryVectors.copy(), similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + if (vectorValues instanceof RandomAccessBinarizedByteVectorValues binarizedQueryVectors) { + // TODO, implement & handle more than one coarse grained cluster + BinaryQuantizer quantizer = binarizedQueryVectors.getQuantizer(); + float[][] centroids = binarizedQueryVectors.getCentroids(); + int discretizedDimensions = BQVectorUtils.discretize(target.length, 64); + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * discretizedDimensions / 8]; + float distToCentroid = VectorUtil.squareDistance(target, centroids[0]); + BinaryQuantizer.QueryFactors factors = + quantizer.quantizeForQuery(target, quantized, centroids[0]); + return new BinarizedRandomVectorScorer( + new BinaryQueryVector[] {new BinaryQueryVector(quantized, distToCentroid, factors)}, + binarizedQueryVectors, + similarityFunction, + discretizedDimensions); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + RandomAccessBinarizedQueryByteVectorValues scoringVectors, + RandomAccessBinarizedByteVectorValues targetVectors) + throws IOException { + return null; + } + + @Override + public String toString() { + return "Lucene912BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + public static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final RandomAccessBinarizedQueryByteVectorValues queryVectors; + private final RandomAccessBinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final int discretizedDimensions; + + public BinarizedRandomVectorScorerSupplier( + RandomAccessBinarizedQueryByteVectorValues queryVectors, + RandomAccessBinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction) + throws IOException { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + this.discretizedDimensions = BQVectorUtils.discretize(this.queryVectors.dimension(), 64); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] queryVector = queryVectors.vectorValue(ord); + + short quantizedSum = queryVectors.sumQuantizedValues(ord, 0); + + float distanceToCentroid = queryVectors.getCentroidDistance(ord, 0); + float vl = queryVectors.getLower(ord, 0); + float width = queryVectors.getWidth(ord, 0); + + float normVmC = queryVectors.getNormVmC(ord, 0); + float vDotC = queryVectors.getVDotC(ord, 0); + float cDotC = queryVectors.getCDotC(ord, 0); + + return new BinarizedRandomVectorScorer( + new BinaryQueryVector[] { + new BinaryQueryVector( + queryVector, + distanceToCentroid, + new BinaryQuantizer.QueryFactors(quantizedSum, vl, width, normVmC, vDotC, cDotC)) + }, + targetVectors, + similarityFunction, + discretizedDimensions); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier( + queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + public record BinaryQueryVector( + byte[] vector, float distanceToCentroid, BinaryQuantizer.QueryFactors factors) {} + + public static class BinarizedRandomVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector[] queryVectors; + private final RandomAccessBinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final int discretizedDimensions; + private final float sqrtDimensions; + private final float maxX1; + + public BinarizedRandomVectorScorer( + BinaryQueryVector[] queryVectors, Review Comment: Because I didn't add all the logic for multiple clusters -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org