mikemccand commented on code in PR #14178: URL: https://github.com/apache/lucene/pull/14178#discussion_r2105066032
########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java: ########## @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * A format which uses <a href="https://github.com/facebookresearch/faiss">Faiss</a> to create and + * search vector indexes, using {@link LibFaissC} to interact with the native library. + * + * <p>A separate Faiss index is created per-segment, and uses the following files: + * + * <ul> + * <li><code>.faissm</code> (metadata file): stores field number, offset and length of actual + * Faiss index in data file. + * <li><code>.faissd</code> (data file): stores concatenated Faiss indexes for all fields. + * <li>All files required by {@link Lucene99FlatVectorsFormat} for storing raw vectors. + * </ul> + * + * <p>Note: Set the {@code $OMP_NUM_THREADS} environment variable to control internal threading. Review Comment: Maybe link out to https://github.com/facebookresearch/faiss/wiki/Threads-and-asynchronous-calls for more details? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java: ########## @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * A format which uses <a href="https://github.com/facebookresearch/faiss">Faiss</a> to create and + * search vector indexes, using {@link LibFaissC} to interact with the native library. + * + * <p>A separate Faiss index is created per-segment, and uses the following files: + * + * <ul> + * <li><code>.faissm</code> (metadata file): stores field number, offset and length of actual + * Faiss index in data file. + * <li><code>.faissd</code> (data file): stores concatenated Faiss indexes for all fields. + * <li>All files required by {@link Lucene99FlatVectorsFormat} for storing raw vectors. + * </ul> + * + * <p>Note: Set the {@code $OMP_NUM_THREADS} environment variable to control internal threading. + * + * @lucene.experimental Review Comment: Could you also add a sentence making it clear there is no promise of backwards compatibility here? That said, does Faiss have any promise? If I write a Faiss HNSW graph with version X, and then upgrade Faiss to version X+1, can the HNSW graph be read/written? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java: ########## @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Write per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsWriter extends KnnVectorsWriter { + private final String description, indexParams; + private final FlatVectorsWriter rawVectorsWriter; + private final IndexOutput meta, data; + private final Map<FieldInfo, FlatFieldVectorsWriter<?>> rawFields; + private boolean closed, finished; + + public FaissKnnVectorsWriter( + String description, + String indexParams, + SegmentWriteState state, + FlatVectorsWriter rawVectorsWriter) + throws IOException { + + this.description = description; + this.indexParams = indexParams; + this.rawVectorsWriter = rawVectorsWriter; + this.rawFields = new HashMap<>(); + this.closed = false; + this.finished = false; + + boolean failure = true; Review Comment: Hmm there is another PR (maybe merged already?) that is trying to clean up this pattern (though, it's usually `success` not `failure` lol) -- maybe clean this up the same way that PR did? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java: ########## @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * A format which uses <a href="https://github.com/facebookresearch/faiss">Faiss</a> to create and + * search vector indexes, using {@link LibFaissC} to interact with the native library. + * + * <p>A separate Faiss index is created per-segment, and uses the following files: + * + * <ul> + * <li><code>.faissm</code> (metadata file): stores field number, offset and length of actual + * Faiss index in data file. + * <li><code>.faissd</code> (data file): stores concatenated Faiss indexes for all fields. + * <li>All files required by {@link Lucene99FlatVectorsFormat} for storing raw vectors. + * </ul> + * + * <p>Note: Set the {@code $OMP_NUM_THREADS} environment variable to control internal threading. + * + * @lucene.experimental + */ +public final class FaissKnnVectorsFormat extends KnnVectorsFormat { + public static final String NAME = FaissKnnVectorsFormat.class.getSimpleName(); + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = NAME + "Meta"; + static final String DATA_CODEC_NAME = NAME + "Data"; + static final String META_EXTENSION = "faissm"; + static final String DATA_EXTENSION = "faissd"; + + private final String description; + private final String indexParams; + private final FlatVectorsFormat rawVectorsFormat; + + public FaissKnnVectorsFormat() { Review Comment: Could you add a comment somewhere or maybe a `README` with rough high level instructions for how to build this? E.g. install Faiss 1.11.0 or newer, typically via pytorch or conda-forge channel from XYZ, then set this gradle property or so, then `./gradlew such-and-such`at root level, etc? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java: ########## @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput meta, data; + private final Map<String, IndexEntry> indexMap; + private final Arena arena; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + this.arena = Arena.ofShared(); + this.closed = false; + + boolean failure = true; + try { + meta = + openInput( + state, + META_EXTENSION, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context); + data = + openInput( + state, + DATA_EXTENSION, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + Map.Entry<String, IndexEntry> entry; + while ((entry = parseNextField(state)) != null) { + this.indexMap.put(entry.getKey(), entry.getValue()); + } + + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @SuppressWarnings("SameParameterValue") + private IndexInput openInput( + SegmentReadState state, + String extension, + String codecName, + int versionStart, + int versionEnd, + IOContext context) + throws IOException { + + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexInput input = state.directory.openInput(fileName, context); + CodecUtil.checkIndexHeader( + input, codecName, versionStart, versionEnd, state.segmentInfo.getId(), state.segmentSuffix); + return input; + } + + private Map.Entry<String, IndexEntry> parseNextField(SegmentReadState state) throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new IllegalStateException("Invalid field"); Review Comment: Include the `fieldNumber` in the exception message? ########## lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java: ########## @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.index.VectorEncoding.FLOAT32; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.junit.BeforeClass; +import org.junit.Ignore; + +/** + * Tests for {@link FaissKnnVectorsFormat}. Will run only if required shared libraries (including + * dependencies) are present at runtime, or the {@value #FAISS_RUN_TESTS} JVM arg is set to {@code + * true} + */ +public class TestFaissKnnVectorsFormat extends BaseKnnVectorsFormatTestCase { + private static final String FAISS_RUN_TESTS = "tests.faiss.run"; + + private static final VectorEncoding[] SUPPORTED_ENCODINGS = {FLOAT32}; + private static final VectorSimilarityFunction[] SUPPORTED_FUNCTIONS = {DOT_PRODUCT, EUCLIDEAN}; + + @BeforeClass + public static void maybeSuppress() throws ClassNotFoundException { + // Explicitly run tests + if (Boolean.getBoolean(FAISS_RUN_TESTS)) { + return; + } + + // Otherwise check if dependencies are present + boolean faissLibraryPresent; + try { + Class.forName("org.apache.lucene.sandbox.codecs.faiss.LibFaissC"); + faissLibraryPresent = true; + } catch (UnsatisfiedLinkError _) { + faissLibraryPresent = false; + } + assumeTrue("Native libraries present", faissLibraryPresent); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return SUPPORTED_ENCODINGS[random().nextInt(SUPPORTED_ENCODINGS.length)]; + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return SUPPORTED_FUNCTIONS[random().nextInt(SUPPORTED_FUNCTIONS.length)]; + } + + @Override Review Comment: Does base class have a test that stresses creating/destroying many segments with KNN vectors? If not, maybe add one? Let's try to gain some confidence that we're not leaking RAM/HEAP/objects/file descriptors/files on disk? ########## gradle/testing/defaults-tests.gradle: ########## @@ -145,6 +145,7 @@ allprojects { ':lucene:core', ':lucene:codecs', ":lucene:distribution.tests", + ':lucene:sandbox', Review Comment: Hmm was this a pre-existing issue (not running any tests in `sandbox` module when running a root `./gradlew test`)? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java: ########## @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput meta, data; + private final Map<String, IndexEntry> indexMap; + private final Arena arena; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + this.arena = Arena.ofShared(); + this.closed = false; + + boolean failure = true; + try { + meta = + openInput( + state, + META_EXTENSION, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context); + data = + openInput( + state, + DATA_EXTENSION, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + Map.Entry<String, IndexEntry> entry; + while ((entry = parseNextField(state)) != null) { + this.indexMap.put(entry.getKey(), entry.getValue()); + } + + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @SuppressWarnings("SameParameterValue") + private IndexInput openInput( + SegmentReadState state, + String extension, + String codecName, + int versionStart, + int versionEnd, + IOContext context) + throws IOException { + + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexInput input = state.directory.openInput(fileName, context); + CodecUtil.checkIndexHeader( + input, codecName, versionStart, versionEnd, state.segmentInfo.getId(), state.segmentSuffix); + return input; + } + + private Map.Entry<String, IndexEntry> parseNextField(SegmentReadState state) throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new IllegalStateException("Invalid field"); + } + + long dataOffset = meta.readLong(); + long dataLength = meta.readLong(); + + // See flags defined in c_api/index_io_c.h + int ioFlags = 3; Review Comment: Hmm can you use the flag names? Maybe 3 `FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY`? We'd have to define these flags here in javaland, and add a comment pointing back to the C header where they are originally defined. ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java: ########## @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput meta, data; + private final Map<String, IndexEntry> indexMap; + private final Arena arena; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + this.arena = Arena.ofShared(); + this.closed = false; + + boolean failure = true; + try { + meta = + openInput( + state, + META_EXTENSION, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context); + data = + openInput( + state, + DATA_EXTENSION, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + Map.Entry<String, IndexEntry> entry; + while ((entry = parseNextField(state)) != null) { + this.indexMap.put(entry.getKey(), entry.getValue()); + } + + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @SuppressWarnings("SameParameterValue") + private IndexInput openInput( + SegmentReadState state, + String extension, + String codecName, + int versionStart, + int versionEnd, + IOContext context) + throws IOException { + + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexInput input = state.directory.openInput(fileName, context); + CodecUtil.checkIndexHeader( + input, codecName, versionStart, versionEnd, state.segmentInfo.getId(), state.segmentSuffix); + return input; + } + + private Map.Entry<String, IndexEntry> parseNextField(SegmentReadState state) throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new IllegalStateException("Invalid field"); + } + + long dataOffset = meta.readLong(); + long dataLength = meta.readLong(); + + // See flags defined in c_api/index_io_c.h + int ioFlags = 3; + + // Read index into memory + MemorySegment indexPointer = + indexRead(data.slice(fieldInfo.name, dataOffset, dataLength), ioFlags) + // Ensure timely cleanup + .reinterpret(arena, LibFaissC::freeIndex); + + return Map.entry( + fieldInfo.name, new IndexEntry(indexPointer, fieldInfo.getVectorSimilarityFunction())); + } + + @Override + public void checkIntegrity() throws IOException { Review Comment: Have you tried running Lucene's `CheckIndex` command-line tool on a Faiss index to confirm it's happy? Does Faiss have any API to "check integrity" that we could call here? ########## .github/workflows/run-special-checks-sandbox.yml: ########## @@ -0,0 +1,51 @@ +name: "Run special checks: module lucene/sandbox" + Review Comment: Hmm add a `CHANGES.txt` entry? I think this can be safely backported to 10.x, after we bake for a week or two in `main`? It's sandbox, completely separate from everything else, experimental ... ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java: ########## @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexRead; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexSearch; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput meta, data; + private final Map<String, IndexEntry> indexMap; + private final Arena arena; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + this.arena = Arena.ofShared(); + this.closed = false; + + boolean failure = true; + try { + meta = + openInput( + state, + META_EXTENSION, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context); + data = + openInput( + state, + DATA_EXTENSION, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + Map.Entry<String, IndexEntry> entry; + while ((entry = parseNextField(state)) != null) { + this.indexMap.put(entry.getKey(), entry.getValue()); + } + + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @SuppressWarnings("SameParameterValue") + private IndexInput openInput( + SegmentReadState state, + String extension, + String codecName, + int versionStart, + int versionEnd, + IOContext context) + throws IOException { + + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexInput input = state.directory.openInput(fileName, context); + CodecUtil.checkIndexHeader( + input, codecName, versionStart, versionEnd, state.segmentInfo.getId(), state.segmentSuffix); + return input; + } + + private Map.Entry<String, IndexEntry> parseNextField(SegmentReadState state) throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new IllegalStateException("Invalid field"); + } + + long dataOffset = meta.readLong(); + long dataLength = meta.readLong(); + + // See flags defined in c_api/index_io_c.h + int ioFlags = 3; + + // Read index into memory + MemorySegment indexPointer = + indexRead(data.slice(fieldInfo.name, dataOffset, dataLength), ioFlags) + // Ensure timely cleanup + .reinterpret(arena, LibFaissC::freeIndex); + + return Map.entry( + fieldInfo.name, new IndexEntry(indexPointer, fieldInfo.getVectorSimilarityFunction())); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(meta); + CodecUtil.checksumEntireFile(data); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) { + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + } + + @Override + public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) { + IndexEntry entry = indexMap.get(field); + if (entry != null) { + indexSearch(entry.indexPointer, entry.function, vector, knnCollector, acceptDocs); + } Review Comment: Hmm, else should we throw an exception? Or are other `KnnVectorsReader` also quietly lenient too? Oh, hmm, I guess one segment may not have the field and others do (in the sparse case -- not all docs have vectors), so we probably cannot thow an exception if we don't recognize the field? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java: ########## @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Write per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsWriter extends KnnVectorsWriter { + private final String description, indexParams; + private final FlatVectorsWriter rawVectorsWriter; + private final IndexOutput meta, data; + private final Map<FieldInfo, FlatFieldVectorsWriter<?>> rawFields; + private boolean closed, finished; + + public FaissKnnVectorsWriter( + String description, + String indexParams, + SegmentWriteState state, + FlatVectorsWriter rawVectorsWriter) + throws IOException { + + this.description = description; + this.indexParams = indexParams; + this.rawVectorsWriter = rawVectorsWriter; + this.rawFields = new HashMap<>(); + this.closed = false; + this.finished = false; + + boolean failure = true; + try { + this.meta = openOutput(state, META_EXTENSION, META_CODEC_NAME); + this.data = openOutput(state, DATA_EXTENSION, DATA_CODEC_NAME); + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private IndexOutput openOutput(SegmentWriteState state, String extension, String codecName) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexOutput output = state.directory.createOutput(fileName, state.context); + CodecUtil.writeIndexHeader( + output, codecName, VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); + return output; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorsWriter.mergeOneField(fieldInfo, mergeState); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT32 -> { + FloatVectorValues merged = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + writeFloatField(fieldInfo, merged, doc -> doc); + } + } + } + + @Override + public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter<?> rawFieldVectorsWriter = rawVectorsWriter.addField(fieldInfo); + rawFields.put(fieldInfo, rawFieldVectorsWriter); + return rawFieldVectorsWriter; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { Review Comment: Does Faiss also not need the full precision vectors at search time when using quantization? Does it have a rerank phase where it uses the original vectors? Is that phase optional? If so, when we finally find a clean way to not replicate full precision vectors to searchers (during nrt segment replication -- there is an issue open for this), we could do something for Faiss as well? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java: ########## @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Provides a Faiss-based vector codec via {@link + * org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat}. + * + * @lucene.experimental + */ +package org.apache.lucene.sandbox.codecs.faiss; Review Comment: Oh this is maybe the right place to put some nice docs about how to build this? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java: ########## @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemoryLayout; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.LongBuffer; +import java.util.Locale; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Utility class to wrap necessary functions of the native C_API of Faiss using <a + * href="https://openjdk.org/projects/panama">Project Panama</a> (<a + * href="https://anaconda.org/pytorch/faiss-cpu">install from Conda</a> or build using <a + * href="https://github.com/facebookresearch/faiss/blob/main/c_api/INSTALL.md">this guide</a> and + * add to runtime along with all dependencies). + * + * <p>Important Note: When installing from Conda, ensure that the license of the distribution and + * channels being used is applicable to you! + * + * @lucene.experimental + */ +final class LibFaissC { + // TODO: Use vectorized version where available + public static final String LIBRARY_NAME = "faiss_c"; + public static final String LIBRARY_VERSION = "1.11.0"; + + static { + System.loadLibrary(LIBRARY_NAME); + checkLibraryVersion(); + } + + private LibFaissC() {} + + @SuppressWarnings("SameParameterValue") + private static MemorySegment getUpcallStub( + Arena arena, MethodHandle target, MemoryLayout resLayout, MemoryLayout... argLayouts) { + return Linker.nativeLinker() + .upcallStub(target, FunctionDescriptor.of(resLayout, argLayouts), arena); + } + + private static MethodHandle getDowncallHandle( + String functionName, MemoryLayout resLayout, MemoryLayout... argLayouts) { + return Linker.nativeLinker() + .downcallHandle( + SymbolLookup.loaderLookup().find(functionName).orElseThrow(), + FunctionDescriptor.of(resLayout, argLayouts)); + } + + private static void checkLibraryVersion() { + MethodHandle getVersion = getDowncallHandle("faiss_get_version", ADDRESS); + String actualVersion = callAndGetString(getVersion); + if (LIBRARY_VERSION.equals(actualVersion) == false) { + throw new UnsupportedOperationException( + String.format( + Locale.ROOT, + "Expected Faiss library version %s, found %s", + LIBRARY_VERSION, + actualVersion)); + } + } + + private static final MethodHandle FREE_INDEX = + getDowncallHandle("faiss_Index_free", JAVA_INT, ADDRESS); + + public static void freeIndex(MemorySegment indexPointer) { + callAndHandleError(FREE_INDEX, indexPointer); + } + + private static final MethodHandle FREE_CUSTOM_IO_WRITER = + getDowncallHandle("faiss_CustomIOWriter_free", JAVA_INT, ADDRESS); + + public static void freeCustomIOWriter(MemorySegment customIOWriterPointer) { + callAndHandleError(FREE_CUSTOM_IO_WRITER, customIOWriterPointer); + } + + private static final MethodHandle FREE_CUSTOM_IO_READER = + getDowncallHandle("faiss_CustomIOReader_free", JAVA_INT, ADDRESS); + + public static void freeCustomIOReader(MemorySegment customIOReaderPointer) { + callAndHandleError(FREE_CUSTOM_IO_READER, customIOReaderPointer); + } + + private static final MethodHandle FREE_PARAMETER_SPACE = + getDowncallHandle("faiss_ParameterSpace_free", JAVA_INT, ADDRESS); + + private static void freeParameterSpace(MemorySegment parameterSpacePointer) { + callAndHandleError(FREE_PARAMETER_SPACE, parameterSpacePointer); + } + + private static final MethodHandle FREE_ID_SELECTOR_BITMAP = + getDowncallHandle("faiss_IDSelectorBitmap_free", JAVA_INT, ADDRESS); + + private static void freeIDSelectorBitmap(MemorySegment idSelectorBitmapPointer) { + callAndHandleError(FREE_ID_SELECTOR_BITMAP, idSelectorBitmapPointer); + } + + private static final MethodHandle FREE_SEARCH_PARAMETERS = + getDowncallHandle("faiss_SearchParameters_free", JAVA_INT, ADDRESS); + + private static void freeSearchParameters(MemorySegment searchParametersPointer) { + callAndHandleError(FREE_SEARCH_PARAMETERS, searchParametersPointer); + } + + private static final MethodHandle INDEX_FACTORY = + getDowncallHandle("faiss_index_factory", JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT); + + private static final MethodHandle PARAMETER_SPACE_NEW = + getDowncallHandle("faiss_ParameterSpace_new", JAVA_INT, ADDRESS); + + private static final MethodHandle SET_INDEX_PARAMETERS = + getDowncallHandle( + "faiss_ParameterSpace_set_index_parameters", JAVA_INT, ADDRESS, ADDRESS, ADDRESS); + + private static final MethodHandle ID_SELECTOR_BITMAP_NEW = + getDowncallHandle("faiss_IDSelectorBitmap_new", JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS); + + private static final MethodHandle SEARCH_PARAMETERS_NEW = + getDowncallHandle("faiss_SearchParameters_new", JAVA_INT, ADDRESS, ADDRESS); + + private static final MethodHandle INDEX_IS_TRAINED = + getDowncallHandle("faiss_Index_is_trained", JAVA_INT, ADDRESS); + + private static final MethodHandle INDEX_TRAIN = + getDowncallHandle("faiss_Index_train", JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS); + + private static final MethodHandle INDEX_ADD_WITH_IDS = + getDowncallHandle("faiss_Index_add_with_ids", JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS); + + public static MemorySegment createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId) + throws IOException { + + try (Arena temp = Arena.ofConfined()) { + int size = floatVectorValues.size(); + int dimension = floatVectorValues.dimension(); + + // Mapped from faiss/MetricType.h + int metric = + switch (function) { + case DOT_PRODUCT -> 0; + case EUCLIDEAN -> 1; + case COSINE, MAXIMUM_INNER_PRODUCT -> + throw new UnsupportedOperationException("Metric type not supported"); + }; + + // Create an index + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(INDEX_FACTORY, pointer, dimension, temp.allocateFrom(description), metric); + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + // Set index params + callAndHandleError(PARAMETER_SPACE_NEW, pointer); + MemorySegment parameterSpacePointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeParameterSpace); + + callAndHandleError( + SET_INDEX_PARAMETERS, + parameterSpacePointer, + indexPointer, + temp.allocateFrom(indexParams)); + + // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: + // - https://github.com/opensearch-project/k-NN/issues/1506 + // - https://github.com/opensearch-project/k-NN/issues/1938 + + // Allocate docs in native memory + MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); + FloatBuffer docsBuffer = docs.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer(); + + // Allocate ids in native memory + MemorySegment ids = temp.allocate(JAVA_LONG, size); + LongBuffer idsBuffer = ids.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer(); + + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { + idsBuffer.put(oldToNewDocId.apply(i)); + docsBuffer.put(floatVectorValues.vectorValue(iterator.index())); + } + + // Train index + if (callAndGetInt(INDEX_IS_TRAINED, indexPointer) == 0) { + callAndHandleError(INDEX_TRAIN, indexPointer, size, docs); + } + + // Add docs to index + callAndHandleError(INDEX_ADD_WITH_IDS, indexPointer, size, docs, ids); + + return indexPointer; + } + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static int writeBytes( + IndexOutput output, MemorySegment inputPointer, int itemSize, int numItems) + throws IOException { + // TODO: Can we avoid copying to heap? + byte[] bytes = + new byte[(int) (Integer.toUnsignedLong(itemSize) * Integer.toUnsignedLong(numItems))]; + inputPointer.reinterpret(bytes.length).asByteBuffer().order(ByteOrder.nativeOrder()).get(bytes); + output.writeBytes(bytes, 0, bytes.length); + return numItems; + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static int readBytes( + IndexInput input, MemorySegment outputPointer, int itemSize, int numItems) + throws IOException { + // TODO: Can we avoid copying to heap? + byte[] bytes = + new byte[(int) (Integer.toUnsignedLong(itemSize) * Integer.toUnsignedLong(numItems))]; + input.readBytes(bytes, 0, bytes.length); + outputPointer + .reinterpret(bytes.length) + .asByteBuffer() + .order(ByteOrder.nativeOrder()) + .put(bytes); + return numItems; + } + + private static final MethodHandle WRITE_BYTES_HANDLE; + private static final MethodHandle READ_BYTES_HANDLE; + + static { + try { + WRITE_BYTES_HANDLE = + MethodHandles.lookup() + .findStatic( + LibFaissC.class, + "writeBytes", + MethodType.methodType( + int.class, IndexOutput.class, MemorySegment.class, int.class, int.class)); + + READ_BYTES_HANDLE = + MethodHandles.lookup() + .findStatic( + LibFaissC.class, + "readBytes", + MethodType.methodType( + int.class, IndexInput.class, MemorySegment.class, int.class, int.class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static final MethodHandle CUSTOM_IO_WRITER_NEW = + getDowncallHandle("faiss_CustomIOWriter_new", JAVA_INT, ADDRESS, ADDRESS); + + private static final MethodHandle WRITE_INDEX_CUSTOM = + getDowncallHandle("faiss_write_index_custom", JAVA_INT, ADDRESS, ADDRESS, JAVA_INT); + + public static void indexWrite(MemorySegment indexPointer, IndexOutput output, int ioFlags) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output); + MemorySegment writerStub = + getUpcallStub(temp, writerHandle, JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT); + + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(CUSTOM_IO_WRITER_NEW, pointer, writerStub); + MemorySegment customIOWriterPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeCustomIOWriter); + + callAndHandleError(WRITE_INDEX_CUSTOM, indexPointer, customIOWriterPointer, ioFlags); + } + } + + private static final MethodHandle CUSTOM_IO_READER_NEW = + getDowncallHandle("faiss_CustomIOReader_new", JAVA_INT, ADDRESS, ADDRESS); + + private static final MethodHandle READ_INDEX_CUSTOM = + getDowncallHandle("faiss_read_index_custom", JAVA_INT, ADDRESS, JAVA_INT, ADDRESS); + + public static MemorySegment indexRead(IndexInput input, int ioFlags) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); + MemorySegment readerStub = + getUpcallStub(temp, readerHandle, JAVA_INT, ADDRESS, JAVA_INT, JAVA_INT); + + MemorySegment pointer = temp.allocate(ADDRESS); + callAndHandleError(CUSTOM_IO_READER_NEW, pointer, readerStub); + MemorySegment customIOReaderPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeCustomIOReader); + + callAndHandleError(READ_INDEX_CUSTOM, customIOReaderPointer, ioFlags, pointer); + return pointer.get(ADDRESS, 0); + } + } + + private static final MethodHandle INDEX_SEARCH = + getDowncallHandle( + "faiss_Index_search", JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS); + + private static final MethodHandle INDEX_SEARCH_WITH_PARAMS = + getDowncallHandle( + "faiss_Index_search_with_params", + JAVA_INT, + ADDRESS, + JAVA_LONG, + ADDRESS, + JAVA_LONG, + ADDRESS, + ADDRESS, + ADDRESS); + + public static void indexSearch( + MemorySegment indexPointer, + VectorSimilarityFunction function, + float[] query, + KnnCollector knnCollector, + Bits acceptDocs) { + + try (Arena temp = Arena.ofConfined()) { + FixedBitSet fixedBitSet = + switch (acceptDocs) { + case null -> null; + case FixedBitSet bitSet -> bitSet; + // TODO: Add optimized case for SparseFixedBitSet + case Bits bits -> FixedBitSet.copyOf(bits); + }; + + // Allocate queries in native memory + MemorySegment queries = temp.allocate(JAVA_FLOAT, query.length); + queries.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().put(query); + + // Faiss knn search + int k = knnCollector.k(); + MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k); + MemorySegment idsPointer = temp.allocate(JAVA_LONG, k); + + MemorySegment localIndex = indexPointer.reinterpret(temp, null); + if (fixedBitSet == null) { + // Search without runtime filters + callAndHandleError(INDEX_SEARCH, localIndex, 1, queries, k, distancesPointer, idsPointer); + } else { + MemorySegment pointer = temp.allocate(ADDRESS); + + long[] bits = fixedBitSet.getBits(); + MemorySegment nativeBits = temp.allocate(JAVA_LONG, bits.length); + + // Use LITTLE_ENDIAN to convert long[] -> uint8_t* + nativeBits.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asLongBuffer().put(bits); + + callAndHandleError(ID_SELECTOR_BITMAP_NEW, pointer, fixedBitSet.length(), nativeBits); + MemorySegment idSelectorBitmapPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeIDSelectorBitmap); + + callAndHandleError(SEARCH_PARAMETERS_NEW, pointer, idSelectorBitmapPointer); + MemorySegment searchParametersPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeSearchParameters); + + // Search with runtime filters + callAndHandleError( + INDEX_SEARCH_WITH_PARAMS, + localIndex, + 1, + queries, + k, + searchParametersPointer, + distancesPointer, + idsPointer); + } + + // Retrieve scores + float[] distances = new float[k]; + distancesPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().get(distances); + + // Retrieve ids + long[] ids = new long[k]; + idsPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer().get(ids); + + // Record hits + for (int i = 0; i < k; i++) { + // Not enough results + if (ids[i] == -1) { + break; + } + + // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java + float score = + switch (function) { + case DOT_PRODUCT -> + // distance in Faiss === dotProduct in Lucene + Math.max((1 + distances[i]) / 2, 0); + + case EUCLIDEAN -> + // distance in Faiss === squareDistance in Lucene + 1 / (1 + distances[i]); + + case COSINE, MAXIMUM_INNER_PRODUCT -> + throw new UnsupportedOperationException("Metric type not supported"); + }; + + knnCollector.collect((int) ids[i], score); + } + } + } + + private static final MethodHandle GET_LAST_ERROR = Review Comment: This doesn't seem thread safe? Or maybe the recorded error in C is somehow thread-local/-private? ########## lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java: ########## @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.createIndex; +import static org.apache.lucene.sandbox.codecs.faiss.LibFaissC.indexWrite; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Write per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsWriter extends KnnVectorsWriter { + private final String description, indexParams; + private final FlatVectorsWriter rawVectorsWriter; + private final IndexOutput meta, data; + private final Map<FieldInfo, FlatFieldVectorsWriter<?>> rawFields; + private boolean closed, finished; + + public FaissKnnVectorsWriter( + String description, + String indexParams, + SegmentWriteState state, + FlatVectorsWriter rawVectorsWriter) + throws IOException { + + this.description = description; + this.indexParams = indexParams; + this.rawVectorsWriter = rawVectorsWriter; + this.rawFields = new HashMap<>(); + this.closed = false; + this.finished = false; + + boolean failure = true; + try { + this.meta = openOutput(state, META_EXTENSION, META_CODEC_NAME); + this.data = openOutput(state, DATA_EXTENSION, DATA_CODEC_NAME); + failure = false; + } finally { + if (failure) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private IndexOutput openOutput(SegmentWriteState state, String extension, String codecName) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, extension); + IndexOutput output = state.directory.createOutput(fileName, state.context); + CodecUtil.writeIndexHeader( + output, codecName, VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); + return output; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorsWriter.mergeOneField(fieldInfo, mergeState); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT32 -> { + FloatVectorValues merged = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + writeFloatField(fieldInfo, merged, doc -> doc); + } + } + } + + @Override + public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter<?> rawFieldVectorsWriter = rawVectorsWriter.addField(fieldInfo); + rawFields.put(fieldInfo, rawFieldVectorsWriter); + return rawFieldVectorsWriter; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorsWriter.flush(maxDoc, sortMap); + for (Map.Entry<FieldInfo, FlatFieldVectorsWriter<?>> entry : rawFields.entrySet()) { + FieldInfo fieldInfo = entry.getKey(); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + + case FLOAT32 -> { + @SuppressWarnings("unchecked") + FlatFieldVectorsWriter<float[]> rawWriter = + (FlatFieldVectorsWriter<float[]>) entry.getValue(); + + List<float[]> vectors = rawWriter.getVectors(); + int dimension = fieldInfo.getVectorDimension(); + DocIdSet docIdSet = rawWriter.getDocsWithFieldSet(); + + writeFloatField( + fieldInfo, + new BufferedFloatVectorValues(vectors, dimension, docIdSet), + (sortMap != null) ? sortMap::oldToNew : doc -> doc); + } + } + } + } + + private void writeFloatField( + FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId) + throws IOException { + int number = fieldInfo.number; + meta.writeInt(number); + + // Write index to temp file and deallocate from memory + try (Arena temp = Arena.ofConfined()) { + VectorSimilarityFunction function = fieldInfo.getVectorSimilarityFunction(); + MemorySegment indexPointer = + createIndex(description, indexParams, function, floatVectorValues, oldToNewDocId) + // Ensure timely cleanup + .reinterpret(temp, LibFaissC::freeIndex); + + // See flags defined in c_api/index_io_c.h + int ioFlags = 3; + + // Write index + long dataOffset = data.getFilePointer(); + indexWrite(indexPointer, data, ioFlags); + long dataLength = data.getFilePointer() - dataOffset; + + meta.writeLong(dataOffset); + meta.writeLong(dataLength); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("Already finished"); + } + finished = true; + + rawVectorsWriter.finish(); + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + CodecUtil.writeFooter(data); + } + + @Override + public void close() throws IOException { + if (closed == false) { + IOUtils.close(rawVectorsWriter, meta, data); + closed = true; + } + } + + @Override + public long ramBytesUsed() { + // TODO: How to estimate Faiss usage? Review Comment: Do we have a rough sense of how RAM hungry Faiss is during construction? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org