mayya-sharipova commented on code in PR #12582: URL: https://github.com/apache/lucene/pull/12582#discussion_r1357345297
########## lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java: ########## @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.ScalarQuantizer; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +/** + * Writes quantized vector values and metadata to index segments. + * + * @lucene.experimental + */ +public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { + + private static final long BASE_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class); + + private static final float QUANTIZATION_RECOMPUTE_LIMIT = 32; + private final IndexOutput quantizedVectorData; + private final Float quantile; + private boolean finished; + + Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) { + this.quantile = quantile; + this.quantizedVectorData = quantizedVectorData; + } + + QuantizationVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "Only float32 vector fields are supported for quantization"); + } + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantizing field=" + + fieldInfo.name + + " dimension=" + + fieldInfo.getVectorDimension() + + " quantile=" + + quantile); + } + return QuantizationVectorWriter.create(fieldInfo, quantile, infoStream); + } + + long[] flush( + Sorter.DocMap sortMap, QuantizationVectorWriter field, DocsWithFieldSet docsWithField) + throws IOException { + field.finish(); + return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField); + } + + void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (quantizedVectorData != null) { + CodecUtil.writeFooter(quantizedVectorData); + } + } + + private long[] writeField(QuantizationVectorWriter fieldData) throws IOException { + long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeQuantizedVectors(fieldData); + long quantizedVectorDataLength = + quantizedVectorData.getFilePointer() - quantizedVectorDataOffset; + return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength}; + } + + private void writeQuantizedVectors(QuantizationVectorWriter fieldData) throws IOException { + ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + byte[] vector = new byte[fieldData.dim]; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] v : fieldData.floatVectors) { + float offsetCorrection = + scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + quantizedVectorData.writeBytes(vector, vector.length); + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + private long[] writeSortingField( + QuantizationVectorWriter fieldData, Sorter.DocMap sortMap, DocsWithFieldSet docsWithField) + throws IOException { + final int[] docIdOffsets = new int[sortMap.size()]; + int offset = 1; // 0 means no vector for this (field, document) + DocIdSetIterator iterator = docsWithField.iterator(); + for (int docID = iterator.nextDoc(); + docID != DocIdSetIterator.NO_MORE_DOCS; + docID = iterator.nextDoc()) { + int newDocID = sortMap.oldToNew(docID); + docIdOffsets[newDocID] = offset++; + } + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + final int[] ordMap = new int[offset - 1]; // new ord to old ord + final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord + int ord = 0; + int doc = 0; + for (int docIdOffset : docIdOffsets) { + if (docIdOffset != 0) { + ordMap[ord] = docIdOffset - 1; + oldOrdMap[docIdOffset - 1] = ord; + newDocsWithField.add(doc); + ord++; + } + doc++; + } + + // write vector values + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeSortedQuantizedVectors(fieldData, ordMap); + long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + + return new long[] {vectorDataOffset, quantizedVectorLength}; + } + + void writeSortedQuantizedVectors(QuantizationVectorWriter fieldData, int[] ordMap) + throws IOException { + ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + byte[] vector = new byte[fieldData.dim]; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] v = fieldData.floatVectors.get(ordinal); + float offsetCorrection = + scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + quantizedVectorData.writeBytes(vector, vector.length); + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile); + } + + ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + ScalarQuantizer mergedQuantizationState) + throws IOException { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + IndexOutput tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + quantizedVectorData.getName(), "temp", segmentWriteState.context); + IndexInput quantizationDataInput = null; + boolean success = false; + try { + MergedQuantizedVectorValues byteVectorValues = + MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( + fieldInfo, mergeState, mergedQuantizationState); + writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + quantizationDataInput = + segmentWriteState.directory.openInput( + tempQuantizedVectorData.getName(), segmentWriteState.context); + quantizedVectorData.copyBytes( + quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(quantizationDataInput); + success = true; + final IndexInput finalQuantizationDataInput = quantizationDataInput; + return new ScalarQuantizedCloseableRandomVectorScorerSupplier( + () -> { + IOUtils.close(finalQuantizationDataInput); + segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); + }, + new ScalarQuantizedRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + mergedQuantizationState, + new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput))); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(quantizationDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + } + } + + static ScalarQuantizer mergeQuantiles( + List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, float quantile) { + assert quantizationStates.size() == segmentSizes.size(); + if (quantizationStates.isEmpty()) { + return null; + } + float lowerQuantile = 0f; + float upperQuantile = 0f; + int totalCount = 0; + for (int i = 0; i < quantizationStates.size(); i++) { + if (quantizationStates.get(i) == null) { + return null; + } + lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i); + upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i); + totalCount += segmentSizes.get(i); + } + lowerQuantile /= totalCount; + upperQuantile /= totalCount; + return new ScalarQuantizer(lowerQuantile, upperQuantile, quantile); + } + + static boolean shouldRecomputeQuantiles( + ScalarQuantizer mergedQuantizationState, List<ScalarQuantizer> quantizationStates) { + float limit = + (mergedQuantizationState.getUpperQuantile() - mergedQuantizationState.getLowerQuantile()) + / QUANTIZATION_RECOMPUTE_LIMIT; + for (ScalarQuantizer quantizationState : quantizationStates) { + if (Math.abs( + quantizationState.getUpperQuantile() - mergedQuantizationState.getUpperQuantile()) + > limit) { + return true; + } + if (Math.abs( + quantizationState.getLowerQuantile() - mergedQuantizationState.getLowerQuantile()) + > limit) { + return true; + } + } + return false; + } + + private static QuantizedVectorsReader getQuantizedKnnVectorsReader( + KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof QuantizedVectorsReader reader) { + return reader; + } + return null; + } + + private static ScalarQuantizer getQuantizedState( + KnnVectorsReader vectorsReader, String fieldName) { + QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(vectorsReader, fieldName); + if (reader != null) { + return reader.getQuantizationState(fieldName); + } + return null; + } + + static ScalarQuantizer mergeAndRecalculateQuantiles( + MergeState mergeState, FieldInfo fieldInfo, float quantile) throws IOException { + List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length); + List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length); + for (int i = 0; i < mergeState.liveDocs.length; i++) { + FloatVectorValues fvv; + if (mergeState.knnVectorsReaders[i] != null + && (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null + && fvv.size() > 0) { + ScalarQuantizer quantizationState = + getQuantizedState(mergeState.knnVectorsReaders[i], fieldInfo.name); + // If we have quantization state, we can utilize that to make merging cheaper + quantizationStates.add(quantizationState); + segmentSizes.add(fvv.size()); + } + } + ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, quantile); + // Segments no providing quantization state indicates that their quantiles were never + // calculated. + // To be safe, we should always recalculate given a sample set over all the float vectors in the + // merged + // segment view + if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { + FloatVectorValues vectorValues = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, quantile); + } + return mergedQuantiles; + } + + static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) { + // Should this instead be 128f? + float tol = 0.2f * (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile()) / 128f; + if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) { + return true; + } + return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol; + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeQuantizedVectorData( + IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + for (int docV = quantizedByteVectorValues.nextDoc(); + docV != NO_MORE_DOCS; + docV = quantizedByteVectorValues.nextDoc()) { + // write vector + byte[] binaryValue = quantizedByteVectorValues.vectorValue(); + assert binaryValue.length == quantizedByteVectorValues.dimension() + : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; + output.writeBytes(binaryValue, binaryValue.length); + output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant())); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public long ramBytesUsed() { + return BASE_RAM_BYTES_USED; + } + + static class QuantizationVectorWriter implements Accountable { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(QuantizationVectorWriter.class); + private final int dim; + private final List<float[]> floatVectors; + private final boolean normalize; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final float quantile; + private final InfoStream infoStream; + private float minQuantile = Float.POSITIVE_INFINITY; + private float maxQuantile = Float.NEGATIVE_INFINITY; + private boolean finished; + + static QuantizationVectorWriter create( + FieldInfo fieldInfo, float quantile, InfoStream infoStream) { + return new QuantizationVectorWriter( + fieldInfo.getVectorDimension(), + quantile, + fieldInfo.getVectorSimilarityFunction(), + infoStream); + } + + QuantizationVectorWriter( + int dim, + float quantile, + VectorSimilarityFunction vectorSimilarityFunction, + InfoStream infoStream) { + this.dim = dim; + this.quantile = quantile; + this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.floatVectors = new ArrayList<>(); + this.infoStream = infoStream; + } + + void finish() throws IOException { + if (finished) { + return; + } + if (floatVectors.size() == 0) { + finished = true; + return; + } + ScalarQuantizer quantizer = + ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile); + minQuantile = quantizer.getLowerQuantile(); + maxQuantile = quantizer.getUpperQuantile(); + if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantized field=" + + " dimension=" + + dim + + " quantile=" + + quantile + + " minQuantile=" + + minQuantile + + " maxQuantile=" + + maxQuantile); + } + finished = true; + } + + public void addValue(float[] vectorValue) throws IOException { + floatVectors.add(vectorValue); + } + + float getMinQuantile() { + assert finished; + return minQuantile; + } + + float getMaxQuantile() { + assert finished; + return maxQuantile; + } + + float getQuantile() { + return quantile; + } + + ScalarQuantizer createQuantizer() { + assert finished; + return new ScalarQuantizer(minQuantile, maxQuantile, quantile); + } + + @Override + public long ramBytesUsed() { + if (floatVectors.size() == 0) return SHALLOW_SIZE; + return SHALLOW_SIZE + + (long) floatVectors.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER); Review Comment: I think we don't need `RamUsageEstimator.NUM_BYTES_ARRAY_HEADER` as here we only to account for a number of object references. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org