Copilot commented on code in PR #17994:
URL: https://github.com/apache/pinot/pull/17994#discussion_r3004241057
##########
pinot-core/src/main/java/org/apache/pinot/core/operator/filter/VectorSimilarityFilterOperator.java:
##########
@@ -120,6 +157,106 @@ protected void explainAttributes(ExplainAttributeBuilder
attributeBuilder) {
attributeBuilder.putString("vectorIdentifier",
_predicate.getLhs().getIdentifier());
attributeBuilder.putString("vectorLiteral",
Arrays.toString(_predicate.getValue()));
attributeBuilder.putLongIdempotent("topKtoSearch", _predicate.getTopK());
+ if (_searchParams.isExactRerank()) {
+ attributeBuilder.putString("exactRerank", "true");
+ }
+ }
+
+ /**
+ * Executes the vector search with backend-specific parameter dispatch and
optional rerank.
+ */
+ private ImmutableRoaringBitmap executeSearch() {
+ String column = _predicate.getLhs().getIdentifier();
+ float[] queryVector = _predicate.getValue();
+ int topK = _predicate.getTopK();
+
+ // 1. Configure backend-specific parameters via interfaces
+ configureBackendParams(column);
+
+ // 2. Determine effective search count (higher if rerank is enabled)
+ int searchCount = topK;
+ if (_searchParams.isExactRerank()) {
+ searchCount = _searchParams.getEffectiveMaxCandidates(topK);
+ }
+
+ // 3. Execute ANN search
+ ImmutableRoaringBitmap annResults =
_vectorIndexReader.getDocIds(queryVector, searchCount);
+ int annCandidateCount = annResults.getCardinality();
+
+ LOGGER.debug("Vector search on column: {}, backend: {}, topK: {},
searchCount: {}, annCandidates: {}",
+ column, getBackendName(), topK, searchCount, annCandidateCount);
+
+ // 4. Apply exact rerank if requested
+ if (_searchParams.isExactRerank() && _forwardIndexReader != null &&
annCandidateCount > 0) {
+ ImmutableRoaringBitmap reranked = applyExactRerank(annResults,
queryVector, topK, column);
+ LOGGER.debug("Exact rerank on column: {}, candidates: {} -> final: {}",
+ column, annCandidateCount, reranked.getCardinality());
+ return reranked;
+ }
+
+ return annResults;
+ }
+
+ /**
+ * Configures backend-specific search parameters on the reader if it
supports them.
+ */
+ private void configureBackendParams(String column) {
+ // Set nprobe on IVF_FLAT readers
+ if (_vectorIndexReader instanceof NprobeAware) {
+ int nprobe = _searchParams.getNprobe();
+ ((NprobeAware) _vectorIndexReader).setNprobe(nprobe);
+ LOGGER.debug("Set nprobe={} on IVF_FLAT reader for column: {}", nprobe,
column);
+ }
+ }
+
+ /**
+ * Re-scores ANN candidates using exact distance from the forward index and
returns top-K.
+ */
+ @SuppressWarnings("unchecked")
+ private ImmutableRoaringBitmap applyExactRerank(ImmutableRoaringBitmap
annResults, float[] queryVector,
+ int topK, String column) {
+ // Max-heap: largest distance on top for efficient eviction
+ PriorityQueue<DocDistance> maxHeap = new PriorityQueue<>(topK + 1,
+ (a, b) -> Float.compare(b._distance, a._distance));
+
+ ForwardIndexReader rawReader = _forwardIndexReader;
+ try (ForwardIndexReaderContext context = rawReader.createContext()) {
+ org.roaringbitmap.IntIterator it = annResults.getIntIterator();
+ while (it.hasNext()) {
+ int docId = it.next();
+ float[] docVector = rawReader.getFloatMV(docId, context);
+ if (docVector == null || docVector.length == 0) {
+ continue;
+ }
+ // TODO: derive distance function from segment's vector index config
instead of hardcoding L2.
+ // Currently correct for EUCLIDEAN/L2; may produce suboptimal rerank
ordering for COSINE/DOT_PRODUCT.
+ float distance =
ExactVectorScanFilterOperator.computeL2SquaredDistance(queryVector, docVector);
+ if (maxHeap.size() < topK) {
Review Comment:
Exact rerank currently hard-codes L2-squared distance
(`computeL2SquaredDistance`) for rescoring, which will produce incorrect
ordering for segments configured with COSINE / INNER_PRODUCT / DOT_PRODUCT
distance functions. This is a correctness issue when `vectorExactRerank=true`.
Please rerank using the same distance function as the segment’s vector index
config (or expose it via the reader/config and branch accordingly).
##########
pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/vector/IvfFlatVectorIndexReader.java:
##########
@@ -0,0 +1,337 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.segment.local.segment.index.readers.vector;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.PriorityQueue;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
+import
org.apache.pinot.segment.local.segment.index.vector.IvfFlatVectorIndexCreator;
+import org.apache.pinot.segment.spi.V1Constants;
+import org.apache.pinot.segment.spi.index.creator.VectorIndexConfig;
+import org.apache.pinot.segment.spi.index.reader.NprobeAware;
+import org.apache.pinot.segment.spi.index.reader.VectorIndexReader;
+import org.apache.pinot.segment.spi.store.SegmentDirectoryPaths;
+import org.roaringbitmap.buffer.MutableRoaringBitmap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Reader for IVF_FLAT (Inverted File with flat vectors) index.
+ *
+ * <p>Loads the entire index into memory at construction time for fast search.
+ * The search algorithm:
+ * <ol>
+ * <li>Computes distance from the query to all centroids.</li>
+ * <li>Selects the {@code nprobe} closest centroids.</li>
+ * <li>Scans all vectors in those centroids' inverted lists.</li>
+ * <li>Returns the top-K doc IDs as a bitmap.</li>
+ * </ol>
+ *
+ * <h3>Thread safety</h3>
+ * <p>This class is thread-safe for concurrent reads. The loaded index data is
immutable
+ * after construction. The only mutable state is {@code _nprobe}, which is
volatile to
+ * allow query-time tuning from another thread. However, the typical pattern is
+ * single-threaded: set nprobe, then call getDocIds.</p>
+ */
+public class IvfFlatVectorIndexReader implements VectorIndexReader,
NprobeAware {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(IvfFlatVectorIndexReader.class);
+
+ /** Default nprobe value when not explicitly set. */
+ static final int DEFAULT_NPROBE = 4;
+
+ // Index data loaded from file
+ private final int _dimension;
+ private final int _numVectors;
+ private final int _nlist;
+ private final VectorIndexConfig.VectorDistanceFunction _distanceFunction;
+ private final float[][] _centroids;
+ private final int[][] _listDocIds;
+ private final float[][][] _listVectors;
+ private final String _column;
+
+ /** Number of centroids to probe during search. */
+ private volatile int _nprobe;
+
+ /**
+ * Opens and loads an IVF_FLAT index from disk.
+ *
+ * @param column the column name
+ * @param indexDir the segment index directory
+ * @param config the vector index configuration
+ * @throws RuntimeException if the index file cannot be read or is corrupt
+ */
+ public IvfFlatVectorIndexReader(String column, File indexDir,
VectorIndexConfig config) {
+ _column = column;
+
+ // Initialize nprobe to the default; query-time tuning should use
NprobeAware#setNprobe.
+ int configuredNprobe = DEFAULT_NPROBE;
+
+ File indexFile = SegmentDirectoryPaths.findVectorIndexIndexFile(indexDir,
column);
+ if (indexFile == null || !indexFile.exists()) {
+ throw new IllegalStateException(
+ "Failed to find IVF_FLAT index file for column: " + column + " in
dir: " + indexDir
+ + ". Expected file: " + column +
V1Constants.Indexes.VECTOR_IVF_FLAT_INDEX_FILE_EXTENSION);
+ }
+
+ try (DataInputStream in = new DataInputStream(new
FileInputStream(indexFile))) {
+ // --- Header ---
+ int magic = in.readInt();
+ Preconditions.checkState(magic == IvfFlatVectorIndexCreator.MAGIC,
+ "Invalid IVF_FLAT magic: 0x%s, expected 0x%s",
+ Integer.toHexString(magic),
Integer.toHexString(IvfFlatVectorIndexCreator.MAGIC));
+
+ int version = in.readInt();
+ Preconditions.checkState(version ==
IvfFlatVectorIndexCreator.FORMAT_VERSION,
+ "Unsupported IVF_FLAT format version: %s, expected: %s",
+ version, IvfFlatVectorIndexCreator.FORMAT_VERSION);
+
+ _dimension = in.readInt();
+ _numVectors = in.readInt();
+ _nlist = in.readInt();
+ int distanceFunctionOrdinal = in.readInt();
+ _distanceFunction =
VectorIndexConfig.VectorDistanceFunction.values()[distanceFunctionOrdinal];
+
+ // Clamp nprobe to valid range
+ _nprobe = Math.min(configuredNprobe, _nlist);
+ if (_nprobe <= 0) {
+ _nprobe = Math.min(DEFAULT_NPROBE, _nlist);
+ }
+
+ // --- Centroids ---
+ _centroids = new float[_nlist][_dimension];
+ for (int c = 0; c < _nlist; c++) {
+ for (int d = 0; d < _dimension; d++) {
+ _centroids[c][d] = in.readFloat();
+ }
+ }
+
+ // --- Inverted Lists ---
+ _listDocIds = new int[_nlist][];
+ _listVectors = new float[_nlist][][];
+
+ for (int c = 0; c < _nlist; c++) {
+ int listSize = in.readInt();
+ _listDocIds[c] = new int[listSize];
+ for (int i = 0; i < listSize; i++) {
+ _listDocIds[c][i] = in.readInt();
+ }
+ _listVectors[c] = new float[listSize][_dimension];
+ for (int i = 0; i < listSize; i++) {
+ for (int d = 0; d < _dimension; d++) {
+ _listVectors[c][i][d] = in.readFloat();
+ }
+ }
+ }
+
+ // We skip reading the offset table and footer since we read sequentially
+
+ LOGGER.info("Loaded IVF_FLAT index for column: {}: {} vectors, {}
centroids, dim={}, nprobe={}, distance={}",
+ column, _numVectors, _nlist, _dimension, _nprobe, _distanceFunction);
+ } catch (IOException e) {
+ throw new RuntimeException(
+ "Failed to load IVF_FLAT index for column: " + column + " from file:
" + indexFile, e);
+ }
+ }
+
+ @Override
+ public MutableRoaringBitmap getDocIds(float[] searchQuery, int topK) {
+ Preconditions.checkArgument(searchQuery.length == _dimension,
+ "Query dimension mismatch: expected %s, got %s", _dimension,
searchQuery.length);
+ Preconditions.checkArgument(topK > 0, "topK must be positive, got: %s",
topK);
+
+ if (_numVectors == 0 || _nlist == 0) {
+ return new MutableRoaringBitmap();
+ }
+
+ int effectiveNprobe = Math.min(_nprobe, _nlist);
+
+ // Step 1: Find the nprobe closest centroids
+ int[] probeCentroids = findClosestCentroids(searchQuery, effectiveNprobe);
+
+ // Step 2: Scan all vectors in the selected inverted lists, maintaining a
max-heap of size topK
+ // Max-heap: the largest distance is at the top, so we can efficiently
evict the worst candidate.
+ int effectiveTopK = Math.min(topK, _numVectors);
+ PriorityQueue<ScoredDoc> maxHeap = new PriorityQueue<>(effectiveTopK,
+ (a, b) -> Float.compare(b._distance, a._distance));
+
+ for (int probeIdx : probeCentroids) {
+ int[] docIds = _listDocIds[probeIdx];
+ float[][] vectors = _listVectors[probeIdx];
+
+ for (int i = 0; i < docIds.length; i++) {
+ float dist = computeDistance(searchQuery, vectors[i]);
+ if (maxHeap.size() < effectiveTopK) {
+ maxHeap.offer(new ScoredDoc(docIds[i], dist));
+ } else if (dist < maxHeap.peek()._distance) {
+ maxHeap.poll();
+ maxHeap.offer(new ScoredDoc(docIds[i], dist));
+ }
+ }
+ }
+
+ // Step 3: Collect results into a bitmap
+ MutableRoaringBitmap result = new MutableRoaringBitmap();
+ for (ScoredDoc doc : maxHeap) {
+ result.add(doc._docId);
+ }
+ return result;
+ }
+
+ /**
+ * Sets the number of centroids to probe during search.
+ * This allows query-time tuning of the recall/speed tradeoff.
+ *
+ * @param nprobe number of centroids to probe (clamped to [1, nlist])
+ *
+ * <p><b>Thread-safety note:</b> This method mutates a volatile field on the
shared reader instance.
+ * In Pinot's query execution model, nprobe is set once per query before
calling getDocIds(),
+ * and each query runs on a single thread per segment. A future improvement
could pass nprobe
+ * as a parameter to getDocIds() to eliminate any cross-query visibility
concern.</p>
+ */
+ public void setNprobe(int nprobe) {
+ _nprobe = Math.max(1, Math.min(nprobe, _nlist));
Review Comment:
`NprobeAware#setNprobe` Javadoc says implementations should throw
`IllegalArgumentException` when `nprobe < 1`, but this implementation silently
clamps values. Either update the interface contract, or validate and throw on
invalid nprobe to match the documented behavior (clamping to `nlist` is fine).
```suggestion
* @param nprobe number of centroids to probe (must be >= 1; values
greater than {@code nlist} are clamped)
* @throws IllegalArgumentException if {@code nprobe < 1}
*
* <p><b>Thread-safety note:</b> This method mutates a volatile field on
the shared reader instance.
* In Pinot's query execution model, nprobe is set once per query before
calling getDocIds(),
* and each query runs on a single thread per segment. A future
improvement could pass nprobe
* as a parameter to getDocIds() to eliminate any cross-query visibility
concern.</p>
*/
@Override
public void setNprobe(int nprobe) {
if (nprobe < 1) {
throw new IllegalArgumentException("nprobe must be >= 1, got: " +
nprobe);
}
_nprobe = Math.min(nprobe, _nlist);
```
##########
pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/vector/IvfFlatVectorIndexCreator.java:
##########
@@ -0,0 +1,561 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.segment.local.segment.index.vector;
+
+import com.google.common.base.Preconditions;
+import java.io.BufferedOutputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
+import org.apache.pinot.segment.spi.V1Constants;
+import org.apache.pinot.segment.spi.index.creator.VectorIndexConfig;
+import org.apache.pinot.segment.spi.index.creator.VectorIndexCreator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Creates an IVF_FLAT (Inverted File with flat vectors) index for immutable
segments.
+ *
+ * <p>The creator buffers all vectors in memory during {@link #add(float[])}
calls, then
+ * trains k-means centroids, assigns vectors to their nearest centroids, and
serializes
+ * the complete index to a single {@code .ivfflat.index} file during {@link
#seal()}.</p>
+ *
+ * <h3>Thread safety</h3>
+ * <p>This class is NOT thread-safe. It is designed for single-threaded
segment creation.</p>
+ *
+ * <h3>File format (version 1)</h3>
+ * <pre>
+ * [Header]
+ * magic: 4 bytes (0x49564646 = "IVFF")
+ * version: 4 bytes (1)
+ * dimension: 4 bytes
+ * numVectors: 4 bytes
+ * nlist: 4 bytes
+ * distanceFunctionOrd: 4 bytes
+ *
+ * [Centroids Section]
+ * nlist x dimension x 4 bytes (float32)
+ *
+ * [Inverted Lists Section]
+ * For each centroid i (0..nlist-1):
+ * listSize_i: 4 bytes
+ * docIds_i: listSize_i x 4 bytes (int32)
+ * vectors_i: listSize_i x dimension x 4 bytes (float32)
+ *
+ * [Inverted List Offsets]
+ * nlist x 8 bytes (long offset to start of each inverted list)
+ *
+ * [Footer]
+ * offsetToOffsets: 8 bytes (position of the offsets section)
+ * </pre>
+ *
+ * <p>All multi-byte values are written in big-endian order (Java {@link
DataOutputStream} default).</p>
+ */
+public class IvfFlatVectorIndexCreator implements VectorIndexCreator {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(IvfFlatVectorIndexCreator.class);
+
+ /** Magic bytes identifying an IVF_FLAT index file: ASCII "IVFF". */
+ public static final int MAGIC = 0x49564646;
+
+ /** Current file format version. */
+ public static final int FORMAT_VERSION = 1;
+
+ /** Default number of Voronoi cells (centroids). */
+ public static final int DEFAULT_NLIST = 128;
+
+ /** Maximum number of k-means iterations. */
+ static final int MAX_KMEANS_ITERATIONS = 50;
+
+ /** Convergence threshold: stop when centroid movement is below this
fraction. */
+ static final float CONVERGENCE_THRESHOLD = 1e-5f;
+
+ /** Default training sample size multiplier relative to nlist. */
+ static final int DEFAULT_TRAIN_SAMPLE_MULTIPLIER = 40;
+
+ /** Minimum training sample size. */
+ static final int DEFAULT_MIN_TRAIN_SAMPLE_SIZE = 10000;
+
+ private final String _column;
+ private final File _indexDir;
+ private final int _dimension;
+ private final int _nlist;
+ private final int _trainSampleSize;
+ private final long _trainingSeed;
+ private final VectorIndexConfig.VectorDistanceFunction _distanceFunction;
+
+ /** All vectors collected during add(), indexed by docId (ordinal). */
+ private final List<float[]> _vectors = new ArrayList<>();
+
+ private boolean _sealed = false;
+
+ /**
+ * Creates a new IVF_FLAT index creator.
+ *
+ * @param column the column name
+ * @param indexDir the segment index directory
+ * @param config the vector index configuration
+ */
+ public IvfFlatVectorIndexCreator(String column, File indexDir,
VectorIndexConfig config) {
+ _column = column;
+ _indexDir = indexDir;
+ _dimension = config.getVectorDimension();
+ _distanceFunction = config.getVectorDistanceFunction();
+
+ Map<String, String> properties = config.getProperties();
+ _nlist = properties != null && properties.containsKey("nlist")
+ ? Integer.parseInt(properties.get("nlist"))
+ : DEFAULT_NLIST;
+ _trainSampleSize = properties != null &&
properties.containsKey("trainSampleSize")
+ ? Integer.parseInt(properties.get("trainSampleSize"))
+ : Math.max(_nlist * DEFAULT_TRAIN_SAMPLE_MULTIPLIER,
DEFAULT_MIN_TRAIN_SAMPLE_SIZE);
+ _trainingSeed = properties != null &&
properties.containsKey("trainingSeed")
+ ? Long.parseLong(properties.get("trainingSeed"))
+ : System.nanoTime();
+
+ Preconditions.checkArgument(_dimension > 0, "Vector dimension must be
positive, got: %s", _dimension);
+ Preconditions.checkArgument(_nlist > 0, "nlist must be positive, got: %s",
_nlist);
+
+ LOGGER.info("Creating IVF_FLAT index for column: {} in dir: {},
dimension={}, nlist={}, distance={}",
+ column, indexDir.getAbsolutePath(), _dimension, _nlist,
_distanceFunction);
+ }
+
+ @Override
+ public void add(Object[] values, @Nullable int[] dictIds) {
+ // The segment builder calls this overload for multi-value columns.
+ // Convert Object[] (boxed Floats) to float[] and delegate to add(float[]).
+ float[] floatValues = new float[_dimension];
+ for (int i = 0; i < values.length; i++) {
+ floatValues[i] = (Float) values[i];
+ }
+ add(floatValues);
+ }
+
+ @Override
+ public void add(float[] document) {
+ Preconditions.checkState(!_sealed, "Cannot add documents after seal()");
+ Preconditions.checkArgument(document.length == _dimension,
+ "Vector dimension mismatch: expected %s, got %s", _dimension,
document.length);
+ _vectors.add(document.clone());
+ }
+
+ @Override
+ public void seal()
+ throws IOException {
+ Preconditions.checkState(!_sealed, "seal() already called");
+ _sealed = true;
+
+ int numVectors = _vectors.size();
+ if (numVectors == 0) {
+ LOGGER.warn("No vectors to index for column: {}. Writing empty index.",
_column);
+ writeIndex(new float[0][0], new int[0], new List[0], 0);
+ return;
+ }
+
+ // Determine effective nlist (cannot have more centroids than vectors)
+ int effectiveNlist = Math.min(_nlist, numVectors);
+ LOGGER.info("IVF_FLAT seal: column={}, numVectors={}, effectiveNlist={}",
_column, numVectors, effectiveNlist);
+
+ // Collect training samples
+ float[][] trainingSamples = collectTrainingSamples(numVectors,
effectiveNlist);
+
+ // Train centroids using k-means
+ float[][] centroids = trainKMeans(trainingSamples, effectiveNlist);
+
+ // Assign all vectors to their nearest centroids
+ int[] assignments = assignVectors(centroids);
+
+ // Build inverted lists
+ @SuppressWarnings("unchecked")
+ List<Integer>[] invertedLists = new List[effectiveNlist];
+ for (int i = 0; i < effectiveNlist; i++) {
+ invertedLists[i] = new ArrayList<>();
+ }
+ for (int docId = 0; docId < numVectors; docId++) {
+ invertedLists[assignments[docId]].add(docId);
+ }
+
+ // Write the index file
+ writeIndex(centroids, assignments, invertedLists, effectiveNlist);
+
+ LOGGER.info("IVF_FLAT index sealed for column: {}. {} vectors across {}
centroids.",
+ _column, numVectors, effectiveNlist);
+ }
+
+ @Override
+ public void close()
+ throws IOException {
+ // Release references to allow GC
+ _vectors.clear();
+ }
+
+ // -----------------------------------------------------------------------
+ // Training
+ // -----------------------------------------------------------------------
+
+ /**
+ * Collects a subsample of vectors for k-means training.
+ */
+ float[][] collectTrainingSamples(int numVectors, int effectiveNlist) {
+ int sampleSize = Math.min(_trainSampleSize, numVectors);
+ if (sampleSize >= numVectors) {
+ // Use all vectors for training
+ return _vectors.toArray(new float[0][]);
+ }
+
+ Random rng = new Random(_trainingSeed);
+ // Fisher-Yates partial shuffle to select sampleSize unique indices
+ int[] indices = new int[numVectors];
+ for (int i = 0; i < numVectors; i++) {
+ indices[i] = i;
+ }
+ for (int i = 0; i < sampleSize; i++) {
+ int j = i + rng.nextInt(numVectors - i);
+ int tmp = indices[i];
+ indices[i] = indices[j];
+ indices[j] = tmp;
+ }
+
+ float[][] samples = new float[sampleSize][];
+ for (int i = 0; i < sampleSize; i++) {
+ samples[i] = _vectors.get(indices[i]);
+ }
+ return samples;
+ }
+
+ /**
+ * Trains centroids using k-means++ initialization followed by Lloyd's
algorithm.
+ *
+ * @param samples the training vectors
+ * @param numCentroids the number of centroids to train
+ * @return the trained centroids
+ */
+ float[][] trainKMeans(float[][] samples, int numCentroids) {
+ int numSamples = samples.length;
+ if (numCentroids >= numSamples) {
+ // Use each sample as its own centroid
+ float[][] centroids = new float[numSamples][];
+ for (int i = 0; i < numSamples; i++) {
+ centroids[i] = samples[i].clone();
+ }
+ return centroids;
+ }
+
+ // k-means++ initialization
+ float[][] centroids = kMeansPlusPlusInit(samples, numCentroids);
+
+ // Lloyd's iterations
+ int[] assignments = new int[numSamples];
+ for (int iter = 0; iter < MAX_KMEANS_ITERATIONS; iter++) {
+ // Assign each sample to the nearest centroid
+ for (int i = 0; i < numSamples; i++) {
+ assignments[i] = findNearestCentroid(samples[i], centroids);
+ }
+
+ // Recompute centroids
+ float[][] newCentroids = new float[numCentroids][_dimension];
+ int[] counts = new int[numCentroids];
+ for (int i = 0; i < numSamples; i++) {
+ int cluster = assignments[i];
+ counts[cluster]++;
+ for (int d = 0; d < _dimension; d++) {
+ newCentroids[cluster][d] += samples[i][d];
+ }
+ }
+
+ // Finalize centroids (divide by count), handle empty clusters
+ float maxMovement = 0.0f;
+ for (int c = 0; c < numCentroids; c++) {
+ if (counts[c] == 0) {
+ // Empty cluster: keep old centroid
+ newCentroids[c] = centroids[c].clone();
+ } else {
+ for (int d = 0; d < _dimension; d++) {
+ newCentroids[c][d] /= counts[c];
+ }
+ }
+ // Track maximum centroid movement for convergence check
+ float movement = (float)
VectorFunctions.euclideanDistance(centroids[c], newCentroids[c]);
+ maxMovement = Math.max(maxMovement, movement);
+ }
+
+ centroids = newCentroids;
+
+ if (maxMovement < CONVERGENCE_THRESHOLD) {
+ LOGGER.debug("K-means converged at iteration {} with maxMovement={}",
iter, maxMovement);
+ break;
+ }
+ }
+
+ return centroids;
+ }
+
+ /**
+ * K-means++ initialization: selects initial centroids with probability
proportional
+ * to the squared distance from the nearest existing centroid.
+ */
+ private float[][] kMeansPlusPlusInit(float[][] samples, int numCentroids) {
+ int numSamples = samples.length;
+ Random rng = new Random(_trainingSeed);
+
+ float[][] centroids = new float[numCentroids][];
+ // Pick first centroid uniformly at random
+ centroids[0] = samples[rng.nextInt(numSamples)].clone();
+
+ // Distances from each sample to the nearest chosen centroid
+ float[] minDistances = new float[numSamples];
+ Arrays.fill(minDistances, Float.MAX_VALUE);
+
+ for (int c = 1; c < numCentroids; c++) {
+ // Update minimum distances with the most recently added centroid
+ float totalWeight = 0.0f;
+ for (int i = 0; i < numSamples; i++) {
+ float dist = computeTrainingDistance(samples[i], centroids[c - 1]);
+ if (dist < minDistances[i]) {
+ minDistances[i] = dist;
+ }
+ totalWeight += minDistances[i];
+ }
+
+ // Weighted random selection
+ float target = rng.nextFloat() * totalWeight;
+ float cumulative = 0.0f;
+ int selected = numSamples - 1; // fallback
+ for (int i = 0; i < numSamples; i++) {
+ cumulative += minDistances[i];
+ if (cumulative >= target) {
+ selected = i;
+ break;
+ }
+ }
+ centroids[c] = samples[selected].clone();
+ }
+
+ return centroids;
+ }
+
+ /**
+ * Computes distance used for training. Always uses L2 squared distance for
k-means
+ * training regardless of the configured distance function, because k-means
minimizes
+ * squared Euclidean distance by construction.
+ *
+ * <p>For COSINE distance, we normalize vectors before computing L2, which
is equivalent
+ * to using angular distance for clustering.</p>
+ */
+ private float computeTrainingDistance(float[] a, float[] b) {
+ // For cosine distance, use L2 on normalized vectors which groups by
angular similarity
+ if (_distanceFunction == VectorIndexConfig.VectorDistanceFunction.COSINE) {
+ return (float) VectorFunctions.euclideanDistance(normalizeVector(a),
normalizeVector(b));
+ }
+ return (float) VectorFunctions.euclideanDistance(a, b);
+ }
+
+ // -----------------------------------------------------------------------
+ // Assignment
+ // -----------------------------------------------------------------------
+
+ /**
+ * Assigns each vector to its nearest centroid using the configured distance
function.
+ */
+ private int[] assignVectors(float[][] centroids) {
+ int numVectors = _vectors.size();
+ int[] assignments = new int[numVectors];
+ for (int i = 0; i < numVectors; i++) {
+ assignments[i] = findNearestCentroidForSearch(_vectors.get(i),
centroids);
+ }
+ return assignments;
+ }
+
+ /**
+ * Finds the index of the nearest centroid to the given vector using L2
distance
+ * (used during k-means training).
+ */
+ private int findNearestCentroid(float[] vector, float[][] centroids) {
+ int nearest = 0;
+ float nearestDist = Float.MAX_VALUE;
+ for (int c = 0; c < centroids.length; c++) {
+ float dist = computeTrainingDistance(vector, centroids[c]);
+ if (dist < nearestDist) {
+ nearestDist = dist;
+ nearest = c;
+ }
+ }
+ return nearest;
+ }
+
+ /**
+ * Finds the index of the nearest centroid to the given vector using the
configured
+ * distance function (used during vector assignment after training).
+ */
+ private int findNearestCentroidForSearch(float[] vector, float[][]
centroids) {
+ int nearest = 0;
+ float nearestDist = Float.MAX_VALUE;
+ for (int c = 0; c < centroids.length; c++) {
+ float dist = computeDistance(vector, centroids[c]);
+ if (dist < nearestDist) {
+ nearestDist = dist;
+ nearest = c;
+ }
+ }
+ return nearest;
+ }
+
+ // -----------------------------------------------------------------------
+ // Distance computation helpers (delegates to VectorFunctions)
+ // -----------------------------------------------------------------------
+
+ /**
+ * Computes distance between two vectors using the configured distance
function.
+ * Internally uses L2 for EUCLIDEAN/L2, cosine for COSINE, negative dot for
INNER_PRODUCT/DOT_PRODUCT.
+ */
+ private float computeDistance(float[] a, float[] b) {
+ switch (_distanceFunction) {
+ case EUCLIDEAN:
+ case L2:
+ return (float) VectorFunctions.euclideanDistance(a, b);
+ case COSINE:
+ return (float) VectorFunctions.cosineDistance(a, b);
+ case INNER_PRODUCT:
+ case DOT_PRODUCT:
+ return (float) -VectorFunctions.dotProduct(a, b);
+ default:
+ throw new IllegalArgumentException("Unsupported distance function: " +
_distanceFunction);
+ }
+ }
+
+ /**
+ * Returns a new unit-length copy of the given vector.
+ * If the vector has zero magnitude, a zero vector of the same length is
returned.
+ */
+ private static float[] normalizeVector(float[] vector) {
+ float norm = 0.0f;
+ for (float v : vector) {
+ norm += v * v;
+ }
+ norm = (float) Math.sqrt(norm);
+ float[] result = new float[vector.length];
+ if (norm > 0.0f) {
+ for (int i = 0; i < vector.length; i++) {
+ result[i] = vector[i] / norm;
+ }
+ }
+ return result;
+ }
+
+ // -----------------------------------------------------------------------
+ // Serialization
+ // -----------------------------------------------------------------------
+
+ /**
+ * Writes the complete IVF_FLAT index to disk.
+ */
+ private void writeIndex(float[][] centroids, int[] assignments,
List<Integer>[] invertedLists, int effectiveNlist)
+ throws IOException {
+ File indexFile = new File(_indexDir, _column +
V1Constants.Indexes.VECTOR_IVF_FLAT_INDEX_FILE_EXTENSION);
+ int numVectors = _vectors.size();
+
+ try (DataOutputStream out = new DataOutputStream(new
BufferedOutputStream(new FileOutputStream(indexFile)))) {
+ // --- Header ---
+ out.writeInt(MAGIC);
+ out.writeInt(FORMAT_VERSION);
+ out.writeInt(_dimension);
+ out.writeInt(numVectors);
+ out.writeInt(effectiveNlist);
+ out.writeInt(_distanceFunction.ordinal());
Review Comment:
The IVF_FLAT file format persists the distance function using
`enum.ordinal()`. This is fragile because adding/reordering enum constants in
`VectorDistanceFunction` will make existing index files unreadable. Since this
is a new on-disk format, consider writing a stable identifier (e.g., the enum
name as a string, or an explicit numeric code you control) instead of the
ordinal.
##########
pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/vector/IvfFlatVectorIndexReader.java:
##########
@@ -0,0 +1,337 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.segment.local.segment.index.readers.vector;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.PriorityQueue;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
+import
org.apache.pinot.segment.local.segment.index.vector.IvfFlatVectorIndexCreator;
+import org.apache.pinot.segment.spi.V1Constants;
+import org.apache.pinot.segment.spi.index.creator.VectorIndexConfig;
+import org.apache.pinot.segment.spi.index.reader.NprobeAware;
+import org.apache.pinot.segment.spi.index.reader.VectorIndexReader;
+import org.apache.pinot.segment.spi.store.SegmentDirectoryPaths;
+import org.roaringbitmap.buffer.MutableRoaringBitmap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Reader for IVF_FLAT (Inverted File with flat vectors) index.
+ *
+ * <p>Loads the entire index into memory at construction time for fast search.
+ * The search algorithm:
+ * <ol>
+ * <li>Computes distance from the query to all centroids.</li>
+ * <li>Selects the {@code nprobe} closest centroids.</li>
+ * <li>Scans all vectors in those centroids' inverted lists.</li>
+ * <li>Returns the top-K doc IDs as a bitmap.</li>
+ * </ol>
+ *
+ * <h3>Thread safety</h3>
+ * <p>This class is thread-safe for concurrent reads. The loaded index data is
immutable
+ * after construction. The only mutable state is {@code _nprobe}, which is
volatile to
+ * allow query-time tuning from another thread. However, the typical pattern is
+ * single-threaded: set nprobe, then call getDocIds.</p>
+ */
+public class IvfFlatVectorIndexReader implements VectorIndexReader,
NprobeAware {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(IvfFlatVectorIndexReader.class);
+
+ /** Default nprobe value when not explicitly set. */
+ static final int DEFAULT_NPROBE = 4;
+
+ // Index data loaded from file
+ private final int _dimension;
+ private final int _numVectors;
+ private final int _nlist;
+ private final VectorIndexConfig.VectorDistanceFunction _distanceFunction;
+ private final float[][] _centroids;
+ private final int[][] _listDocIds;
+ private final float[][][] _listVectors;
+ private final String _column;
+
+ /** Number of centroids to probe during search. */
+ private volatile int _nprobe;
+
+ /**
+ * Opens and loads an IVF_FLAT index from disk.
+ *
+ * @param column the column name
+ * @param indexDir the segment index directory
+ * @param config the vector index configuration
+ * @throws RuntimeException if the index file cannot be read or is corrupt
+ */
+ public IvfFlatVectorIndexReader(String column, File indexDir,
VectorIndexConfig config) {
+ _column = column;
+
+ // Initialize nprobe to the default; query-time tuning should use
NprobeAware#setNprobe.
+ int configuredNprobe = DEFAULT_NPROBE;
+
+ File indexFile = SegmentDirectoryPaths.findVectorIndexIndexFile(indexDir,
column);
+ if (indexFile == null || !indexFile.exists()) {
+ throw new IllegalStateException(
+ "Failed to find IVF_FLAT index file for column: " + column + " in
dir: " + indexDir
+ + ". Expected file: " + column +
V1Constants.Indexes.VECTOR_IVF_FLAT_INDEX_FILE_EXTENSION);
+ }
+
+ try (DataInputStream in = new DataInputStream(new
FileInputStream(indexFile))) {
+ // --- Header ---
+ int magic = in.readInt();
+ Preconditions.checkState(magic == IvfFlatVectorIndexCreator.MAGIC,
+ "Invalid IVF_FLAT magic: 0x%s, expected 0x%s",
+ Integer.toHexString(magic),
Integer.toHexString(IvfFlatVectorIndexCreator.MAGIC));
+
+ int version = in.readInt();
+ Preconditions.checkState(version ==
IvfFlatVectorIndexCreator.FORMAT_VERSION,
+ "Unsupported IVF_FLAT format version: %s, expected: %s",
+ version, IvfFlatVectorIndexCreator.FORMAT_VERSION);
+
+ _dimension = in.readInt();
+ _numVectors = in.readInt();
+ _nlist = in.readInt();
+ int distanceFunctionOrdinal = in.readInt();
+ _distanceFunction =
VectorIndexConfig.VectorDistanceFunction.values()[distanceFunctionOrdinal];
Review Comment:
The reader trusts the persisted `distanceFunctionOrdinal` and indexes into
`VectorDistanceFunction.values()` without bounds checking. A corrupt/unknown
value will throw `ArrayIndexOutOfBoundsException` and bypass the intended
validation/error message path. Please validate the ordinal range and fail with
a clear exception; ideally also avoid ordinal-based serialization altogether
(use a stable id/name).
```suggestion
VectorIndexConfig.VectorDistanceFunction[] distanceFunctions =
VectorIndexConfig.VectorDistanceFunction.values();
Preconditions.checkState(distanceFunctionOrdinal >= 0 &&
distanceFunctionOrdinal < distanceFunctions.length,
"Unsupported IVF_FLAT distance function ordinal: %s for column:
%s, file: %s",
distanceFunctionOrdinal, column, indexFile);
_distanceFunction = distanceFunctions[distanceFunctionOrdinal];
```
##########
pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/readers/vector/IvfFlatVectorIndexReader.java:
##########
@@ -0,0 +1,337 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.segment.local.segment.index.readers.vector;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.PriorityQueue;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
+import
org.apache.pinot.segment.local.segment.index.vector.IvfFlatVectorIndexCreator;
+import org.apache.pinot.segment.spi.V1Constants;
+import org.apache.pinot.segment.spi.index.creator.VectorIndexConfig;
+import org.apache.pinot.segment.spi.index.reader.NprobeAware;
+import org.apache.pinot.segment.spi.index.reader.VectorIndexReader;
+import org.apache.pinot.segment.spi.store.SegmentDirectoryPaths;
+import org.roaringbitmap.buffer.MutableRoaringBitmap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Reader for IVF_FLAT (Inverted File with flat vectors) index.
+ *
+ * <p>Loads the entire index into memory at construction time for fast search.
+ * The search algorithm:
+ * <ol>
+ * <li>Computes distance from the query to all centroids.</li>
+ * <li>Selects the {@code nprobe} closest centroids.</li>
+ * <li>Scans all vectors in those centroids' inverted lists.</li>
+ * <li>Returns the top-K doc IDs as a bitmap.</li>
+ * </ol>
+ *
+ * <h3>Thread safety</h3>
+ * <p>This class is thread-safe for concurrent reads. The loaded index data is
immutable
+ * after construction. The only mutable state is {@code _nprobe}, which is
volatile to
+ * allow query-time tuning from another thread. However, the typical pattern is
+ * single-threaded: set nprobe, then call getDocIds.</p>
+ */
+public class IvfFlatVectorIndexReader implements VectorIndexReader,
NprobeAware {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(IvfFlatVectorIndexReader.class);
+
+ /** Default nprobe value when not explicitly set. */
+ static final int DEFAULT_NPROBE = 4;
+
+ // Index data loaded from file
+ private final int _dimension;
+ private final int _numVectors;
+ private final int _nlist;
+ private final VectorIndexConfig.VectorDistanceFunction _distanceFunction;
+ private final float[][] _centroids;
+ private final int[][] _listDocIds;
+ private final float[][][] _listVectors;
+ private final String _column;
+
+ /** Number of centroids to probe during search. */
+ private volatile int _nprobe;
+
+ /**
+ * Opens and loads an IVF_FLAT index from disk.
+ *
+ * @param column the column name
+ * @param indexDir the segment index directory
+ * @param config the vector index configuration
+ * @throws RuntimeException if the index file cannot be read or is corrupt
+ */
+ public IvfFlatVectorIndexReader(String column, File indexDir,
VectorIndexConfig config) {
+ _column = column;
+
+ // Initialize nprobe to the default; query-time tuning should use
NprobeAware#setNprobe.
+ int configuredNprobe = DEFAULT_NPROBE;
+
+ File indexFile = SegmentDirectoryPaths.findVectorIndexIndexFile(indexDir,
column);
+ if (indexFile == null || !indexFile.exists()) {
+ throw new IllegalStateException(
+ "Failed to find IVF_FLAT index file for column: " + column + " in
dir: " + indexDir
+ + ". Expected file: " + column +
V1Constants.Indexes.VECTOR_IVF_FLAT_INDEX_FILE_EXTENSION);
+ }
+
+ try (DataInputStream in = new DataInputStream(new
FileInputStream(indexFile))) {
+ // --- Header ---
+ int magic = in.readInt();
+ Preconditions.checkState(magic == IvfFlatVectorIndexCreator.MAGIC,
+ "Invalid IVF_FLAT magic: 0x%s, expected 0x%s",
+ Integer.toHexString(magic),
Integer.toHexString(IvfFlatVectorIndexCreator.MAGIC));
+
+ int version = in.readInt();
+ Preconditions.checkState(version ==
IvfFlatVectorIndexCreator.FORMAT_VERSION,
+ "Unsupported IVF_FLAT format version: %s, expected: %s",
+ version, IvfFlatVectorIndexCreator.FORMAT_VERSION);
+
+ _dimension = in.readInt();
+ _numVectors = in.readInt();
+ _nlist = in.readInt();
+ int distanceFunctionOrdinal = in.readInt();
+ _distanceFunction =
VectorIndexConfig.VectorDistanceFunction.values()[distanceFunctionOrdinal];
+
+ // Clamp nprobe to valid range
+ _nprobe = Math.min(configuredNprobe, _nlist);
+ if (_nprobe <= 0) {
+ _nprobe = Math.min(DEFAULT_NPROBE, _nlist);
+ }
+
+ // --- Centroids ---
+ _centroids = new float[_nlist][_dimension];
+ for (int c = 0; c < _nlist; c++) {
+ for (int d = 0; d < _dimension; d++) {
+ _centroids[c][d] = in.readFloat();
+ }
+ }
+
+ // --- Inverted Lists ---
+ _listDocIds = new int[_nlist][];
+ _listVectors = new float[_nlist][][];
+
+ for (int c = 0; c < _nlist; c++) {
+ int listSize = in.readInt();
+ _listDocIds[c] = new int[listSize];
+ for (int i = 0; i < listSize; i++) {
+ _listDocIds[c][i] = in.readInt();
+ }
+ _listVectors[c] = new float[listSize][_dimension];
+ for (int i = 0; i < listSize; i++) {
+ for (int d = 0; d < _dimension; d++) {
+ _listVectors[c][i][d] = in.readFloat();
+ }
+ }
+ }
+
+ // We skip reading the offset table and footer since we read sequentially
+
+ LOGGER.info("Loaded IVF_FLAT index for column: {}: {} vectors, {}
centroids, dim={}, nprobe={}, distance={}",
+ column, _numVectors, _nlist, _dimension, _nprobe, _distanceFunction);
+ } catch (IOException e) {
+ throw new RuntimeException(
+ "Failed to load IVF_FLAT index for column: " + column + " from file:
" + indexFile, e);
+ }
+ }
+
+ @Override
+ public MutableRoaringBitmap getDocIds(float[] searchQuery, int topK) {
+ Preconditions.checkArgument(searchQuery.length == _dimension,
+ "Query dimension mismatch: expected %s, got %s", _dimension,
searchQuery.length);
+ Preconditions.checkArgument(topK > 0, "topK must be positive, got: %s",
topK);
+
+ if (_numVectors == 0 || _nlist == 0) {
+ return new MutableRoaringBitmap();
+ }
+
+ int effectiveNprobe = Math.min(_nprobe, _nlist);
+
+ // Step 1: Find the nprobe closest centroids
+ int[] probeCentroids = findClosestCentroids(searchQuery, effectiveNprobe);
+
+ // Step 2: Scan all vectors in the selected inverted lists, maintaining a
max-heap of size topK
+ // Max-heap: the largest distance is at the top, so we can efficiently
evict the worst candidate.
+ int effectiveTopK = Math.min(topK, _numVectors);
+ PriorityQueue<ScoredDoc> maxHeap = new PriorityQueue<>(effectiveTopK,
+ (a, b) -> Float.compare(b._distance, a._distance));
+
+ for (int probeIdx : probeCentroids) {
+ int[] docIds = _listDocIds[probeIdx];
+ float[][] vectors = _listVectors[probeIdx];
+
+ for (int i = 0; i < docIds.length; i++) {
+ float dist = computeDistance(searchQuery, vectors[i]);
+ if (maxHeap.size() < effectiveTopK) {
+ maxHeap.offer(new ScoredDoc(docIds[i], dist));
+ } else if (dist < maxHeap.peek()._distance) {
+ maxHeap.poll();
+ maxHeap.offer(new ScoredDoc(docIds[i], dist));
+ }
+ }
+ }
+
+ // Step 3: Collect results into a bitmap
+ MutableRoaringBitmap result = new MutableRoaringBitmap();
+ for (ScoredDoc doc : maxHeap) {
+ result.add(doc._docId);
+ }
+ return result;
+ }
+
+ /**
+ * Sets the number of centroids to probe during search.
+ * This allows query-time tuning of the recall/speed tradeoff.
+ *
+ * @param nprobe number of centroids to probe (clamped to [1, nlist])
+ *
+ * <p><b>Thread-safety note:</b> This method mutates a volatile field on the
shared reader instance.
+ * In Pinot's query execution model, nprobe is set once per query before
calling getDocIds(),
+ * and each query runs on a single thread per segment. A future improvement
could pass nprobe
+ * as a parameter to getDocIds() to eliminate any cross-query visibility
concern.</p>
+ */
+ public void setNprobe(int nprobe) {
+ _nprobe = Math.max(1, Math.min(nprobe, _nlist));
+ }
Review Comment:
This reader stores `nprobe` in a mutable field (`volatile _nprobe`) and
updates it via `setNprobe()`. Because index readers are shared across
concurrent queries for a segment, this can cause cross-query interference when
different queries set different nprobe values concurrently. Please avoid
query-specific mutable state on the shared reader (e.g., pass nprobe into the
search method, or return a per-query searcher/context object).
##########
pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/index/creator/VectorIndexConfigValidator.java:
##########
@@ -0,0 +1,213 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.segment.spi.index.creator;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+
+/**
+ * Validates {@link VectorIndexConfig} for backend-specific correctness.
+ *
+ * <p>This validator ensures that:
+ * <ul>
+ * <li>Required common fields (vectorDimension, vectorDistanceFunction) are
present and valid.</li>
+ * <li>The vectorIndexType resolves to a known {@link
VectorBackendType}.</li>
+ * <li>Backend-specific properties are valid for the resolved backend
type.</li>
+ * <li>Properties belonging to a different backend are rejected with a clear
error message.</li>
+ * </ul>
+ *
+ * <p>Thread-safe: this class is stateless and all methods are static.</p>
+ */
+public final class VectorIndexConfigValidator {
+
+ // HNSW-specific property keys
+ static final Set<String> HNSW_PROPERTIES = Collections.unmodifiableSet(new
HashSet<>(
+ Arrays.asList("maxCon", "beamWidth", "maxDimensions", "maxBufferSizeMB",
+ "useCompoundFile", "mode", "commit", "commitIntervalMs",
"commitDocs")));
+
+ // IVF_FLAT-specific property keys
+ static final Set<String> IVF_FLAT_PROPERTIES =
Collections.unmodifiableSet(new HashSet<>(
+ Arrays.asList("nlist", "trainSampleSize", "trainingSeed",
"minRowsForIndex")));
+
+ // Common property keys that appear in the properties map (legacy format
stores common fields there too)
+ private static final Set<String> COMMON_PROPERTIES =
Collections.unmodifiableSet(new HashSet<>(
+ Arrays.asList("vectorIndexType", "vectorDimension",
"vectorDistanceFunction", "version")));
+
+ private VectorIndexConfigValidator() {
+ }
+
+ /**
+ * Validates the given {@link VectorIndexConfig} for backend-specific
correctness.
+ *
+ * @param config the config to validate
+ * @throws IllegalArgumentException if validation fails
+ */
+ public static void validate(VectorIndexConfig config) {
+ if (config.isDisabled()) {
+ return;
+ }
+
+ VectorBackendType backendType = resolveBackendType(config);
+ validateCommonFields(config);
+ validateBackendSpecificProperties(config, backendType);
+ }
+
+ /**
+ * Resolves the {@link VectorBackendType} from the config. Defaults to HNSW
if the
+ * vectorIndexType field is null or empty, preserving backward compatibility.
+ *
+ * @param config the config to resolve from
+ * @return the resolved backend type
+ * @throws IllegalArgumentException if the vectorIndexType is not recognized
+ */
+ public static VectorBackendType resolveBackendType(VectorIndexConfig config)
{
+ String typeString = config.getVectorIndexType();
+ if (typeString == null || typeString.isEmpty()) {
+ return VectorBackendType.HNSW;
+ }
+ return VectorBackendType.fromString(typeString);
+ }
+
+ /**
+ * Validates common fields shared across all backend types.
+ */
+ private static void validateCommonFields(VectorIndexConfig config) {
+ if (config.getVectorDimension() <= 0) {
+ throw new IllegalArgumentException(
+ "vectorDimension must be a positive integer, got: " +
config.getVectorDimension());
+ }
+
+ if (config.getVectorDistanceFunction() == null) {
+ throw new IllegalArgumentException("vectorDistanceFunction is required");
+ }
+ }
+
+ /**
+ * Validates that the properties map only contains keys valid for the
resolved backend type,
+ * and that backend-specific property values are within acceptable ranges.
+ */
+ private static void validateBackendSpecificProperties(VectorIndexConfig
config, VectorBackendType backendType) {
+ Map<String, String> properties = config.getProperties();
+ if (properties == null || properties.isEmpty()) {
+ return;
+ }
+
+ switch (backendType) {
+ case HNSW:
+ validateNoForeignProperties(properties, HNSW_PROPERTIES,
IVF_FLAT_PROPERTIES, "HNSW", "IVF_FLAT");
+ validateHnswProperties(properties);
+ break;
+ case IVF_FLAT:
+ validateNoForeignProperties(properties, IVF_FLAT_PROPERTIES,
HNSW_PROPERTIES, "IVF_FLAT", "HNSW");
+ validateIvfFlatProperties(properties);
+ break;
+ default:
+ throw new IllegalArgumentException("Unsupported vector backend type: "
+ backendType);
+ }
+ }
+
+ /**
+ * Ensures that properties belonging to a foreign backend are not present.
+ * Note: this only rejects known foreign-backend keys; arbitrary unknown
keys are allowed
+ * to support forward-compatible extensibility.
+ */
+ private static void validateNoForeignProperties(Map<String, String>
properties,
+ Set<String> ownProperties, Set<String> foreignProperties,
+ String ownType, String foreignType) {
Review Comment:
`validateNoForeignProperties(...)` takes an `ownProperties` parameter but
never uses it. This is confusing and may trip unused-parameter/static-analysis
rules. Consider removing the parameter (or using it, e.g., to optionally
validate/whitelist known keys) to keep the API and implementation consistent.
##########
pinot-core/src/main/java/org/apache/pinot/core/operator/filter/VectorSimilarityFilterOperator.java:
##########
@@ -120,6 +157,106 @@ protected void explainAttributes(ExplainAttributeBuilder
attributeBuilder) {
attributeBuilder.putString("vectorIdentifier",
_predicate.getLhs().getIdentifier());
attributeBuilder.putString("vectorLiteral",
Arrays.toString(_predicate.getValue()));
attributeBuilder.putLongIdempotent("topKtoSearch", _predicate.getTopK());
+ if (_searchParams.isExactRerank()) {
+ attributeBuilder.putString("exactRerank", "true");
+ }
+ }
+
+ /**
+ * Executes the vector search with backend-specific parameter dispatch and
optional rerank.
+ */
+ private ImmutableRoaringBitmap executeSearch() {
+ String column = _predicate.getLhs().getIdentifier();
+ float[] queryVector = _predicate.getValue();
+ int topK = _predicate.getTopK();
+
+ // 1. Configure backend-specific parameters via interfaces
+ configureBackendParams(column);
+
+ // 2. Determine effective search count (higher if rerank is enabled)
+ int searchCount = topK;
+ if (_searchParams.isExactRerank()) {
+ searchCount = _searchParams.getEffectiveMaxCandidates(topK);
+ }
+
+ // 3. Execute ANN search
+ ImmutableRoaringBitmap annResults =
_vectorIndexReader.getDocIds(queryVector, searchCount);
+ int annCandidateCount = annResults.getCardinality();
+
+ LOGGER.debug("Vector search on column: {}, backend: {}, topK: {},
searchCount: {}, annCandidates: {}",
+ column, getBackendName(), topK, searchCount, annCandidateCount);
+
+ // 4. Apply exact rerank if requested
+ if (_searchParams.isExactRerank() && _forwardIndexReader != null &&
annCandidateCount > 0) {
+ ImmutableRoaringBitmap reranked = applyExactRerank(annResults,
queryVector, topK, column);
+ LOGGER.debug("Exact rerank on column: {}, candidates: {} -> final: {}",
+ column, annCandidateCount, reranked.getCardinality());
+ return reranked;
+ }
+
+ return annResults;
+ }
+
+ /**
+ * Configures backend-specific search parameters on the reader if it
supports them.
+ */
+ private void configureBackendParams(String column) {
+ // Set nprobe on IVF_FLAT readers
+ if (_vectorIndexReader instanceof NprobeAware) {
+ int nprobe = _searchParams.getNprobe();
+ ((NprobeAware) _vectorIndexReader).setNprobe(nprobe);
+ LOGGER.debug("Set nprobe={} on IVF_FLAT reader for column: {}", nprobe,
column);
+ }
Review Comment:
`configureBackendParams()` mutates shared `VectorIndexReader` state via
`NprobeAware#setNprobe`. Index readers are created once per segment and shared
across concurrent queries, so two queries on the same segment with different
`vectorNprobe` values can race and affect each other’s results. Please avoid
per-query mutable state on the shared reader (e.g., pass nprobe into the search
call, create a per-query search context/object, or ensure nprobe is stored in a
query-local structure rather than a field on the reader).
##########
pinot-core/src/main/java/org/apache/pinot/core/operator/filter/ExactVectorScanFilterOperator.java:
##########
@@ -0,0 +1,223 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.filter;
+
+import com.google.common.base.CaseFormat;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.PriorityQueue;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
+import
org.apache.pinot.common.request.context.predicate.VectorSimilarityPredicate;
+import org.apache.pinot.core.common.BlockDocIdSet;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.ExplainAttributeBuilder;
+import org.apache.pinot.core.operator.docidsets.BitmapDocIdSet;
+import org.apache.pinot.segment.spi.index.reader.ForwardIndexReader;
+import org.apache.pinot.segment.spi.index.reader.ForwardIndexReaderContext;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.trace.FilterType;
+import org.apache.pinot.spi.trace.InvocationRecording;
+import org.apache.pinot.spi.trace.Tracing;
+import org.roaringbitmap.buffer.ImmutableRoaringBitmap;
+import org.roaringbitmap.buffer.MutableRoaringBitmap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Fallback operator that performs exact brute-force vector similarity search
by scanning the forward index.
+ *
+ * <p>This operator is used when no ANN vector index exists on a segment for
the target column
+ * (e.g., the segment was built before the vector index was added, or the
index type is not
+ * supported). It reads all vectors from the forward index, computes exact
distances to the
+ * query vector, and returns the top-K closest document IDs.</p>
+ *
+ * <p>The distance computation uses L2 (Euclidean) squared distance. For
COSINE similarity,
+ * vectors should be pre-normalized. This matches the behavior of Lucene's
HNSW implementation.</p>
+ *
+ * <p>This operator is intentionally simple and correct rather than fast -- it
is a safety net.
+ * A warning is logged when this operator is used because it scans all
documents in the segment.</p>
+ *
+ * <p>This class is thread-safe for single-threaded execution per query (same
as other filter operators).</p>
+ */
+public class ExactVectorScanFilterOperator extends BaseFilterOperator {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(ExactVectorScanFilterOperator.class);
+ private static final String EXPLAIN_NAME = "VECTOR_SIMILARITY_EXACT_SCAN";
+
+ private final ForwardIndexReader<?> _forwardIndexReader;
+ private final VectorSimilarityPredicate _predicate;
+ private final String _column;
+ private ImmutableRoaringBitmap _matches;
+
+ /**
+ * Creates an exact scan operator.
+ *
+ * @param forwardIndexReader the forward index reader for the vector column
+ * @param predicate the vector similarity predicate containing query vector
and top-K
+ * @param column the column name (for logging and explain)
+ * @param numDocs the total number of documents in the segment
+ */
+ public ExactVectorScanFilterOperator(ForwardIndexReader<?>
forwardIndexReader,
+ VectorSimilarityPredicate predicate, String column, int numDocs) {
+ super(numDocs, false);
+ _forwardIndexReader = forwardIndexReader;
+ _predicate = predicate;
+ _column = column;
+ }
+
+ @Override
+ protected BlockDocIdSet getTrues() {
+ if (_matches == null) {
+ _matches = computeExactTopK();
+ }
+ return new BitmapDocIdSet(_matches, _numDocs);
+ }
+
+ @Override
+ public int getNumMatchingDocs() {
+ if (_matches == null) {
+ _matches = computeExactTopK();
+ }
+ return _matches.getCardinality();
+ }
+
+ @Override
+ public boolean canProduceBitmaps() {
+ return true;
+ }
+
+ @Override
+ public BitmapCollection getBitmaps() {
+ if (_matches == null) {
+ _matches = computeExactTopK();
+ }
+ record(_matches);
+ return new BitmapCollection(_numDocs, false, _matches);
+ }
+
+ @Override
+ public List<Operator> getChildOperators() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public String toExplainString() {
+ return EXPLAIN_NAME + "(indexLookUp:exact_scan"
+ + ", operator:" + _predicate.getType()
+ + ", vector identifier:" + _column
+ + ", vector literal:" + Arrays.toString(_predicate.getValue())
+ + ", topK to search:" + _predicate.getTopK()
+ + ')';
+ }
+
+ @Override
+ protected String getExplainName() {
+ return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL,
EXPLAIN_NAME);
+ }
+
+ @Override
+ protected void explainAttributes(ExplainAttributeBuilder attributeBuilder) {
+ super.explainAttributes(attributeBuilder);
+ attributeBuilder.putString("indexLookUp", "exact_scan");
+ attributeBuilder.putString("operator", _predicate.getType().name());
+ attributeBuilder.putString("vectorIdentifier", _column);
+ attributeBuilder.putString("vectorLiteral",
Arrays.toString(_predicate.getValue()));
+ attributeBuilder.putLongIdempotent("topKtoSearch", _predicate.getTopK());
+ }
+
+ /**
+ * Performs brute-force exact search over all documents in the segment.
+ * Uses a max-heap to maintain the top-K closest vectors.
+ */
+ @SuppressWarnings("unchecked")
+ private ImmutableRoaringBitmap computeExactTopK() {
+ LOGGER.warn("Performing exact vector scan fallback on column: {} for
segment with {} docs. "
+ + "This is expensive -- consider adding a vector index.", _column,
_numDocs);
+
+ float[] queryVector = _predicate.getValue();
+ int topK = _predicate.getTopK();
+
+ // Max-heap: entry with largest distance is at the top so we can
efficiently evict it
+ PriorityQueue<DocDistance> maxHeap = new PriorityQueue<>(topK + 1,
+ (a, b) -> Float.compare(b._distance, a._distance));
+
+ ForwardIndexReader rawReader = _forwardIndexReader;
+ try (ForwardIndexReaderContext context = rawReader.createContext()) {
+ for (int docId = 0; docId < _numDocs; docId++) {
+ float[] docVector = rawReader.getFloatMV(docId, context);
+ if (docVector == null || docVector.length == 0) {
+ continue;
+ }
+ float distance = computeL2SquaredDistance(queryVector, docVector);
+ if (maxHeap.size() < topK) {
+ maxHeap.add(new DocDistance(docId, distance));
Review Comment:
`ExactVectorScanFilterOperator` always ranks by L2-squared distance (see
`computeL2SquaredDistance` usage). For segments configured with COSINE /
INNER_PRODUCT / DOT_PRODUCT distance functions, this exact-scan fallback will
return a different topK than the vector index would, changing query semantics
when the index is missing. Please compute exact distances using the segment’s
configured vector distance function (keep L2-squared only for EUCLIDEAN/L2).
##########
pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/converter/SegmentV1V2ToV3FormatConverter.java:
##########
@@ -263,15 +263,15 @@ public boolean accept(File dir, String name) {
private void copyVectorIndexIfExists(File segmentDirectory, File v3Dir)
throws IOException {
- // TODO: see if this can be done by reusing some existing methods
- String suffix = V1Constants.Indexes.VECTOR_V912_HNSW_INDEX_FILE_EXTENSION;
- File[] vectorIndexFiles = segmentDirectory.listFiles(new FilenameFilter() {
+ // Copy HNSW index directories (Lucene-based, stored as directories)
+ String hnswSuffix =
V1Constants.Indexes.VECTOR_V912_HNSW_INDEX_FILE_EXTENSION;
+ File[] hnswIndexFiles = segmentDirectory.listFiles(new FilenameFilter() {
@Override
public boolean accept(File dir, String name) {
- return name.endsWith(suffix);
+ return name.endsWith(hnswSuffix);
}
});
- for (File vectorIndexFile : vectorIndexFiles) {
+ for (File vectorIndexFile : hnswIndexFiles) {
File[] indexFiles = vectorIndexFile.listFiles();
Review Comment:
`segmentDirectory.listFiles(...)` can return null (I/O error or not a
directory). `hnswIndexFiles` is iterated without a null-check, which can cause
an NPE during segment conversion. Please add a null guard (as done for
`ivfFlatIndexFiles`) before iterating.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]