benwtrent commented on code in PR #12582: URL: https://github.com/apache/lucene/pull/12582#discussion_r1362664506
########## lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java: ########## @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocIDMerger; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.ScalarQuantizer; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +/** + * Writes quantized vector values and metadata to index segments. + * + * @lucene.experimental + */ +public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { + + private static final long BASE_RAM_BYTES_USED = + shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class); + + private static final float QUANTIZATION_RECOMPUTE_LIMIT = 32; + private final IndexOutput quantizedVectorData; + private final Float quantile; + private boolean finished; + + Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) { + this.quantile = quantile; + this.quantizedVectorData = quantizedVectorData; + } + + QuantizationVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "Only float32 vector fields are supported for quantization"); + } + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantizing field=" + + fieldInfo.name + + " dimension=" + + fieldInfo.getVectorDimension() + + " quantile=" + + quantile); + } + return QuantizationVectorWriter.create(fieldInfo, quantile, infoStream); + } + + long[] flush( + Sorter.DocMap sortMap, QuantizationVectorWriter field, DocsWithFieldSet docsWithField) + throws IOException { + field.finish(); + return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField); + } + + void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (quantizedVectorData != null) { + CodecUtil.writeFooter(quantizedVectorData); + } + } + + private long[] writeField(QuantizationVectorWriter fieldData) throws IOException { + long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeQuantizedVectors(fieldData); + long quantizedVectorDataLength = + quantizedVectorData.getFilePointer() - quantizedVectorDataOffset; + return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength}; + } + + private void writeQuantizedVectors(QuantizationVectorWriter fieldData) throws IOException { + ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + byte[] vector = new byte[fieldData.dim]; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] v : fieldData.floatVectors) { + float offsetCorrection = + scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + quantizedVectorData.writeBytes(vector, vector.length); + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + private long[] writeSortingField( + QuantizationVectorWriter fieldData, Sorter.DocMap sortMap, DocsWithFieldSet docsWithField) + throws IOException { + final int[] docIdOffsets = new int[sortMap.size()]; + int offset = 1; // 0 means no vector for this (field, document) + DocIdSetIterator iterator = docsWithField.iterator(); + for (int docID = iterator.nextDoc(); + docID != DocIdSetIterator.NO_MORE_DOCS; + docID = iterator.nextDoc()) { + int newDocID = sortMap.oldToNew(docID); + docIdOffsets[newDocID] = offset++; + } + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + final int[] ordMap = new int[offset - 1]; // new ord to old ord + final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord + int ord = 0; + int doc = 0; + for (int docIdOffset : docIdOffsets) { + if (docIdOffset != 0) { + ordMap[ord] = docIdOffset - 1; + oldOrdMap[docIdOffset - 1] = ord; + newDocsWithField.add(doc); + ord++; + } + doc++; + } + + // write vector values + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeSortedQuantizedVectors(fieldData, ordMap); + long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + + return new long[] {vectorDataOffset, quantizedVectorLength}; + } + + void writeSortedQuantizedVectors(QuantizationVectorWriter fieldData, int[] ordMap) + throws IOException { + ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); + byte[] vector = new byte[fieldData.dim]; + final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] v = fieldData.floatVectors.get(ordinal); + float offsetCorrection = + scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + quantizedVectorData.writeBytes(vector, vector.length); + offsetBuffer.putFloat(offsetCorrection); + quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); + offsetBuffer.rewind(); + } + } + + ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile); + } + + ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + ScalarQuantizer mergedQuantizationState) + throws IOException { + if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { + return null; + } + IndexOutput tempQuantizedVectorData = + segmentWriteState.directory.createTempOutput( + quantizedVectorData.getName(), "temp", segmentWriteState.context); + IndexInput quantizationDataInput = null; + boolean success = false; + try { + MergedQuantizedVectorValues byteVectorValues = + MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( + fieldInfo, mergeState, mergedQuantizationState); + writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + quantizationDataInput = + segmentWriteState.directory.openInput( + tempQuantizedVectorData.getName(), segmentWriteState.context); + quantizedVectorData.copyBytes( + quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(quantizationDataInput); + success = true; + final IndexInput finalQuantizationDataInput = quantizationDataInput; + return new ScalarQuantizedCloseableRandomVectorScorerSupplier( + () -> { + IOUtils.close(finalQuantizationDataInput); + segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); + }, + new ScalarQuantizedRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + mergedQuantizationState, + new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput))); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(quantizationDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, tempQuantizedVectorData.getName()); + } + } + } + + static ScalarQuantizer mergeQuantiles( + List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, float quantile) { + assert quantizationStates.size() == segmentSizes.size(); + if (quantizationStates.isEmpty()) { + return null; + } + float lowerQuantile = 0f; + float upperQuantile = 0f; + int totalCount = 0; + for (int i = 0; i < quantizationStates.size(); i++) { + if (quantizationStates.get(i) == null) { Review Comment: Since we have no knowledge of the vectors or their quantiles from this segment, its better to require us to recalculate the quantiles and requantize. We could continue merging the rest, but in the end, we have no information around how different the other vectors are. I think its good to play it safe here. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org