benwtrent commented on code in PR #12421: URL: https://github.com/apache/lucene/pull/12421#discussion_r1285112800
########## lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java: ########## @@ -87,6 +88,67 @@ public static NeighborQueue search( Bits acceptOrds, int visitedLimit) throws IOException { + return search( + query, + topK, + vectors, + vectorEncoding, + similarityFunction, + graph, + acceptOrds, + visitedLimit, + new SparseFixedBitSet(vectors.size())); + } + + /** + * Searches a concurrent HNSW graph for the nearest neighbors of a query vector. + * + * @param query search query vector + * @param topK the number of nodes to be returned + * @param vectors the vector values + * @param similarityFunction the similarity function to compare vectors + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + * @param visitedLimit the maximum number of nodes that the search is allowed to visit + * @return a priority queue holding the closest neighbors found + */ + public static NeighborQueue searchConcurrent( Review Comment: This is weird IMO, this accepts an `HnswGraph` object, but its says "concurrent". Additionally, the only thing that is concurrent is the visited counter. Seems to me that this method should accept a `ConcurrentOnHeapHnswGraph` only. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java: ########## @@ -104,7 +103,9 @@ public void addNode(int level, int node) { // and make this node the graph's new entry point if (level >= numLevels) { for (int i = numLevels; i <= level; i++) { - graphUpperLevels.add(new HashMap<>()); + graphUpperLevels.add( + new HashMap<>( + 16, levelLoadFactor)); // these are the default parameters, made explicit } Review Comment: We shouldn't specify the loadfactor. That doesn't make sense to me. Also, with `16` really this is only pre-allocating up to `12` values and then it will grow again, I am not sure if that is your purpose or not. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java: ########## @@ -173,7 +175,7 @@ public String toString() { return "NeighborArray[" + size + "]"; } - private int ascSortFindRightMostInsertionPoint(float newScore, int bound) { + protected int ascSortFindRightMostInsertionPoint(float newScore, int bound) { Review Comment: make this final? Seems like we should discourage subclasses being able to override this ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java: ########## @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.PrimitiveIterator; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; + +/** A concurrent set of neighbors. */ +public class ConcurrentNeighborSet { + /** the node id whose neighbors we are storing */ + private final int nodeId; + + /** + * We use a copy-on-write NeighborArray to store the neighbors. Even though updating this is + * expensive, it is still faster than using a concurrent Collection because "iterate through a + * node's neighbors" is a hot loop in adding to the graph, and NeighborArray can do that much + * faster: no boxing/unboxing, all the data is stored sequentially instead of having to follow + * references, and no fancy encoding necessary for node/score. + */ + private final AtomicReference<ConcurrentNeighborArray> neighborsRef; + + private final NeighborSimilarity similarity; + + /** the maximum number of neighbors we can store */ + private final int maxConnections; + + public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) { + this.nodeId = nodeId; + this.maxConnections = maxConnections; + this.similarity = similarity; + neighborsRef = new AtomicReference<>(new ConcurrentNeighborArray(maxConnections, true)); + } + + public PrimitiveIterator.OfInt nodeIterator() { + // don't use a stream here. stream's implementation of iterator buffers + // very aggressively, which is a big waste for a lot of searches. + return new NeighborIterator(neighborsRef.get()); + } + + public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf) throws IOException { + NeighborArray neighbors = neighborsRef.get(); + for (int i = 0; i < neighbors.size(); i++) { + int nbr = neighbors.node[i]; + float nbrScore = neighbors.score[i]; + ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr); + nbrNbr.insert(nodeId, nbrScore); + } + } + + private static class NeighborIterator implements PrimitiveIterator.OfInt { + private final NeighborArray neighbors; + private int i; + + private NeighborIterator(NeighborArray neighbors) { + this.neighbors = neighbors; + i = 0; + } + + @Override + public boolean hasNext() { + return i < neighbors.size(); + } + + @Override + public int nextInt() { + return neighbors.node[i++]; + } + } + + public int size() { + return neighborsRef.get().size(); + } + + public int arrayLength() { + return neighborsRef.get().node.length; + } + + /** + * For each candidate (going from best to worst), select it only if it is closer to target than it + * is to any of the already-selected candidates. This is maintained whether those other neighbors + * were selected by this method, or were added as a "backlink" to a node inserted concurrently + * that chose this one as a neighbor. + */ + public void insertDiverse(NeighborArray candidates) { + BitSet selected = new FixedBitSet(candidates.size()); + for (int i = candidates.size() - 1; i >= 0; i--) { + int cNode = candidates.node[i]; + float cScore = candidates.score[i]; + if (isDiverse(cNode, cScore, candidates, selected)) { + selected.set(i); + } + } + insertMultiple(candidates, selected); + // This leaves the paper's keepPrunedConnection option out; we might want to add that + // as an option in the future. + } + + private void insertMultiple(NeighborArray others, BitSet selected) { + neighborsRef.getAndUpdate( + current -> { + ConcurrentNeighborArray next = current.copy(); + for (int i = others.size() - 1; i >= 0; i--) { + if (!selected.get(i)) { + continue; + } + int node = others.node[i]; + float score = others.score[i]; + next.insertSorted(node, score); + } + enforceMaxConnLimit(next); + return next; + }); + } + + /** + * Insert a new neighbor, maintaining our size cap by removing the least diverse neighbor if + * necessary. + */ + public void insert(int neighborId, float score) throws IOException { + assert neighborId != nodeId : "can't add self as neighbor at node " + nodeId; + neighborsRef.getAndUpdate( + current -> { + ConcurrentNeighborArray next = current.copy(); + next.insertSorted(neighborId, score); + enforceMaxConnLimit(next); + return next; + }); + } + + // is the candidate node with the given score closer to the base node than it is to any of the + // existing neighbors + private boolean isDiverse(int node, float score, NeighborArray others, BitSet selected) { + if (others.size() == 0) { + return true; + } + + NeighborSimilarity.ScoreFunction scoreProvider = similarity.scoreProvider(node); + for (int i = others.size() - 1; i >= 0; i--) { + if (!selected.get(i)) { + continue; + } + int candidateNode = others.node[i]; + if (node == candidateNode) { + break; + } + if (scoreProvider.apply(candidateNode) > score) { + return false; + } + } + return true; + } + + private void enforceMaxConnLimit(NeighborArray neighbors) { + while (neighbors.size() > maxConnections) { + try { + removeLeastDiverse(neighbors); + } catch (IOException e) { + throw new UncheckedIOException(e); // called from closures + } + } + } + + /** + * For each node e1 starting with the last neighbor (i.e. least similar to the base node), look at + * all nodes e2 that are closer to the base node than e1 is. If any e2 is closer to e1 than e1 is + * to the base node, remove e1. + */ + private void removeLeastDiverse(NeighborArray neighbors) throws IOException { + for (int i = neighbors.size() - 1; i >= 1; i--) { + int e1Id = neighbors.node[i]; + float baseScore = neighbors.score[i]; + NeighborSimilarity.ScoreFunction scoreProvider = similarity.scoreProvider(e1Id); + + for (int j = i - 1; j >= 0; j--) { + int n2Id = neighbors.node[j]; + float n1n2Score = scoreProvider.apply(n2Id); + if (n1n2Score > baseScore) { + neighbors.removeIndex(i); + return; + } + } + } + + // couldn't find any "non-diverse" neighbors, so remove the one farthest from the base node + neighbors.removeIndex(neighbors.size() - 1); + } + + /** Only for testing; this is a linear search */ + boolean contains(int i) { + var it = this.nodeIterator(); + while (it.hasNext()) { + if (it.nextInt() == i) { + return true; + } + } + return false; + } + + /** Encapsulates comparing node distances for diversity checks. */ + public interface NeighborSimilarity { Review Comment: This should be its own class in the package. I could see other things using this. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java: ########## @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.GrowableBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.ThreadInterruptedException; +import org.apache.lucene.util.hnsw.ConcurrentNeighborSet.NeighborSimilarity; +import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph.NodeAtLevel; + +/** + * Builder for Concurrent HNSW graph. See {@link HnswGraph} for a high level overview, and the + * comments to `addGraphNode` for details on the concurrent building approach. + * + * @param <T> the type of vector + */ +public class ConcurrentHnswGraphBuilder<T> { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + + /** A name for the HNSW component for the info-stream */ + public static final String HNSW_COMPONENT = "HNSW"; + + private final int beamWidth; + private final double ml; + private final ExplicitThreadLocal<NeighborArray> scratchNeighbors; + + private final VectorSimilarityFunction similarityFunction; + private final VectorEncoding vectorEncoding; + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectors; + private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher; + private final ExplicitThreadLocal<NeighborQueue> beamCandidates; + + final ConcurrentOnHeapHnswGraph hnsw; + private final ConcurrentSkipListSet<NodeAtLevel> insertionsInProgress = + new ConcurrentSkipListSet<>(); + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectorsCopy; + + /** This is the "native" factory for ConcurrentHnswGraphBuilder. */ + public static <T> ConcurrentHnswGraphBuilder<T> create( + RandomAccessVectorValues<T> vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectorValues the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + */ + public ConcurrentHnswGraphBuilder( + RandomAccessVectorValues<T> vectorValues, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) { + this.vectors = createThreadSafeVectors(vectorValues); + this.vectorsCopy = createThreadSafeVectors(vectorValues); + this.vectorEncoding = Objects.requireNonNull(vectorEncoding); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + if (M <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + + NeighborSimilarity similarity = + new NeighborSimilarity() { + @Override + public float score(int node1, int node2) { + try { + return scoreBetween( + vectors.get().vectorValue(node1), vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ScoreFunction scoreProvider(int node1) { + T v1; + try { + v1 = vectors.get().vectorValue(node1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return node2 -> { + try { + return scoreBetween(v1, vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + }; + this.hnsw = new ConcurrentOnHeapHnswGraph(M, similarity); + + this.graphSearcher = + ExplicitThreadLocal.withInitial( + () -> { + return new HnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(beamWidth, true), + new GrowableBitSet(this.vectors.get().size())); + }); + // in scratch we store candidates in reverse order: worse candidates are first + this.scratchNeighbors = + ExplicitThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1), false)); + this.beamCandidates = + ExplicitThreadLocal.withInitial(() -> new NeighborQueue(beamWidth, false)); + } + + private abstract static class ExplicitThreadLocal<U> { + private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>(); + + public U get() { + return map.computeIfAbsent(Thread.currentThread().getId(), k -> initialValue()); + } + + protected abstract U initialValue(); + + public static <U> ExplicitThreadLocal<U> withInitial(Supplier<U> initialValue) { + return new ExplicitThreadLocal<U>() { + @Override + protected U initialValue() { + return initialValue.get(); + } + }; + } + } + + /** + * Bring-your-own ExecutorService graph builder. + * + * <p>Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an + * independent accessor for the vectors + * @param pool The ExecutorService to use. Must be an instance of ThreadPoolExecutor. + * @param concurrentTasks the number of tasks to submit in parallel. + */ + public Future<ConcurrentOnHeapHnswGraph> buildAsync( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + if (vectorsToAdd == this.vectors) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + } + return addVectors(vectorsToAdd, pool, concurrentTasks); + } + + // the goal here is to keep all the ExecutorService threads busy, but not to create potentially + // millions of futures by naively throwing everything at submit at once. So, we use + // a semaphore to wait until a thread is free before adding a new task. + private Future<ConcurrentOnHeapHnswGraph> addVectors( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + Semaphore semaphore = new Semaphore(concurrentTasks); + Set<Integer> inFlight = ConcurrentHashMap.newKeySet(); + AtomicReference<Throwable> asyncException = new AtomicReference<>(null); + + ExplicitThreadLocal<RandomAccessVectorValues<T>> threadSafeVectors = + createThreadSafeVectors(vectorsToAdd); + + for (int i = 0; i < vectorsToAdd.size(); i++) { + final int node = i; // copy for closure + try { + semaphore.acquire(); + inFlight.add(node); + pool.submit( + () -> { + try { + addGraphNode(node, threadSafeVectors.get()); + } catch (Throwable e) { + asyncException.set(e); + } finally { Review Comment: also, wouldn't additional exceptions overwrite the previous one? What if different failures occur? Seems like this should be a "setonce" and things should just be cancelled and cleaned up as soon as possible. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java: ########## @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.GrowableBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.ThreadInterruptedException; +import org.apache.lucene.util.hnsw.ConcurrentNeighborSet.NeighborSimilarity; +import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph.NodeAtLevel; + +/** + * Builder for Concurrent HNSW graph. See {@link HnswGraph} for a high level overview, and the + * comments to `addGraphNode` for details on the concurrent building approach. + * + * @param <T> the type of vector + */ +public class ConcurrentHnswGraphBuilder<T> { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + + /** A name for the HNSW component for the info-stream */ + public static final String HNSW_COMPONENT = "HNSW"; + + private final int beamWidth; + private final double ml; + private final ExplicitThreadLocal<NeighborArray> scratchNeighbors; + + private final VectorSimilarityFunction similarityFunction; + private final VectorEncoding vectorEncoding; + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectors; + private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher; + private final ExplicitThreadLocal<NeighborQueue> beamCandidates; + + final ConcurrentOnHeapHnswGraph hnsw; + private final ConcurrentSkipListSet<NodeAtLevel> insertionsInProgress = + new ConcurrentSkipListSet<>(); + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectorsCopy; + + /** This is the "native" factory for ConcurrentHnswGraphBuilder. */ + public static <T> ConcurrentHnswGraphBuilder<T> create( + RandomAccessVectorValues<T> vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectorValues the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + */ + public ConcurrentHnswGraphBuilder( + RandomAccessVectorValues<T> vectorValues, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) { + this.vectors = createThreadSafeVectors(vectorValues); + this.vectorsCopy = createThreadSafeVectors(vectorValues); + this.vectorEncoding = Objects.requireNonNull(vectorEncoding); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + if (M <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + + NeighborSimilarity similarity = + new NeighborSimilarity() { + @Override + public float score(int node1, int node2) { + try { + return scoreBetween( + vectors.get().vectorValue(node1), vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ScoreFunction scoreProvider(int node1) { + T v1; + try { + v1 = vectors.get().vectorValue(node1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return node2 -> { + try { + return scoreBetween(v1, vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + }; + this.hnsw = new ConcurrentOnHeapHnswGraph(M, similarity); + + this.graphSearcher = + ExplicitThreadLocal.withInitial( + () -> { + return new HnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(beamWidth, true), + new GrowableBitSet(this.vectors.get().size())); + }); + // in scratch we store candidates in reverse order: worse candidates are first + this.scratchNeighbors = + ExplicitThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1), false)); + this.beamCandidates = + ExplicitThreadLocal.withInitial(() -> new NeighborQueue(beamWidth, false)); + } + + private abstract static class ExplicitThreadLocal<U> { + private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>(); + + public U get() { + return map.computeIfAbsent(Thread.currentThread().getId(), k -> initialValue()); + } + + protected abstract U initialValue(); + + public static <U> ExplicitThreadLocal<U> withInitial(Supplier<U> initialValue) { + return new ExplicitThreadLocal<U>() { + @Override + protected U initialValue() { + return initialValue.get(); + } + }; + } + } + + /** + * Bring-your-own ExecutorService graph builder. + * + * <p>Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an + * independent accessor for the vectors + * @param pool The ExecutorService to use. Must be an instance of ThreadPoolExecutor. + * @param concurrentTasks the number of tasks to submit in parallel. + */ + public Future<ConcurrentOnHeapHnswGraph> buildAsync( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + if (vectorsToAdd == this.vectors) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + } + return addVectors(vectorsToAdd, pool, concurrentTasks); + } + + // the goal here is to keep all the ExecutorService threads busy, but not to create potentially + // millions of futures by naively throwing everything at submit at once. So, we use + // a semaphore to wait until a thread is free before adding a new task. + private Future<ConcurrentOnHeapHnswGraph> addVectors( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + Semaphore semaphore = new Semaphore(concurrentTasks); + Set<Integer> inFlight = ConcurrentHashMap.newKeySet(); + AtomicReference<Throwable> asyncException = new AtomicReference<>(null); + + ExplicitThreadLocal<RandomAccessVectorValues<T>> threadSafeVectors = + createThreadSafeVectors(vectorsToAdd); + + for (int i = 0; i < vectorsToAdd.size(); i++) { + final int node = i; // copy for closure + try { + semaphore.acquire(); + inFlight.add(node); + pool.submit( + () -> { + try { + addGraphNode(node, threadSafeVectors.get()); + } catch (Throwable e) { + asyncException.set(e); + } finally { Review Comment: How we handle a failure here is key. It seems like the entire building process should stop and other tasks cancelled if one fails. Could we do a bunch of unnecessary work to ultimately fail later? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java: ########## @@ -187,7 +189,7 @@ private int ascSortFindRightMostInsertionPoint(float newScore, int bound) { return insertionPoint; } - private int descSortFindRightMostInsertionPoint(float newScore, int bound) { + protected int descSortFindRightMostInsertionPoint(float newScore, int bound) { Review Comment: make this final? Seems like we should discourage subclasses being able to override this ########## lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java: ########## @@ -169,35 +170,43 @@ public NodesIterator getNodesOnLevel(int level) { @Override public long ramBytesUsed() { + // local vars here just to make it easier to keep lines short enough to read + long AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long neighborArrayBytes0 = - nsize0 * (Integer.BYTES + Float.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 - + Integer.BYTES * 3; + (long) nsize0 * (Integer.BYTES + Float.BYTES) + + AH_BYTES * 2 + + REF_BYTES + + Integer.BYTES * 2; Review Comment: This should still be `Integer.BYTES * 3` still correct? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java: ########## @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.PrimitiveIterator; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; + +/** A concurrent set of neighbors. */ +public class ConcurrentNeighborSet { + /** the node id whose neighbors we are storing */ + private final int nodeId; + + /** + * We use a copy-on-write NeighborArray to store the neighbors. Even though updating this is + * expensive, it is still faster than using a concurrent Collection because "iterate through a + * node's neighbors" is a hot loop in adding to the graph, and NeighborArray can do that much + * faster: no boxing/unboxing, all the data is stored sequentially instead of having to follow + * references, and no fancy encoding necessary for node/score. + */ + private final AtomicReference<ConcurrentNeighborArray> neighborsRef; + + private final NeighborSimilarity similarity; + + /** the maximum number of neighbors we can store */ + private final int maxConnections; + + public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) { + this.nodeId = nodeId; + this.maxConnections = maxConnections; + this.similarity = similarity; + neighborsRef = new AtomicReference<>(new ConcurrentNeighborArray(maxConnections, true)); + } + + public PrimitiveIterator.OfInt nodeIterator() { + // don't use a stream here. stream's implementation of iterator buffers + // very aggressively, which is a big waste for a lot of searches. + return new NeighborIterator(neighborsRef.get()); + } + + public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf) throws IOException { + NeighborArray neighbors = neighborsRef.get(); + for (int i = 0; i < neighbors.size(); i++) { + int nbr = neighbors.node[i]; + float nbrScore = neighbors.score[i]; + ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr); + nbrNbr.insert(nodeId, nbrScore); + } + } + + private static class NeighborIterator implements PrimitiveIterator.OfInt { + private final NeighborArray neighbors; + private int i; + + private NeighborIterator(NeighborArray neighbors) { + this.neighbors = neighbors; + i = 0; + } + + @Override + public boolean hasNext() { + return i < neighbors.size(); + } + + @Override + public int nextInt() { + return neighbors.node[i++]; + } + } + + public int size() { + return neighborsRef.get().size(); + } + + public int arrayLength() { + return neighborsRef.get().node.length; + } + + /** + * For each candidate (going from best to worst), select it only if it is closer to target than it + * is to any of the already-selected candidates. This is maintained whether those other neighbors + * were selected by this method, or were added as a "backlink" to a node inserted concurrently + * that chose this one as a neighbor. + */ + public void insertDiverse(NeighborArray candidates) { + BitSet selected = new FixedBitSet(candidates.size()); + for (int i = candidates.size() - 1; i >= 0; i--) { + int cNode = candidates.node[i]; + float cScore = candidates.score[i]; + if (isDiverse(cNode, cScore, candidates, selected)) { + selected.set(i); + } + } + insertMultiple(candidates, selected); + // This leaves the paper's keepPrunedConnection option out; we might want to add that + // as an option in the future. + } + + private void insertMultiple(NeighborArray others, BitSet selected) { + neighborsRef.getAndUpdate( + current -> { + ConcurrentNeighborArray next = current.copy(); Review Comment: we are doing so many copies. The other neighbor array is sorted. Why can't we do a merge over the selected to reduce our copies? Or is "others" always only 1 or 2 values? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java: ########## @@ -126,12 +124,16 @@ private int insertSortedInternal() { return insertionPoint; } - /** This method is for test only. */ - void insertSorted(int newNode, float newScore) { + protected void insertSorted(int newNode, float newScore) { addOutOfOrder(newNode, newScore); insertSortedInternal(); } + protected void growArrays() { Review Comment: final? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentNeighborSet.java: ########## @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.PrimitiveIterator; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; + +/** A concurrent set of neighbors. */ +public class ConcurrentNeighborSet { + /** the node id whose neighbors we are storing */ + private final int nodeId; + + /** + * We use a copy-on-write NeighborArray to store the neighbors. Even though updating this is + * expensive, it is still faster than using a concurrent Collection because "iterate through a + * node's neighbors" is a hot loop in adding to the graph, and NeighborArray can do that much + * faster: no boxing/unboxing, all the data is stored sequentially instead of having to follow + * references, and no fancy encoding necessary for node/score. + */ + private final AtomicReference<ConcurrentNeighborArray> neighborsRef; + + private final NeighborSimilarity similarity; + + /** the maximum number of neighbors we can store */ + private final int maxConnections; + + public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) { + this.nodeId = nodeId; + this.maxConnections = maxConnections; + this.similarity = similarity; + neighborsRef = new AtomicReference<>(new ConcurrentNeighborArray(maxConnections, true)); + } + + public PrimitiveIterator.OfInt nodeIterator() { + // don't use a stream here. stream's implementation of iterator buffers + // very aggressively, which is a big waste for a lot of searches. + return new NeighborIterator(neighborsRef.get()); + } + + public void backlink(Function<Integer, ConcurrentNeighborSet> neighborhoodOf) throws IOException { + NeighborArray neighbors = neighborsRef.get(); + for (int i = 0; i < neighbors.size(); i++) { + int nbr = neighbors.node[i]; + float nbrScore = neighbors.score[i]; + ConcurrentNeighborSet nbrNbr = neighborhoodOf.apply(nbr); + nbrNbr.insert(nodeId, nbrScore); + } + } + + private static class NeighborIterator implements PrimitiveIterator.OfInt { + private final NeighborArray neighbors; + private int i; + + private NeighborIterator(NeighborArray neighbors) { + this.neighbors = neighbors; + i = 0; + } + + @Override + public boolean hasNext() { + return i < neighbors.size(); + } + + @Override + public int nextInt() { + return neighbors.node[i++]; + } + } + + public int size() { + return neighborsRef.get().size(); + } + + public int arrayLength() { + return neighborsRef.get().node.length; + } + + /** + * For each candidate (going from best to worst), select it only if it is closer to target than it + * is to any of the already-selected candidates. This is maintained whether those other neighbors + * were selected by this method, or were added as a "backlink" to a node inserted concurrently + * that chose this one as a neighbor. + */ + public void insertDiverse(NeighborArray candidates) { Review Comment: Why is this public? Seems like its only used in this package ########## lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java: ########## @@ -122,6 +123,24 @@ public NodesIterator getNodesOnLevel(int level) { } }; + /** + * Add node on the given level with an empty set of neighbors. + * + * <p>Nodes can be inserted out of order, but it requires that the nodes preceded by the node + * inserted out of order are eventually added. + * + * <p>Actually populating the neighbors, and establishing bidirectional links, is the + * responsibility of the caller. + * + * <p>It is also the responsibility of the caller to ensure that each node is only added once. + * + * @param level level to add a node on + * @param node the node to add, represented as an ordinal on the level 0. + */ + public void addNode(int level, int node) { Review Comment: I don't think this should be part of this class. This class is an abstract HNSW Searcher class really. All the "onheap" flavors are for building graphs. There are more "search" flavors that do not allow additions. We shouldn't make this part of the HnswGraph base class. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java: ########## @@ -126,12 +124,16 @@ private int insertSortedInternal() { return insertionPoint; } - /** This method is for test only. */ - void insertSorted(int newNode, float newScore) { + protected void insertSorted(int newNode, float newScore) { Review Comment: this could stay package private. Unsure why it was made protected. ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java: ########## @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.GrowableBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.ThreadInterruptedException; +import org.apache.lucene.util.hnsw.ConcurrentNeighborSet.NeighborSimilarity; +import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph.NodeAtLevel; + +/** + * Builder for Concurrent HNSW graph. See {@link HnswGraph} for a high level overview, and the + * comments to `addGraphNode` for details on the concurrent building approach. + * + * @param <T> the type of vector + */ +public class ConcurrentHnswGraphBuilder<T> { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + + /** A name for the HNSW component for the info-stream */ + public static final String HNSW_COMPONENT = "HNSW"; + + private final int beamWidth; + private final double ml; + private final ExplicitThreadLocal<NeighborArray> scratchNeighbors; + + private final VectorSimilarityFunction similarityFunction; + private final VectorEncoding vectorEncoding; + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectors; + private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher; + private final ExplicitThreadLocal<NeighborQueue> beamCandidates; + + final ConcurrentOnHeapHnswGraph hnsw; + private final ConcurrentSkipListSet<NodeAtLevel> insertionsInProgress = + new ConcurrentSkipListSet<>(); + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectorsCopy; + + /** This is the "native" factory for ConcurrentHnswGraphBuilder. */ + public static <T> ConcurrentHnswGraphBuilder<T> create( + RandomAccessVectorValues<T> vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectorValues the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + */ + public ConcurrentHnswGraphBuilder( + RandomAccessVectorValues<T> vectorValues, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) { + this.vectors = createThreadSafeVectors(vectorValues); + this.vectorsCopy = createThreadSafeVectors(vectorValues); + this.vectorEncoding = Objects.requireNonNull(vectorEncoding); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + if (M <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + + NeighborSimilarity similarity = + new NeighborSimilarity() { + @Override + public float score(int node1, int node2) { + try { + return scoreBetween( + vectors.get().vectorValue(node1), vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ScoreFunction scoreProvider(int node1) { + T v1; + try { + v1 = vectors.get().vectorValue(node1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return node2 -> { + try { + return scoreBetween(v1, vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + }; + this.hnsw = new ConcurrentOnHeapHnswGraph(M, similarity); + + this.graphSearcher = + ExplicitThreadLocal.withInitial( + () -> { + return new HnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(beamWidth, true), + new GrowableBitSet(this.vectors.get().size())); + }); + // in scratch we store candidates in reverse order: worse candidates are first + this.scratchNeighbors = + ExplicitThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1), false)); + this.beamCandidates = + ExplicitThreadLocal.withInitial(() -> new NeighborQueue(beamWidth, false)); + } + + private abstract static class ExplicitThreadLocal<U> { + private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>(); + + public U get() { + return map.computeIfAbsent(Thread.currentThread().getId(), k -> initialValue()); + } + + protected abstract U initialValue(); + + public static <U> ExplicitThreadLocal<U> withInitial(Supplier<U> initialValue) { + return new ExplicitThreadLocal<U>() { + @Override + protected U initialValue() { + return initialValue.get(); + } + }; + } + } + + /** + * Bring-your-own ExecutorService graph builder. + * + * <p>Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an + * independent accessor for the vectors + * @param pool The ExecutorService to use. Must be an instance of ThreadPoolExecutor. + * @param concurrentTasks the number of tasks to submit in parallel. + */ + public Future<ConcurrentOnHeapHnswGraph> buildAsync( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + if (vectorsToAdd == this.vectors) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + } + return addVectors(vectorsToAdd, pool, concurrentTasks); + } + + // the goal here is to keep all the ExecutorService threads busy, but not to create potentially + // millions of futures by naively throwing everything at submit at once. So, we use + // a semaphore to wait until a thread is free before adding a new task. + private Future<ConcurrentOnHeapHnswGraph> addVectors( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + Semaphore semaphore = new Semaphore(concurrentTasks); + Set<Integer> inFlight = ConcurrentHashMap.newKeySet(); + AtomicReference<Throwable> asyncException = new AtomicReference<>(null); + + ExplicitThreadLocal<RandomAccessVectorValues<T>> threadSafeVectors = + createThreadSafeVectors(vectorsToAdd); + + for (int i = 0; i < vectorsToAdd.size(); i++) { + final int node = i; // copy for closure + try { + semaphore.acquire(); + inFlight.add(node); + pool.submit( + () -> { + try { + addGraphNode(node, threadSafeVectors.get()); + } catch (Throwable e) { + asyncException.set(e); + } finally { + semaphore.release(); + inFlight.remove(node); + } + }); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + + // return a future that will complete when the inflight set is empty + return CompletableFuture.supplyAsync( + () -> { + while (!inFlight.isEmpty()) { + try { + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + if (asyncException.get() != null) { + throw new CompletionException(asyncException.get()); + } + hnsw.validateEntryNode(); + return hnsw; + }); + } + + private static <T> ExplicitThreadLocal<RandomAccessVectorValues<T>> createThreadSafeVectors( + RandomAccessVectorValues<T> vectorValues) { + return ExplicitThreadLocal.withInitial( + () -> { + try { + return vectorValues.copy(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } + + /** + * Adds a node to the graph, with the vector at the same ordinal in the given provider. + * + * <p>See {@link #addGraphNode(int, Object)} for more details. + */ + public long addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException { + return addGraphNode(node, values.vectorValue(node)); + } + + /** Set info-stream to output debugging information * */ + public void setInfoStream(InfoStream infoStream) { + this.infoStream = infoStream; + } + + public ConcurrentOnHeapHnswGraph getGraph() { + return hnsw; + } + + /** Number of inserts in progress, across all threads. */ + public int insertsInProgress() { + return insertionsInProgress.size(); + } + + /** + * Inserts a doc with vector value to the graph. + * + * <p>To allow correctness under concurrency, we track in-progress updates in a + * ConcurrentSkipListSet. After adding ourselves, we take a snapshot of this set, and consider all + * other in-progress updates as neighbor candidates (subject to normal level constraints). + * + * @param node the node ID to add + * @param value the vector value to add + * @return an estimate of the number of extra bytes used by the graph after adding the given node + */ + public long addGraphNode(int node, T value) throws IOException { + // do this before adding to in-progress, so a concurrent writer checking + // the in-progress set doesn't have to worry about uninitialized neighbor sets + final int nodeLevel = getRandomGraphLevel(ml); + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + + HnswGraph consistentView = hnsw.getView(); + NodeAtLevel progressMarker = new NodeAtLevel(nodeLevel, node); + insertionsInProgress.add(progressMarker); + ConcurrentSkipListSet<NodeAtLevel> inProgressBefore = insertionsInProgress.clone(); + try { + // find ANN of the new node by searching the graph + NodeAtLevel entry = hnsw.entry(); + int ep = entry.node; + int[] eps = ep >= 0 ? new int[] {ep} : new int[0]; + var gs = graphSearcher.get(); + + // for levels > nodeLevel search with topk = 1 + NeighborQueue candidates = new NeighborQueue(1, false); + for (int level = entry.level; level > nodeLevel; level--) { + candidates.clear(); + gs.searchLevel( + candidates, + value, + 1, + level, + eps, + vectors.get(), + consistentView, + null, + Integer.MAX_VALUE); + eps = new int[] {candidates.pop()}; + } + + // for levels <= nodeLevel search with topk = beamWidth, and add connections + candidates = beamCandidates.get(); + for (int level = Math.min(nodeLevel, entry.level); level >= 0; level--) { + candidates.clear(); + // find best "natural" candidates at this level with a beam search + gs.searchLevel( + candidates, + value, + beamWidth, + level, + eps, + vectors.get(), + consistentView, + null, + Integer.MAX_VALUE); + eps = candidates.nodes(); + + // Update entry points and neighbors with these candidates. + // + // Note: We don't want to over-prune the neighbors, which can + // happen if we group the concurrent candidates and the natural candidates together. + // + // Consider the following graph with "circular" test vectors: + // + // 0 -> 1 + // 1 <- 0 + // At this point we insert nodes 2 and 3 concurrently, denoted T1 and T2 for threads 1 and 2 + // T1 T2 + // insert 2 to L1 [2 is marked "in progress"] + // insert 3 to L1 + // 3 considers as neighbors 0, 1, 2; 0 and 1 are not diverse wrt 2 + // 3 -> 2 is added to graph + // 3 is marked entry node + // 2 follows 3 to L0, where 3 only has 2 as a neighbor + // 2 -> 3 is added to graph + // all further nodes will only be added to the 2/3 subgraph; 0/1 are partitioned forever + // + // Considering concurrent inserts separately from natural candidates solves this problem; + // both 1 and 2 will be added as neighbors to 3, avoiding the partition, and 2 will then + // pick up the connection to 1 that it's supposed to have as well. + addForwardLinks(level, node, candidates); // natural candidates + addForwardLinks(level, node, inProgressBefore, progressMarker); // concurrent candidates + // Backlinking is the same for both natural and concurrent candidates. + addBackLinks(level, node); Review Comment: This seems to break the idea of HNSW, the asymmetry of the navigable small worlds is a feature not a bug. I guess I can understand potentially adding backlinks against the inprogress ones. But, that seems weird to me as the "inprogress" ones would also consider this node right? As this node is considered "inprogress" to those other ones? ########## lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswGraphBuilder.java: ########## @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static java.lang.Math.log; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.GrowableBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.ThreadInterruptedException; +import org.apache.lucene.util.hnsw.ConcurrentNeighborSet.NeighborSimilarity; +import org.apache.lucene.util.hnsw.ConcurrentOnHeapHnswGraph.NodeAtLevel; + +/** + * Builder for Concurrent HNSW graph. See {@link HnswGraph} for a high level overview, and the + * comments to `addGraphNode` for details on the concurrent building approach. + * + * @param <T> the type of vector + */ +public class ConcurrentHnswGraphBuilder<T> { + + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + + /** A name for the HNSW component for the info-stream */ + public static final String HNSW_COMPONENT = "HNSW"; + + private final int beamWidth; + private final double ml; + private final ExplicitThreadLocal<NeighborArray> scratchNeighbors; + + private final VectorSimilarityFunction similarityFunction; + private final VectorEncoding vectorEncoding; + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectors; + private final ExplicitThreadLocal<HnswGraphSearcher<T>> graphSearcher; + private final ExplicitThreadLocal<NeighborQueue> beamCandidates; + + final ConcurrentOnHeapHnswGraph hnsw; + private final ConcurrentSkipListSet<NodeAtLevel> insertionsInProgress = + new ConcurrentSkipListSet<>(); + + private InfoStream infoStream = InfoStream.getDefault(); + + // we need two sources of vectors in order to perform diversity check comparisons without + // colliding + private final ExplicitThreadLocal<RandomAccessVectorValues<T>> vectorsCopy; + + /** This is the "native" factory for ConcurrentHnswGraphBuilder. */ + public static <T> ConcurrentHnswGraphBuilder<T> create( + RandomAccessVectorValues<T> vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) + throws IOException { + return new ConcurrentHnswGraphBuilder<>( + vectors, vectorEncoding, similarityFunction, M, beamWidth); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectorValues the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – graph fanout parameter used to calculate the maximum number of connections a node + * can have – M on upper layers, and M * 2 on the lowest level. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + */ + public ConcurrentHnswGraphBuilder( + RandomAccessVectorValues<T> vectorValues, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth) { + this.vectors = createThreadSafeVectors(vectorValues); + this.vectorsCopy = createThreadSafeVectors(vectorValues); + this.vectorEncoding = Objects.requireNonNull(vectorEncoding); + this.similarityFunction = Objects.requireNonNull(similarityFunction); + if (M <= 0) { + throw new IllegalArgumentException("maxConn must be positive"); + } + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + this.beamWidth = beamWidth; + // normalization factor for level generation; currently not configurable + this.ml = M == 1 ? 1 : 1 / Math.log(1.0 * M); + + NeighborSimilarity similarity = + new NeighborSimilarity() { + @Override + public float score(int node1, int node2) { + try { + return scoreBetween( + vectors.get().vectorValue(node1), vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ScoreFunction scoreProvider(int node1) { + T v1; + try { + v1 = vectors.get().vectorValue(node1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return node2 -> { + try { + return scoreBetween(v1, vectorsCopy.get().vectorValue(node2)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + }; + this.hnsw = new ConcurrentOnHeapHnswGraph(M, similarity); + + this.graphSearcher = + ExplicitThreadLocal.withInitial( + () -> { + return new HnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(beamWidth, true), + new GrowableBitSet(this.vectors.get().size())); + }); + // in scratch we store candidates in reverse order: worse candidates are first + this.scratchNeighbors = + ExplicitThreadLocal.withInitial(() -> new NeighborArray(Math.max(beamWidth, M + 1), false)); + this.beamCandidates = + ExplicitThreadLocal.withInitial(() -> new NeighborQueue(beamWidth, false)); + } + + private abstract static class ExplicitThreadLocal<U> { + private final ConcurrentHashMap<Long, U> map = new ConcurrentHashMap<>(); + + public U get() { + return map.computeIfAbsent(Thread.currentThread().getId(), k -> initialValue()); + } + + protected abstract U initialValue(); + + public static <U> ExplicitThreadLocal<U> withInitial(Supplier<U> initialValue) { + return new ExplicitThreadLocal<U>() { + @Override + protected U initialValue() { + return initialValue.get(); + } + }; + } + } + + /** + * Bring-your-own ExecutorService graph builder. + * + * <p>Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * returned values. + * + * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an + * independent accessor for the vectors + * @param pool The ExecutorService to use. Must be an instance of ThreadPoolExecutor. + * @param concurrentTasks the number of tasks to submit in parallel. + */ + public Future<ConcurrentOnHeapHnswGraph> buildAsync( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + if (vectorsToAdd == this.vectors) { + throw new IllegalArgumentException( + "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); + } + if (infoStream.isEnabled(HNSW_COMPONENT)) { + infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors"); + } + return addVectors(vectorsToAdd, pool, concurrentTasks); + } + + // the goal here is to keep all the ExecutorService threads busy, but not to create potentially + // millions of futures by naively throwing everything at submit at once. So, we use + // a semaphore to wait until a thread is free before adding a new task. + private Future<ConcurrentOnHeapHnswGraph> addVectors( + RandomAccessVectorValues<T> vectorsToAdd, ExecutorService pool, int concurrentTasks) { + Semaphore semaphore = new Semaphore(concurrentTasks); + Set<Integer> inFlight = ConcurrentHashMap.newKeySet(); + AtomicReference<Throwable> asyncException = new AtomicReference<>(null); + + ExplicitThreadLocal<RandomAccessVectorValues<T>> threadSafeVectors = + createThreadSafeVectors(vectorsToAdd); + + for (int i = 0; i < vectorsToAdd.size(); i++) { + final int node = i; // copy for closure + try { + semaphore.acquire(); + inFlight.add(node); + pool.submit( + () -> { + try { + addGraphNode(node, threadSafeVectors.get()); + } catch (Throwable e) { + asyncException.set(e); + } finally { + semaphore.release(); + inFlight.remove(node); + } + }); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + + // return a future that will complete when the inflight set is empty + return CompletableFuture.supplyAsync( + () -> { + while (!inFlight.isEmpty()) { + try { + TimeUnit.MILLISECONDS.sleep(10); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + if (asyncException.get() != null) { + throw new CompletionException(asyncException.get()); + } + hnsw.validateEntryNode(); + return hnsw; + }); + } + + private static <T> ExplicitThreadLocal<RandomAccessVectorValues<T>> createThreadSafeVectors( + RandomAccessVectorValues<T> vectorValues) { + return ExplicitThreadLocal.withInitial( + () -> { + try { + return vectorValues.copy(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } + + /** + * Adds a node to the graph, with the vector at the same ordinal in the given provider. + * + * <p>See {@link #addGraphNode(int, Object)} for more details. + */ + public long addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException { + return addGraphNode(node, values.vectorValue(node)); + } + + /** Set info-stream to output debugging information * */ + public void setInfoStream(InfoStream infoStream) { + this.infoStream = infoStream; + } + + public ConcurrentOnHeapHnswGraph getGraph() { + return hnsw; + } + + /** Number of inserts in progress, across all threads. */ + public int insertsInProgress() { + return insertionsInProgress.size(); + } + + /** + * Inserts a doc with vector value to the graph. + * + * <p>To allow correctness under concurrency, we track in-progress updates in a + * ConcurrentSkipListSet. After adding ourselves, we take a snapshot of this set, and consider all + * other in-progress updates as neighbor candidates (subject to normal level constraints). + * + * @param node the node ID to add + * @param value the vector value to add + * @return an estimate of the number of extra bytes used by the graph after adding the given node + */ + public long addGraphNode(int node, T value) throws IOException { + // do this before adding to in-progress, so a concurrent writer checking + // the in-progress set doesn't have to worry about uninitialized neighbor sets + final int nodeLevel = getRandomGraphLevel(ml); + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); + } + + HnswGraph consistentView = hnsw.getView(); + NodeAtLevel progressMarker = new NodeAtLevel(nodeLevel, node); + insertionsInProgress.add(progressMarker); + ConcurrentSkipListSet<NodeAtLevel> inProgressBefore = insertionsInProgress.clone(); + try { + // find ANN of the new node by searching the graph + NodeAtLevel entry = hnsw.entry(); + int ep = entry.node; + int[] eps = ep >= 0 ? new int[] {ep} : new int[0]; + var gs = graphSearcher.get(); + + // for levels > nodeLevel search with topk = 1 + NeighborQueue candidates = new NeighborQueue(1, false); + for (int level = entry.level; level > nodeLevel; level--) { + candidates.clear(); + gs.searchLevel( + candidates, + value, + 1, + level, + eps, + vectors.get(), + consistentView, + null, + Integer.MAX_VALUE); + eps = new int[] {candidates.pop()}; + } + + // for levels <= nodeLevel search with topk = beamWidth, and add connections + candidates = beamCandidates.get(); + for (int level = Math.min(nodeLevel, entry.level); level >= 0; level--) { + candidates.clear(); + // find best "natural" candidates at this level with a beam search + gs.searchLevel( + candidates, + value, + beamWidth, + level, + eps, + vectors.get(), + consistentView, + null, + Integer.MAX_VALUE); + eps = candidates.nodes(); + + // Update entry points and neighbors with these candidates. + // + // Note: We don't want to over-prune the neighbors, which can + // happen if we group the concurrent candidates and the natural candidates together. + // + // Consider the following graph with "circular" test vectors: + // + // 0 -> 1 + // 1 <- 0 + // At this point we insert nodes 2 and 3 concurrently, denoted T1 and T2 for threads 1 and 2 + // T1 T2 + // insert 2 to L1 [2 is marked "in progress"] + // insert 3 to L1 + // 3 considers as neighbors 0, 1, 2; 0 and 1 are not diverse wrt 2 + // 3 -> 2 is added to graph + // 3 is marked entry node + // 2 follows 3 to L0, where 3 only has 2 as a neighbor + // 2 -> 3 is added to graph + // all further nodes will only be added to the 2/3 subgraph; 0/1 are partitioned forever + // + // Considering concurrent inserts separately from natural candidates solves this problem; + // both 1 and 2 will be added as neighbors to 3, avoiding the partition, and 2 will then + // pick up the connection to 1 that it's supposed to have as well. + addForwardLinks(level, node, candidates); // natural candidates + addForwardLinks(level, node, inProgressBefore, progressMarker); // concurrent candidates Review Comment: The "overpruning" seems like it is controlled by beam-width and the furtherest candidate. Why do we ignore the furthest candidate here? It seems like we shouldn't even bother with forward links with the in progress candidates if none of them are closer than the currently available one. Because, if they were already part of the graph, they would likely be ignored. Why aren't beam width and ignoring candidates further than the current furthest one not adequate? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org