diff --git a/README.md b/README.md index 7ba559888..a21d6fa4a 100644 --- a/README.md +++ b/README.md @@ -10,17 +10,19 @@ There are two broad categories of ANN index: Graph-based indexes tend to be simpler to implement and faster, but more importantly they can be constructed and updated incrementally. This makes them a much better fit for a general-purpose index than partitioning approaches that only work on static datasets that are completely specified up front. That is why all the major commercial vector indexes use graph approaches. -JVector is a graph index in the DiskANN family tree. +JVector is a graph index that merges the DiskANN and HNSW family trees. +JVector borrows the hierarchical structure from HNSW, and uses Vamana (the algorithm behind DiskANN) within each layer. ## JVector Architecture -JVector is a graph-based index that builds on the DiskANN design with composeable extensions. +JVector is a graph-based index that builds on the HNSW and DiskANN designs with composable extensions. -JVector implements a single-layer graph with nonblocking concurrency control, allowing construction to scale linearly with the number of cores: +JVector implements a multi-layer graph with nonblocking concurrency control, allowing construction to scale linearly with the number of cores: ![JVector scales linearly as thread count increases](https://github.com/jbellis/jvector/assets/42158/f0127bfc-6c45-48b9-96ea-95b2120da0d9) -The graph is represented by an on-disk adjacency list per node, with additional data stored inline to support two-pass searches, with the first pass powered by lossily compressed representations of the vectors kept in memory, and the second by a more accurate representation read from disk. The first pass can be performed with +The upper layers of the hierarchy are represnted by an in-memory adjacency list per node. This allows for quick navigation with no IOs. +The bottom layer of the graph is represented by an on-disk adjacency list per node. JVector uses additional data stored inline to support two-pass searches, with the first pass powered by lossily compressed representations of the vectors kept in memory, and the second by a more accurate representation read from disk. The first pass can be performed with * Product quantization (PQ), optionally with [anisotropic weighting](https://arxiv.org/abs/1908.10396) * [Binary quantization](https://huggingface.co/blog/embedding-quantization) (BQ) * Fused ADC, where PQ codebooks are transposed and written inline with the graph adjacency list @@ -51,7 +53,7 @@ First the code: int originalDimension = baseVectors.get(0).length(); // wrap the raw vectors in a RandomAccessVectorValues RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); - + // score provider using the raw, in-memory vectors BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, @@ -59,7 +61,8 @@ First the code: 16, // graph degree 100, // construction search depth 1.2f, // allow degree overflow during construction by this factor - 1.2f)) // relax neighbor diversity requirement by this factor + 1.2f, // relax neighbor diversity requirement by this factor (alpha) + true)) // use a hierarchical index { // build the index (in memory) OnHeapGraphIndex index = builder.build(ravv); @@ -86,6 +89,7 @@ Commentary: * For the overflow Builder parameter, the sweet spot is about 1.2 for in-memory construction and 1.5 for on-disk. (The more overflow is allowed, the fewer recomputations of best edges are required, but the more neighbors will be consulted in every search.) * The alpha parameter controls the tradeoff between edge distance and diversity; usually 1.2 is sufficient for high-dimensional vectors; 2.0 is recommended for 2D or 3D datasets. See [the DiskANN paper](https://suhasjs.github.io/files/diskann_neurips19.pdf) for more details. * The Bits parameter to GraphSearcher is intended for controlling your resultset based on external predicates and won’t be used in this tutorial. +* Setting the addHierarchy parameter to true, build a multi-layer index. This approach has proven more robust in highly challenging scenarios. #### Step 2: more control over GraphSearcher @@ -129,7 +133,7 @@ This is expected given the approximate nature of the index being created and the The code: ```java Path indexPath = Files.createTempFile("siftsmall", ".inline"); - try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f)) { + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, true)) { // build the index (in memory) OnHeapGraphIndex index = builder.build(ravv); // write the index to disk with default options @@ -218,7 +222,7 @@ Then we need to set up an OnDiskGraphIndexWriter with full control over the cons Path indexPath = Files.createTempFile("siftsmall", ".inline"); Path pqPath = Files.createTempFile("siftsmall", ".pq"); // Builder creation looks mostly the same - try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f); + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, true); // explicit Writer for the first time, this is what's behind OnDiskGraphIndex.write OnDiskGraphIndexWriter writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexPath) .with(new InlineVectors(ravv.dimension())) @@ -259,7 +263,7 @@ Commentary: ### Less-obvious points -* Embeddings models product output from a consistent distribution of vectors. This means that you can save and re-use ProductQuantization codebooks, even for a different set of vectors, as long as you had a sufficiently large training set to build it the first time around. ProductQuantization.MAX_PQ_TRAINING_SET_SIZE (128,000 vectors) has proven to be sufficiently large. +* Embeddings models produce output from a consistent distribution of vectors. This means that you can save and re-use ProductQuantization codebooks, even for a different set of vectors, as long as you had a sufficiently large training set to build it the first time around. ProductQuantization.MAX_PQ_TRAINING_SET_SIZE (128,000 vectors) has proven to be sufficiently large. * JDK ThreadLocal objects cannot be referenced except from the thread that created them. This is a difficult design into which to fit caching of Closeable objects like GraphSearcher. JVector provides the ExplicitThreadLocal class to solve this. * Fused ADC is only compatible with Product Quantization, not Binary Quantization. This is no great loss since [very few models generate embeddings that are best suited for BQ](https://thenewstack.io/why-vector-size-matters/). That said, BQ continues to be supported with non-Fused indexes. * JVector heavily utilizes the Panama Vector API(SIMD) for ANN indexing and search. We have seen cases where the memory bandwidth is saturated during indexing and product quantization and can cause the process to slow down. To avoid this, the batch methods for index and PQ builds use a [PhysicalCoreExecutor](https://javadoc.io/doc/io.github.jbellis/jvector/latest/io/github/jbellis/jvector/util/PhysicalCoreExecutor.html) to limit the amount of operations to the physical core count. The default value is 1/2 the processor count seen by Java. This may not be correct in all setups (e.g. no hyperthreading or hybrid architectures) so if you wish to override the default use the `-Djvector.physical_core_count` property, or pass in your own ForkJoinPool instance. diff --git a/UPGRADING.md b/UPGRADING.md index e4c58ca0d..7dfbb317d 100644 --- a/UPGRADING.md +++ b/UPGRADING.md @@ -5,11 +5,22 @@ in each vector with high accuracy by first applying a nonlinear transformation that is individually fit to each vector. These nonlinearities are designed to be lightweight and have a negligible impact on distance computation performance. +- Support for hierarchical graph indices. This new type of index blends HNSW and DiskANN in a novel way. An + HNSW-like hierarchy resides in memory for quickly seeding the search. This also reduces the need for caching the + DiskANN graph near the entrypoint. The base layer of the hierarchy is a DiskANN-like index and inherits its + properties. This hierarchical structure can be disabled, ending up with just the base DiskANN layer. ## API changes - MemorySegmentReader.Supplier and SimpleMappedReader.Supplier must now be explicitly closed, instead of being closed by the first Reader created from them. - OnDiskGraphIndex no longer closes its ReaderSupplier +- The constructor of GraphIndexBuilder takes an additional parameter which allows to enable or disable the use of the + hierarchy. +- GraphSearcher can be configured to run pruned searches using GraphSearcher.usePruning. When this is set to true, + we do early termination of the search. In certain cases, this can accelerate the search at the potential cost of some + accuracy. It is set to false by default. +- The constructors of GraphIndexBuilder allow to specify different maximum out-degrees for the graphs in each layer. + However, this feature does not work with FusedADC in this version. ### API changes in 3.0.6 diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java index b16df49e2..61794c23a 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java @@ -78,7 +78,7 @@ public void tearDown() throws IOException { @Benchmark public void buildIndexBenchmark(Blackhole blackhole) throws IOException { // score provider using the raw, in-memory vectors - try (final var graphIndexBuilder = new GraphIndexBuilder(bsp, ravv.dimension(), M, beamWidth, 1.2f, 1.2f)) { + try (final var graphIndexBuilder = new GraphIndexBuilder(bsp, ravv.dimension(), M, beamWidth, 1.2f, 1.2f, true)) { final var graphIndex = graphIndexBuilder.build(ravv); blackhole.consume(graphIndex); } diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RandomVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RandomVectorsBenchmark.java index 33e7d9f7d..6f51fb0a1 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RandomVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RandomVectorsBenchmark.java @@ -81,7 +81,8 @@ public void setup() throws IOException { 16, // graph degree 100, // construction search depth 1.2f, // allow degree overflow during construction by this factor - 1.2f); // relax neighbor diversity requirement by this factor + 1.2f, // relax neighbor diversity requirement by this factor + true); // add the hierarchy graphIndex = graphIndexBuilder.build(ravv); } diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java index 184aa2145..1e10c425d 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java @@ -69,7 +69,8 @@ public void setup() throws IOException { 16, // graph degree 100, // construction search depth 1.2f, // allow degree overflow during construction by this factor - 1.2f); // relax neighbor diversity requirement by this factor + 1.2f, // relax neighbor diversity requirement by this factor + true); // add the hierarchy graphIndex = graphIndexBuilder.build(ravv); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java index 55f578943..610069b7a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java @@ -74,7 +74,12 @@ public int readInt() { } @Override - public float readFloat() throws IOException { + public long readLong() { + return bb.getLong(); + } + + @Override + public float readFloat() { return bb.getFloat(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java index f3bea3c82..4a25457b2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java @@ -40,6 +40,8 @@ public interface RandomAccessReader extends AutoCloseable { float readFloat() throws IOException; + long readLong() throws IOException; + void readFully(byte[] bytes) throws IOException; void readFully(ByteBuffer buffer) throws IOException; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java index caf79f469..fb0634029 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java @@ -48,6 +48,11 @@ public int readInt() throws IOException { return raf.readInt(); } + @Override + public long readLong() throws IOException { + return raf.readLong(); + } + @Override public float readFloat() throws IOException { return raf.readFloat(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 03a034d93..f25f61d8c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -24,6 +24,7 @@ import io.github.jbellis.jvector.util.DenseIntMap; import io.github.jbellis.jvector.util.DocIdSetIterator; import io.github.jbellis.jvector.util.FixedBitSet; +import io.github.jbellis.jvector.util.IntMap; import static java.lang.Math.min; @@ -31,13 +32,13 @@ * Encapsulates operations on a graph's neighbors. */ public class ConcurrentNeighborMap { - private final DenseIntMap neighbors; + final IntMap neighbors; /** the diversity threshold; 1.0 is equivalent to HNSW; Vamana uses 1.2 or more */ - private final float alpha; + final float alpha; /** used to compute diversity */ - private final BuildScoreProvider scoreProvider; + final BuildScoreProvider scoreProvider; /** the maximum number of neighbors desired per node */ public final int maxDegree; @@ -45,11 +46,16 @@ public class ConcurrentNeighborMap { public final int maxOverflowDegree; public ConcurrentNeighborMap(BuildScoreProvider scoreProvider, int maxDegree, int maxOverflowDegree, float alpha) { + this(new DenseIntMap<>(1024), scoreProvider, maxDegree, maxOverflowDegree, alpha); + } + + public ConcurrentNeighborMap(IntMap neighbors, BuildScoreProvider scoreProvider, int maxDegree, int maxOverflowDegree, float alpha) { + assert maxDegree <= maxOverflowDegree : String.format("maxDegree %d exceeds maxOverflowDegree %d", maxDegree, maxOverflowDegree); + this.neighbors = neighbors; this.alpha = alpha; this.scoreProvider = scoreProvider; this.maxDegree = maxDegree; this.maxOverflowDegree = maxOverflowDegree; - neighbors = new DenseIntMap<>(1024); } public void insertEdge(int fromId, int toId, float score, float overflow) { @@ -103,6 +109,7 @@ public void replaceDeletedNeighbors(int nodeId, BitSet toDelete, NodeArray candi public Neighbors insertDiverse(int nodeId, NodeArray candidates) { while (true) { var old = neighbors.get(nodeId); + assert old != null : nodeId; // graph.addNode should always be called before this method var next = old.insertDiverse(candidates, this); if (next == old || neighbors.compareAndPut(nodeId, old, next)) { return next; @@ -132,10 +139,6 @@ public void addNode(int nodeId) { addNode(nodeId, new NodeArray(0)); } - public NodesIterator nodesIterator() { - return neighbors.keysIterator(); - } - public Neighbors remove(int node) { return neighbors.remove(node); } @@ -262,7 +265,9 @@ private Neighbors insertDiverse(NodeArray toMerge, ConcurrentNeighborMap map) { retainDiverse(merged, 0, map); } // insertDiverse usually gets called with a LOT of candidates, so trim down the resulting NodeArray - var nextNodes = merged.getArrayLength() <= map.nodeArrayLength() ? merged : merged.copy(map.nodeArrayLength()); + var nextNodes = merged.getArrayLength() <= map.nodeArrayLength() + ? merged + : merged.copy(map.nodeArrayLength()); return new Neighbors(nodeId, nextNodes); } @@ -402,16 +407,20 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) { } } - private static class NeighborIterator extends NodesIterator { + private static class NeighborIterator implements NodesIterator { private final NodeArray neighbors; private int i; private NeighborIterator(NodeArray neighbors) { - super(neighbors.size()); this.neighbors = neighbors; i = 0; } + @Override + public int size() { + return neighbors.size(); + } + @Override public boolean hasNext() { return i < neighbors.size(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java index 6eb09f02a..602bf71f1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java @@ -29,6 +29,7 @@ import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; +import java.util.Objects; import java.io.Closeable; import java.io.IOException; @@ -45,7 +46,10 @@ */ public interface GraphIndex extends AutoCloseable, Accountable { /** Returns the number of nodes in the graph */ - int size(); + @Deprecated + default int size() { + return size(0); + } /** * Get all node ordinals included in the graph. The nodes are NOT guaranteed to be @@ -53,7 +57,7 @@ public interface GraphIndex extends AutoCloseable, Accountable { * * @return an iterator over nodes where {@code nextInt} returns the next node. */ - NodesIterator getNodes(); + NodesIterator getNodes(int level); /** * Return a View with which to navigate the graph. Views are not threadsafe -- that is, @@ -68,7 +72,7 @@ public interface GraphIndex extends AutoCloseable, Accountable { View getView(); /** - * @return the maximum number of edges per node + * @return the maximum number of edges per node across any layer */ int maxDegree(); @@ -91,6 +95,25 @@ default boolean containsNode(int nodeId) { @Override void close() throws IOException; + /** + * @return The maximum (coarser) level that contains a vector in the graph. + */ + int getMaxLevel(); + + /** + * Return the maximum out-degree allowed of the given level. + * @param level The level of interest + * @return the maximum out-degree of the given level + */ + int getDegree(int level); + + /** + * Return the number of vectors/nodes in the given level. + * @param level The level of interest + * @return the number of vectors in the given level + */ + int size(int level); + /** * Encapsulates the state of a graph for searching. Re-usable across search calls, * but each thread needs its own. @@ -100,17 +123,20 @@ interface View extends Closeable { * Iterator over the neighbors of a given node. Only the most recently instantiated iterator * is guaranteed to be valid. */ - NodesIterator getNeighborsIterator(int node); + NodesIterator getNeighborsIterator(int level, int node); /** + * This method is deprecated as most View usages should not need size. + * Where they do, they could access the graph. * @return the number of nodes in the graph */ + @Deprecated int size(); /** * @return the node of the graph to start searches at */ - int entryNode(); + NodeAtLevel entryNode(); /** * Return a Bits instance indicating which nodes are live. The result is undefined for @@ -141,12 +167,16 @@ static String prettyPrint(GraphIndex graph) { sb.append("\n"); try (var view = graph.getView()) { - NodesIterator it = graph.getNodes(); - while (it.hasNext()) { - int node = it.nextInt(); - sb.append(" ").append(node).append(" -> "); - for (var neighbors = view.getNeighborsIterator(node); neighbors.hasNext(); ) { - sb.append(" ").append(neighbors.nextInt()); + for (int level = 0; level <= graph.getMaxLevel(); level++) { + sb.append(String.format("# Level %d\n", level)); + NodesIterator it = graph.getNodes(level); + while (it.hasNext()) { + int node = it.nextInt(); + sb.append(" ").append(node).append(" -> "); + for (var neighbors = view.getNeighborsIterator(level, node); neighbors.hasNext(); ) { + sb.append(" ").append(neighbors.nextInt()); + } + sb.append("\n"); } sb.append("\n"); } @@ -156,4 +186,44 @@ static String prettyPrint(GraphIndex graph) { return sb.toString(); } + + // Comparable b/c it gets used in ConcurrentSkipListMap + final class NodeAtLevel implements Comparable { + public final int level; + public final int node; + + public NodeAtLevel(int level, int node) { + assert level >= 0 : level; + assert node >= 0 : node; + this.level = level; + this.node = node; + } + + @Override + public int compareTo(NodeAtLevel o) { + int cmp = Integer.compare(level, o.level); + if (cmp == 0) { + cmp = Integer.compare(node, o.node); + } + return cmp; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof NodeAtLevel)) return false; + NodeAtLevel that = (NodeAtLevel) o; + return level == that.level && node == that.node; + } + + @Override + public int hashCode() { + return Objects.hash(level, node); + } + + @Override + public String toString() { + return "NodeAtLevel(level=" + level + ", node=" + node + ")"; + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 3f35b72b1..ef82243e2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -18,10 +18,11 @@ import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel; +import io.github.jbellis.jvector.graph.SearchResult.NodeScore; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.util.AtomicFixedBitSet; -import io.github.jbellis.jvector.util.BitSet; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.ExceptionUtils; import io.github.jbellis.jvector.util.ExplicitThreadLocal; @@ -29,15 +30,13 @@ import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; -import org.agrona.collections.IntArrayQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Set; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.ForkJoinPool; @@ -45,9 +44,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; -import static io.github.jbellis.jvector.graph.OnHeapGraphIndex.NO_ENTRY_POINT; import static io.github.jbellis.jvector.util.DocIdSetIterator.NO_MORE_DOCS; -import static io.github.jbellis.jvector.vector.VectorUtil.dotProduct; +import static java.lang.Math.*; /** * Builder for Concurrent GraphIndex. See {@link GraphIndex} for a high level overview, and the @@ -59,6 +57,8 @@ * that spawning a new Thread per call is not advisable. This includes virtual threads. */ public class GraphIndexBuilder implements Closeable { + private static final int BUILD_BATCH_SIZE = 50; + private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class); private final int beamWidth; @@ -68,12 +68,12 @@ public class GraphIndexBuilder implements Closeable { private final int dimension; private final float neighborOverflow; private final float alpha; + private final boolean addHierarchy; @VisibleForTesting final OnHeapGraphIndex graph; - private double averageShortEdges = Double.NaN; - private final ConcurrentSkipListSet insertionsInProgress = new ConcurrentSkipListSet<>(); + private final ConcurrentSkipListSet insertionsInProgress = new ConcurrentSkipListSet<>(); private final BuildScoreProvider scoreProvider; @@ -82,7 +82,7 @@ public class GraphIndexBuilder implements Closeable { private final ExplicitThreadLocal searchers; - private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10_000); + private final Random rng; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -97,20 +97,23 @@ public class GraphIndexBuilder implements Closeable { * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. */ public GraphIndexBuilder(RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, - float alpha) + float alpha, + boolean addHierarchy) { this(BuildScoreProvider.randomAccessScoreProvider(vectorValues, similarityFunction), - vectorValues.dimension(), - M, - beamWidth, - neighborOverflow, - alpha); + vectorValues.dimension(), + M, + beamWidth, + neighborOverflow, + alpha, + addHierarchy); } /** @@ -126,15 +129,17 @@ public GraphIndexBuilder(RandomAccessVectorValues vectorValues, * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. */ public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, - float alpha) + float alpha, + boolean addHierarchy) { - this(scoreProvider, dimension, M, beamWidth, neighborOverflow, alpha, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + this(scoreProvider, dimension, M, beamWidth, neighborOverflow, alpha, addHierarchy, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); } /** @@ -149,6 +154,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of * the number of physical cores. * @param parallelExecutor ForkJoinPool instance for parallel stream operations @@ -159,11 +165,73 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, int beamWidth, float neighborOverflow, float alpha, + boolean addHierarchy, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) { - if (M <= 0) { - throw new IllegalArgumentException("maxConn must be positive"); + this(scoreProvider, dimension, List.of(M), beamWidth, neighborOverflow, alpha, addHierarchy, simdExecutor, parallelExecutor); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * Default executor pools are used. + * + * @param scoreProvider describes how to determine the similarities between vectors + * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries + * * are specified than the number of layers, the last entry is used for all remaining layers. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a + * node. larger values will build more efficiently, but use more memory. + * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to + * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of + * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. + */ + public GraphIndexBuilder(BuildScoreProvider scoreProvider, + int dimension, + List maxDegrees, + int beamWidth, + float neighborOverflow, + float alpha, + boolean addHierarchy) + { + this(scoreProvider, dimension, maxDegrees, beamWidth, neighborOverflow, alpha, addHierarchy, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param scoreProvider describes how to determine the similarities between vectors + * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries + * are specified than the number of layers, the last entry is used for all remaining layers. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a + * node. larger values will build more efficiently, but use more memory. + * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to + * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of + * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. + * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of + * the number of physical cores. + * @param parallelExecutor ForkJoinPool instance for parallel stream operations + */ + public GraphIndexBuilder(BuildScoreProvider scoreProvider, + int dimension, + List maxDegrees, + int beamWidth, + float neighborOverflow, + float alpha, + boolean addHierarchy, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) + { + if (maxDegrees.stream().anyMatch(i -> i <= 0)) { + throw new IllegalArgumentException("layer degrees must be positive"); + } + if (maxDegrees.size() > 1 && !addHierarchy) { + throw new IllegalArgumentException("Cannot specify multiple max degrees with addHierarchy=False"); } if (beamWidth <= 0) { throw new IllegalArgumentException("beamWidth must be positive"); @@ -179,48 +247,67 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.dimension = dimension; this.neighborOverflow = neighborOverflow; this.alpha = alpha; + this.addHierarchy = addHierarchy; this.beamWidth = beamWidth; this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - int maxOverflowDegree = (int) (M * neighborOverflow); - this.graph = new OnHeapGraphIndex(M, maxOverflowDegree, scoreProvider, alpha); - this.searchers = ExplicitThreadLocal.withInitial(() -> new GraphSearcher(graph)); + this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, scoreProvider, alpha, BUILD_BATCH_SIZE); + this.searchers = ExplicitThreadLocal.withInitial(() -> { + var gs = new GraphSearcher(graph); + gs.usePruning(false); + return gs; + }); // in scratch we store candidates in reverse order: worse candidates are first - this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1))); - this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1))); + this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + + this.rng = new Random(0); } + // used by Cassandra when it fine-tunes the PQ codebook public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, - other.dimension, - other.graph.maxDegree(), - other.beamWidth, - other.neighborOverflow, - other.alpha, - other.simdExecutor, - other.parallelExecutor); + other.dimension, + other.graph.maxDegrees, + other.beamWidth, + other.neighborOverflow, + other.alpha, + other.addHierarchy, + other.simdExecutor, + other.parallelExecutor); // Copy each node and its neighbors from the old graph to the new one - for (int i = 0; i < other.graph.getIdUpperBound(); i++) { - if (!other.graph.containsNode(i)) { - continue; - } - - var neighbors = other.graph.getNeighbors(i); - var sf = newProvider.searchProviderFor(i).scoreFunction(); - var newNeighbors = new NodeArray(neighbors.size()); - - // Copy edges, compute new scores - for (var it = neighbors.iterator(); it.hasNext(); ) { - int neighbor = it.nextInt(); - // since we're using a different score provider, use insertSorted instead of addInOrder - newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); - } + other.parallelExecutor.submit(() -> { + IntStream.range(0, other.graph.getIdUpperBound()).parallel().forEach(i -> { + // Find the highest layer this node exists in + int maxLayer = -1; + for (int lvl = 0; lvl < other.graph.layers.size(); lvl++) { + if (other.graph.getNeighbors(lvl, i) == null) { + break; + } + maxLayer = lvl; + } + if (maxLayer < 0) { + return; + } - newBuilder.graph.addNode(i, newNeighbors); - } + // Loop over 0..maxLayer, re-score neighbors for each layer + var sf = newProvider.searchProviderFor(i).scoreFunction(); + for (int lvl = 0; lvl <= maxLayer; lvl++) { + var oldNeighbors = other.graph.getNeighbors(lvl, i); + // Copy edges, compute new scores + var newNeighbors = new NodeArray(oldNeighbors.size()); + for (var it = oldNeighbors.iterator(); it.hasNext();) { + int neighbor = it.nextInt(); + // since we're using a different score provider, use insertSorted instead of addInOrder + newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); + } + newBuilder.graph.addNode(lvl, i, newNeighbors); + } + }); + }).join(); // Set the entry node newBuilder.graph.updateEntryNode(other.graph.entry()); @@ -233,7 +320,9 @@ public OnHeapGraphIndex build(RandomAccessVectorValues ravv) { int size = ravv.size(); simdExecutor.submit(() -> { - IntStream.range(0, size).parallel().forEach(node -> addGraphNode(node, vv.get().getVector(node))); + IntStream.range(0, size).parallel().forEach(node -> { + addGraphNode(node, vv.get().getVector(node)); + }); }).join(); cleanup(); @@ -251,7 +340,7 @@ public OnHeapGraphIndex build(RandomAccessVectorValues ravv) { * May be called multiple times, but should not be called during concurrent modifications to the graph. */ public void cleanup() { - if (graph.size() == 0) { + if (graph.size(0) == 0) { return; } graph.validateEntryNode(); // sanity check before we start @@ -260,148 +349,57 @@ public void cleanup() { // backlinks can cause neighbors to soft-overflow, so do this before neighbors cleanup removeDeletedNodes(); - if (graph.size() == 0) { + if (graph.size(0) == 0) { // After removing all the deleted nodes, we might end up with an empty graph. // The calls below expect a valid entry node, but we do not have one right now. return; } - // clean up overflowed neighbor lists and compute short edges - averageShortEdges = parallelExecutor.submit( - () -> IntStream.range(0, graph.getIdUpperBound()).parallel() - .mapToDouble(graph.nodes::enforceDegree) - .filter(Double::isFinite) - .average() - ).join().orElse(Double.NaN); - - // optimize entry node -- we do this before reconnecting, as otherwise, improving the entry node's - // connections will tend to disconnect any orphaned nodes reconnected to the entry node - updateEntryPoint(); - - // reconnect any orphaned nodes. this will maintain neighbors size - reconnectOrphanedNodes(); - } - - private void reconnectOrphanedNodes() { - // Set of nodes already used as connection targets, initialized to the entry point. Since reconnection edges are - // usually worse (by distance and/or diversity) than the original ones, we update this as edges are added to - // avoid reusing the same target node more than once. - AtomicFixedBitSet globalConnectionTargets = new AtomicFixedBitSet(graph.getIdUpperBound()); - globalConnectionTargets.set(graph.entry()); - // Reconnection is best-effort: reconnecting one node may result in disconnecting another, since we are maintaining - // the maxConnections invariant. So, we do a maximum of 5 loops. - for (int i = 0; i < 5; i++) { - // determine the nodes reachable from the entry point at the start of this pass - var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound()); - var entryNeighbors = graph.getNeighbors(graph.entry()); - parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size()).parallel().forEach(node -> findConnected(connectedNodes, entryNeighbors.getNode(node)))).join(); - - // Gather basic debug information about efficacy/efficiency of reconnection attempts - var nReconnectAttempts = new AtomicInteger(); - var nReconnectedViaNeighbors = new AtomicInteger(); - var nResumesRun = new AtomicInteger(); - var nReconnectedViaSearch = new AtomicInteger(); - - simdExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(node -> { - if (connectedNodes.get(node) || !graph.containsNode(node)) { - return; - } - nReconnectAttempts.incrementAndGet(); - - // first, attempt to connect one of our own connected neighbors to us. Filtering - // to connected nodes tends to help for partitioned graphs with large partitions. - ConcurrentNeighborMap.Neighbors self = graph.getNeighbors(node); - var neighbors = (NodeArray) self; - if (connectToClosestNeighbor(node, neighbors, connectedNodes, globalConnectionTargets) != null) { - nReconnectedViaNeighbors.incrementAndGet(); - return; - } - - // if we can't find a connected neighbor to reconnect to, we'll have to search. We start with a small - // search, and we resume the search in a bounded loop to try to find an eligible connection target. - // This significantly improves behavior for large (1M+ node) partitioned graphs. We don't add - // connectionTargets to excludeBits because large partitions lead to excessively large excludeBits, - // greatly degrading search performance. - SearchResult result; - try (var gs = searchers.get()) { - var ssp = scoreProvider.searchProviderFor(node); - int ep = graph.entry(); - result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, other -> other != node); - neighbors = new NodeArray(result.getNodes().length); - toScratchCandidates(result.getNodes(), neighbors); - var j = 0; - // no need to filter to connected nodes here, as they're connected by virtue of being reachable via - // search - var reconnectedTo = connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets); - // if we can't find a valid connectionTarget within 2*degree of the search destination, give up - while (reconnectedTo == null && j < 2 * graph.maxDegree) { - j++; - nResumesRun.incrementAndGet(); - result = gs.resume(beamWidth, beamWidth); - toScratchCandidates(result.getNodes(), neighbors); - reconnectedTo = connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets); - } - - if (reconnectedTo != null) { - nReconnectedViaSearch.incrementAndGet(); - // since we went to the trouble of finding the closest available neighbor, let `backlink` - // check to see if it should be added as an edge to the original node as well - var na = new NodeArray(1); - na.addInOrder(reconnectedTo.node, reconnectedTo.score); - graph.nodes.backlink(na, node, 1.0f); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } - })).join(); - - logger.debug("Reconnecting {} nodes out of {} on pass {}. {} neighbor reconnects. {} searches/resumes run. {} nodes reconnected via search", - nReconnectAttempts.get(), graph.size(), i, nReconnectedViaNeighbors.get(), nResumesRun.get(), nReconnectedViaSearch.get()); - - if (nReconnectAttempts.get() == 0) { - break; - } + // improve connections on everything in L1 + if (graph.getMaxLevel() > 0) { + parallelExecutor.submit(() -> { + graph.nodeStream(1).parallel().forEach(this::improveConnections); + }).join(); } - } - /** - * Connect `node` to the closest connected neighbor that is not already a connection target. - * - * @return the node score id if such a neighbor was found, else null. - */ - private SearchResult.NodeScore connectToClosestNeighbor(int node, NodeArray neighbors, Bits connectedNodes, BitSet connectionTargets) { - // connect this node to the closest connected neighbor that hasn't already been used as a connection target - // (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be - // overwritten by the next node to need reconnection if we don't choose a unique target) - for (int i = 0; i < neighbors.size(); i++) { - var neighborNode = neighbors.getNode(i); - if (!connectedNodes.get(neighborNode) || connectionTargets.get(neighborNode)) - continue; - - var neighborScore = neighbors.getScore(i); - graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore); - connectionTargets.set(neighborNode); - return new SearchResult.NodeScore(neighborNode, neighborScore); - } - return null; + // clean up overflowed neighbor lists + parallelExecutor.submit(() -> { + IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> { + for (int layer = 0; layer < graph.layers.size(); layer++) { + graph.layers.get(layer).enforceDegree(id); + } + }); + }).join(); } - private void findConnected(AtomicFixedBitSet connectedNodes, int start) { - var queue = new IntArrayQueue(); - queue.add(start); - try (var view = graph.getView()) { - while (!queue.isEmpty()) { - // DFS should result in less contention across findConnected threads than BFS - int next = queue.pollInt(); - if (connectedNodes.getAndSet(next)) { - continue; - } - for (var it = view.getNeighborsIterator(next); it.hasNext(); ) { - queue.addInt(it.nextInt()); + private void improveConnections(int node) { + var ssp = scoreProvider.searchProviderFor(node); + var bits = new ExcludingBits(node); + try (var gs = searchers.get()) { + gs.initializeInternal(ssp, graph.entry(), bits); + var acceptedBits = Bits.intersectionOf(bits, gs.getView().liveNodes()); + + // Move downward from entry.level to 0 + for (int lvl = graph.entry().level; lvl >= 0; lvl--) { + // This additional call seems redundant given that we have already initialized an ssp above. + // However, there is a subtle interplay between the ssp of the search and the ssp used in insertDiverse. + // Do not remove this line. + ssp = scoreProvider.searchProviderFor(node); + + if (graph.layers.get(lvl).get(node) != null) { + gs.searchOneLayer(ssp, beamWidth, 0.0f, lvl, acceptedBits); + + var candidates = new NodeArray(gs.approximateResults.size()); + gs.approximateResults.foreach(candidates::insertSorted); + var newNeighbors = graph.layers.get(lvl).insertDiverse(node, candidates); + graph.layers.get(lvl).backlink(newNeighbors, node, neighborOverflow); + } else { + gs.searchOneLayer(ssp, 1, 0.0f, lvl, acceptedBits); } + gs.setEntryPointsFromPreviousLayer(); } - } catch (Exception e) { - throw new RuntimeException(e); + } catch (IOException e) { + throw new UncheckedIOException(e); } } @@ -423,6 +421,25 @@ public long addGraphNode(int node, RandomAccessVectorValues ravv) { return addGraphNode(node, ravv.getVector(node)); } + /** + * Assigns a hierarchy level to a node at random. It follows the HNSW sampling strategy. + * @return The assigned level + */ + private int getRandomGraphLevel() { + double ml; + double randDouble; + if (addHierarchy) { + ml = graph.getDegree(0) == 1 ? 1 : 1 / log(1.0 * graph.getDegree(0)); + do { + randDouble = this.rng.nextDouble(); // avoid 0 value, as log(0) is undefined + } while (randDouble == 0.0); + } else { + ml = 0; + randDouble = 0; + } + return ((int) (-log(randDouble) * ml)); + } + /** * Inserts a node with the given vector value to the graph. * @@ -434,110 +451,76 @@ public long addGraphNode(int node, RandomAccessVectorValues ravv) { * @return an estimate of the number of extra bytes used by the graph after adding the given node */ public long addGraphNode(int node, VectorFloat vector) { + var nodeLevel = new NodeAtLevel(getRandomGraphLevel(), node); // do this before adding to in-progress, so a concurrent writer checking // the in-progress set doesn't have to worry about uninitialized neighbor sets - graph.addNode(node); + graph.addNode(nodeLevel); - insertionsInProgress.add(node); - ConcurrentSkipListSet inProgressBefore = insertionsInProgress.clone(); + insertionsInProgress.add(nodeLevel); + var inProgressBefore = insertionsInProgress.clone(); try (var gs = searchers.get()) { + gs.setView(graph.getView()); // new snapshot var naturalScratchPooled = naturalScratch.get(); var concurrentScratchPooled = concurrentScratch.get(); - // find ANN of the new node by searching the graph - int ep = graph.entry(); - var bits = new ExcludingBits(node); - // find best "natural" candidates with a beam search + var bits = new ExcludingBits(nodeLevel.node); var ssp = scoreProvider.searchProviderFor(vector); - var result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, bits); - - // Update neighbors with these candidates. - // The DiskANN paper calls for using the entire set of visited nodes along the search path as - // potential candidates, but in practice we observe neighbor lists being completely filled using - // just the topK results. (Since the Robust Prune algorithm prioritizes closer neighbors, - // this means that considering additional nodes from the search path, that are by definition - // farther away than the ones in the topK, would not change the result.) - // TODO if we made NeighborArray an interface we could wrap the NodeScore[] directly instead of copying - var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled); - var concurrent = getConcurrentCandidates(node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction()); - updateNeighbors(node, natural, concurrent); - - maybeUpdateEntryPoint(node); - maybeImproveOlderNode(); - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - insertionsInProgress.remove(node); - } - - return graph.ramBytesUsedOneNode(); - } - - /** - * Improve edge quality on very low-d indexes. This makes a big difference - * in the ability of search to escape local maxima to find better options. - *

- * This has negligible effect on ML embedding-sized vectors, starting at least with GloVe-25, so we don't bother. - * (Dimensions between 4 and 25 are untested but they get left out too.) - * For 2D vectors, this takes us to over 99% recall up to at least 4M nodes. (Higher counts untested.) - */ - private void maybeImproveOlderNode() { - // pick a node added earlier at random to improve its connections - // 20k threshold chosen because that's where recall starts to fall off from 100% for 2D vectors - if (dimension <= 3 && graph.size() > 20_000) { - // if we can't find a candidate in 3 tries, the graph is too sparse, - // we'll have to wait for more nodes to be added (this threshold has been tested w/ parallel build, - // which generates very sparse ids due to how spliterator works) - for (int i = 0; i < 3; i++) { - var olderNode = ThreadLocalRandom.current().nextInt(graph.size()); - if (graph.containsNode(olderNode) && !graph.getDeletedNodes().get(olderNode)) { - improveConnections(olderNode); - break; + var entry = graph.entry(); + SearchResult result; + if (entry == null) { + result = new SearchResult(new NodeScore[] {}, 0, 0, 0); + } else { + gs.initializeInternal(ssp, entry, bits); + + // Move downward from entry.level to 1 + for (int lvl = entry.level; lvl > 0; lvl--) { + if (lvl > nodeLevel.level) { + gs.searchOneLayer(ssp, 1, 0.0f, lvl, gs.getView().liveNodes()); + } else { + gs.searchOneLayer(ssp, beamWidth, 0.0f, lvl, gs.getView().liveNodes()); + NodeScore[] neighbors = new NodeScore[gs.approximateResults.size()]; + AtomicInteger index = new AtomicInteger(); + // TODO extract an interface that lets us avoid the copy here and in toScratchCandidates + gs.approximateResults.foreach((neighbor, score) -> { + neighbors[index.getAndIncrement()] = new NodeScore(neighbor, score); + }); + Arrays.sort(neighbors); + updateNeighborsOneLayer(lvl, nodeLevel.node, neighbors, naturalScratchPooled, inProgressBefore, concurrentScratchPooled, ssp); + } + gs.setEntryPointsFromPreviousLayer(); } + + // Now do the main search at layer 0 + result = gs.resume(0, beamWidth, beamWidth, 0.0f, 0.0f); } - } - } - private void maybeUpdateEntryPoint(int node) { - graph.maybeSetInitialEntryNode(node); // TODO it seems silly to call this long after we've set it the first time + updateNeighborsOneLayer(0, nodeLevel.node, result.getNodes(), naturalScratchPooled, inProgressBefore, concurrentScratchPooled, ssp); - if (updateEntryNodeIn.decrementAndGet() == 0) { - updateEntryPoint(); + graph.markComplete(nodeLevel); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + insertionsInProgress.remove(nodeLevel); } - } - @VisibleForTesting - public void setEntryPoint(int ep) { - graph.updateEntryNode(ep); + return IntStream.range(0, nodeLevel.level).mapToLong(graph::ramBytesUsedOneNode).sum(); } - private void updateEntryPoint() { - int newEntryNode = approximateMedioid(); - graph.updateEntryNode(newEntryNode); - if (newEntryNode >= 0) { - improveConnections(newEntryNode); - updateEntryNodeIn.addAndGet(graph.size()); - } else { - updateEntryNodeIn.addAndGet(10_000); - } + private void updateNeighborsOneLayer(int layer, int node, NodeScore[] neighbors, NodeArray naturalScratchPooled, ConcurrentSkipListSet inProgressBefore, NodeArray concurrentScratchPooled, SearchScoreProvider ssp) { + // Update neighbors with these candidates. + // The DiskANN paper calls for using the entire set of visited nodes along the search path as + // potential candidates, but in practice we observe neighbor lists being completely filled using + // just the topK results. (Since the Robust Prune algorithm prioritizes closer neighbors, + // this means that considering additional nodes from the search path, that are by definition + // farther away than the ones in the topK, would not change the result.) + var natural = toScratchCandidates(neighbors, naturalScratchPooled); + var concurrent = getConcurrentCandidates(layer, node, inProgressBefore, concurrentScratchPooled, ssp.scoreFunction()); + updateNeighbors(layer, node, natural, concurrent); } - private void improveConnections(int node) { - NodeArray naturalScratchPooled; - SearchResult result; - try (var gs = searchers.get()) { - naturalScratchPooled = naturalScratch.get(); - int ep = graph.entry(); - var bits = new ExcludingBits(node); - var ssp = scoreProvider.searchProviderFor(node); - result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, bits); - } catch (Exception e) { - throw new RuntimeException(e); - } - var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled); - var neighbors = graph.nodes.insertDiverse(node, natural); - // no overflow -- this method gets called from cleanup - graph.nodes.backlink(neighbors, node, 1.0f); + @VisibleForTesting + public void setEntryPoint(int level, int node) { + graph.updateEntryNode(new NodeAtLevel(level, node)); } public void markNodeDeleted(int node) { @@ -558,7 +541,7 @@ public synchronized long removeDeletedNodes() { if (nRemoved == 0) { return 0; } - // make a list of remaining live nodes + // make a list of remaining live nodes var liveNodes = new IntArrayList(); for (int i = 0; i < graph.getIdUpperBound(); i++) { if (graph.containsNode(i) && !toDelete.get(i)) { @@ -566,121 +549,109 @@ public synchronized long removeDeletedNodes() { } } - // Compute new edges to insert. If node j is deleted, we add edges (i, k) - // whenever (i, j) and (j, k) are directed edges in the current graph. This - // strategy is proposed in "FreshDiskANN: A Fast and Accurate Graph-Based - // ANN Index for Streaming Similarity Search" section 4.2. - var newEdges = new ConcurrentHashMap>(); // new edges for key k are values v - parallelExecutor.submit(() -> { - IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> { - var neighbors = graph.getNeighbors(i); - if (neighbors == null || toDelete.get(i)) { - return; - } - for (var it = neighbors.iterator(); it.hasNext(); ) { - var j = it.nextInt(); - if (toDelete.get(j)) { - var newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet()); - for (var jt = graph.getNeighbors(j).iterator(); jt.hasNext(); ) { - int k = jt.nextInt(); - if (i != k && !toDelete.get(k)) { - newEdgesForI.add(k); + for (int currentLevel = 0; currentLevel < graph.layers.size(); currentLevel++) { + final int level = currentLevel; // Create effectively final copy for lambda + // Compute new edges to insert. If node j is deleted, we add edges (i, k) + // whenever (i, j) and (j, k) are directed edges in the current graph. This + // strategy is proposed in "FreshDiskANN: A Fast and Accurate Graph-Based + // ANN Index for Streaming Similarity Search" section 4.2. + var newEdges = new ConcurrentHashMap>(); // new edges for key k are values v + parallelExecutor.submit(() -> { + IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> { + var neighbors = graph.getNeighbors(level, i); + if (neighbors == null || toDelete.get(i)) { + return; + } + for (var it = neighbors.iterator(); it.hasNext(); ) { + var j = it.nextInt(); + if (toDelete.get(j)) { + var newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet()); + for (var jt = graph.getNeighbors(level, j).iterator(); jt.hasNext(); ) { + int k = jt.nextInt(); + if (i != k && !toDelete.get(k)) { + newEdgesForI.add(k); + } } } } - } - }); - }).join(); - - // Remove deleted nodes from neighbors lists; - // Score the new edges, and connect the most diverse ones as neighbors - simdExecutor.submit(() -> { - newEdges.entrySet().stream().parallel().forEach(e -> { - // turn the new edges into a NodeArray - int node = e.getKey(); - // each deleted node has ALL of its neighbors added as candidates, so using approximate - // scoring and then re-scoring only the best options later makes sense here - var sf = scoreProvider.searchProviderFor(node).scoreFunction(); - var candidates = new NodeArray(graph.maxDegree); - for (var k : e.getValue()) { - candidates.insertSorted(k, sf.similarityTo(k)); - } + }); + }).join(); + + // Remove deleted nodes from neighbors lists; + // Score the new edges, and connect the most diverse ones as neighbors + simdExecutor.submit(() -> { + newEdges.entrySet().stream().parallel().forEach(e -> { + // turn the new edges into a NodeArray + int node = e.getKey(); + // each deleted node has ALL of its neighbors added as candidates, so using approximate + // scoring and then re-scoring only the best options later makes sense here + var sf = scoreProvider.searchProviderFor(node).scoreFunction(); + var candidates = new NodeArray(graph.getDegree(level)); + for (var k : e.getValue()) { + candidates.insertSorted(k, sf.similarityTo(k)); + } - // it's unlikely, but possible, that all the potential replacement edges were to nodes that have also - // been deleted. if that happens, keep the graph connected by adding random edges. - // (this is overly conservative -- really what we care about is that the end result of - // replaceDeletedNeighbors not be empty -- but we want to avoid having the node temporarily - // neighborless while concurrent searches run. empirically, this only results in a little extra work.) - if (candidates.size() == 0) { - var R = ThreadLocalRandom.current(); - // doing actual sampling-without-replacement is expensive so we'll loop a fixed number of times instead - for (int i = 0; i < 2 * graph.maxDegree(); i++) { - int randomNode = liveNodes.get(R.nextInt(liveNodes.size())); - if (randomNode != node && !candidates.contains(randomNode)) { - float score = sf.similarityTo(randomNode); - candidates.insertSorted(randomNode, score); - } - if (candidates.size() == graph.maxDegree) { - break; + // it's unlikely, but possible, that all the potential replacement edges were to nodes that have also + // been deleted. if that happens, keep the graph connected by adding random edges. + // (this is overly conservative -- really what we care about is that the end result of + // replaceDeletedNeighbors not be empty -- but we want to avoid having the node temporarily + // neighborless while concurrent searches run. empirically, this only results in a little extra work.) + if (candidates.size() == 0) { + var R = ThreadLocalRandom.current(); + // doing actual sampling-without-replacement is expensive so we'll loop a fixed number of times instead + for (int i = 0; i < 2 * graph.getDegree(level); i++) { + int randomNode = liveNodes.get(R.nextInt(liveNodes.size())); + if (randomNode != node && !candidates.contains(randomNode) && graph.layers.get(level).contains(randomNode)) { + float score = sf.similarityTo(randomNode); + candidates.insertSorted(randomNode, score); + } + if (candidates.size() == graph.getDegree(level)) { + break; + } } } - } - // remove edges to deleted nodes and add the new connections, maintaining diversity - graph.nodes.replaceDeletedNeighbors(node, toDelete, candidates); - }); - }).join(); + // remove edges to deleted nodes and add the new connections, maintaining diversity + graph.layers.get(level).replaceDeletedNeighbors(node, toDelete, candidates); + }); + }).join(); + } // Generally we want to keep entryPoint update and node removal distinct, because both can be expensive, // but if the entry point was deleted then we have no choice - if (toDelete.get(graph.entry())) { - updateEntryPoint(); + if (toDelete.get(graph.entry().node)) { + // pick a random node at the top layer + int newLevel = graph.getMaxLevel(); + int newEntry = -1; + outer: + while (newLevel >= 0) { + for (var it = graph.getNodes(newLevel); it.hasNext(); ){ + int i = it.nextInt(); + if (!toDelete.get(i)) { + newEntry = i; + break outer; + } + } + newLevel--; + } + + graph.updateEntryNode(newEntry >= 0 ? new NodeAtLevel(newLevel, newEntry) : null); } + long memorySize = 0; + // Remove the deleted nodes from the graph assert toDelete.cardinality() == nRemoved : "cardinality changed"; for (int i = toDelete.nextSetBit(0); i != NO_MORE_DOCS; i = toDelete.nextSetBit(i + 1)) { - graph.removeNode(i); - } - - return nRemoved * graph.ramBytesUsedOneNode(); - } - - /** - * Returns the ordinal of the node that is closest to the centroid of the graph, - * or NO_ENTRY_POINT if there are no live nodes in the graph. - */ - private int approximateMedioid() { - if (graph.size() == 0) { - return NO_ENTRY_POINT; - } - - var centroid = scoreProvider.approximateCentroid(); - // if the centroid is the zero vector, pick a random node - // (this is not a scenario likely to arise outside of small, contrived tests) - if (dotProduct(centroid, centroid) < 1E-6) { - return randomLiveNode(); - } - - int ep = graph.entry(); - var ssp = scoreProvider.searchProviderFor(centroid); - try (var gs = searchers.get()) { - // search for the centroid. if we can find a live node nearby, return it - var result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, Bits.ALL); - if (result.getNodes().length != 0) { - return result.getNodes()[0].node; + int nDeletions = graph.removeNode(i); + for (var iLayer = 0; iLayer < nDeletions; iLayer++) { + memorySize += graph.ramBytesUsedOneNode(iLayer); } - - // No live nodes found in the search. Either no live nodes exist, or the graph is too - // poorly connected to find one. we'll do our best under the circumstances by picking - // a random live node, or NO_ENTRY_POINT if none exist. - return randomLiveNode(); - } catch (IOException e) { - throw new UncheckedIOException(e); } + return memorySize; } - private void updateNeighbors(int nodeId, NodeArray natural, NodeArray concurrent) { + private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) { // if either natural or concurrent is empty, skip the merge NodeArray toMerge; if (concurrent.size() == 0) { @@ -691,11 +662,11 @@ private void updateNeighbors(int nodeId, NodeArray natural, NodeArray concurrent toMerge = NodeArray.merge(natural, concurrent); } // toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones - var neighbors = graph.nodes.insertDiverse(nodeId, toMerge); - graph.nodes.backlink(neighbors, nodeId, neighborOverflow); + var neighbors = graph.layers.get(layer).insertDiverse(nodeId, toMerge); + graph.layers.get(layer).backlink(neighbors, nodeId, neighborOverflow); } - private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) { + private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) { scratch.clear(); for (var candidate : candidates) { scratch.addInOrder(candidate.node, candidate.score); @@ -703,17 +674,18 @@ private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates return scratch; } - private NodeArray getConcurrentCandidates(int newNode, - Set inProgress, + private NodeArray getConcurrentCandidates(int layer, + int newNode, + Set inProgress, NodeArray scratch, ScoreFunction scoreFunction) { scratch.clear(); - for (var n : inProgress) { - if (n == newNode) { + for (NodeAtLevel n : inProgress) { + if (n.node == newNode || n.level < layer) { continue; } - scratch.insertSorted(n, scoreFunction.similarityTo(n)); + scratch.insertSorted(n.node, scoreFunction.similarityTo(n.node)); } return scratch; } @@ -727,81 +699,81 @@ public void close() throws IOException { } } - /** - * @return a random live node, or `NO_ENTRY_POINT` if no live nodes exist. - */ - @VisibleForTesting - int randomLiveNode() { - var R = ThreadLocalRandom.current(); - - // first, try doing it quickly by just picking a random node - for (int i = 0; i < 3; i++) { - var idUpperBound = graph.getIdUpperBound(); - if (idUpperBound == 0) { - return NO_ENTRY_POINT; - } - int n = R.nextInt(idUpperBound); - if (graph.containsNode(n) && !graph.getDeletedNodes().get(n)) { - return n; - } - } + private static class ExcludingBits implements Bits { + private final int excluded; - // lots of deletions and/or sparse node ids, so we do it the slow way - var L = new ArrayList(); - for (int i = 0; i < graph.getIdUpperBound(); i++) { - if (graph.containsNode(i) && !graph.getDeletedNodes().get(i)) { - L.add(i); - } + public ExcludingBits(int excluded) { + this.excluded = excluded; } - if (L.isEmpty()) { - return NO_ENTRY_POINT; + + @Override + public boolean get(int index) { + return index != excluded; } - return L.get(R.nextInt(L.size())); } - @VisibleForTesting - void validateAllNodesLive() { - assert graph.getDeletedNodes().cardinality() == 0; - // all edges should be valid, live nodes - for (int i = 0; i < graph.getIdUpperBound(); i++) { - if (!graph.containsNode(i)) { - continue; // holes are tolerated - } - var neighbors = graph.getNeighbors(i); - for (var it = neighbors.iterator(); it.hasNext(); ) { - var j = it.nextInt(); - assert graph.containsNode(j) : String.format("Edge %d -> %d is invalid", i, j); - } + public void load(RandomAccessReader in) throws IOException { + if (graph.size(0) != 0) { + throw new IllegalStateException("Cannot load into a non-empty graph"); } - } - /** - * @return the average short edges. Will be NaN if cleanup() has not been run, - * or if no edge lists in the graph needed to be trimmed at cleanup time. - */ - public double getAverageShortEdges() { - return averageShortEdges; + int maybeMagic = in.readInt(); + int version; // This is not used in V4 but may be useful in the future, putting it as a placeholder. + if (maybeMagic != OnHeapGraphIndex.MAGIC) { + // JVector 3 format, no magic or version, starts straight off with the number of nodes + version = 3; + int size = maybeMagic; + loadV3(in, size); + } else { + version = in.readInt(); + loadV4(in); + } } - private static class ExcludingBits implements Bits { - private final int excluded; + private void loadV4(RandomAccessReader in) throws IOException { + if (graph.size(0) != 0) { + throw new IllegalStateException("Cannot load into a non-empty graph"); + } - public ExcludingBits(int excluded) { - this.excluded = excluded; + int layerCount = in.readInt(); + int entryNode = in.readInt(); + var layerDegrees = new ArrayList(layerCount); + + Map nodeLevelMap = new HashMap<>(); + + // Read layer info + for (int level = 0; level < layerCount; level++) { + int layerSize = in.readInt(); + layerDegrees.add(in.readInt()); + for (int i = 0; i < layerSize; i++) { + int nodeId = in.readInt(); + int nNeighbors = in.readInt(); + var sf = scoreProvider.searchProviderFor(nodeId).exactScoreFunction(); + var ca = new NodeArray(nNeighbors); + for (int j = 0; j < nNeighbors; j++) { + int neighbor = in.readInt(); + ca.addInOrder(neighbor, sf.similarityTo(neighbor)); + } + graph.addNode(level, nodeId, ca); + nodeLevelMap.put(nodeId, level); + } } - @Override - public boolean get(int index) { - return index != excluded; + for (var k : nodeLevelMap.keySet()) { + NodeAtLevel nal = new NodeAtLevel(nodeLevelMap.get(k), k); + graph.markComplete(nal); } + + graph.setDegrees(layerDegrees); + graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); } - public void load(RandomAccessReader in) throws IOException { + + private void loadV3(RandomAccessReader in, int size) throws IOException { if (graph.size() != 0) { throw new IllegalStateException("Cannot load into a non-empty graph"); } - int size = in.readInt(); int entryNode = in.readInt(); int maxDegree = in.readInt(); @@ -814,9 +786,11 @@ public void load(RandomAccessReader in) throws IOException { int neighbor = in.readInt(); ca.addInOrder(neighbor, sf.similarityTo(neighbor)); } - graph.addNode(nodeId, ca); + graph.addNode(0, nodeId, ca); + graph.markComplete(new NodeAtLevel(0, nodeId)); } - graph.updateEntryNode(entryNode); + graph.updateEntryNode(new NodeAtLevel(0, entryNode)); + graph.setDegrees(List.of(maxDegree)); } -} +} \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index a17af04bb..896ea16c9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -25,6 +25,7 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.annotations.Experimental; +import io.github.jbellis.jvector.graph.GraphIndex.NodeAtLevel; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.util.Bits; @@ -45,14 +46,14 @@ * search algorithm, see {@link GraphIndex}. */ public class GraphSearcher implements Closeable { - private static final boolean PRUNE = Boolean.parseBoolean(System.getenv().getOrDefault("JVECTOR_PRUNE_SEARCH", "true")); + private boolean pruneSearch; - private final GraphIndex.View view; + private GraphIndex.View view; // Scratch data structures that are used in each {@link #searchInternal} call. These can be expensive // to allocate, so they're cleared and reused across calls. private final NodeQueue candidates; - private final NodeQueue approximateResults; + final NodeQueue approximateResults; private final NodeQueue rerankedResults; private final IntHashSet visited; private final NodesUnsorted evictedResults; @@ -76,6 +77,7 @@ private GraphSearcher(GraphIndex.View view) { this.approximateResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP); this.rerankedResults = new NodeQueue(new BoundedLongHeap(100), NodeQueue.Order.MIN_HEAP); this.visited = new IntHashSet(); + this.pruneSearch = false; } private void initializeScoreProvider(SearchScoreProvider scoreProvider) { @@ -92,6 +94,15 @@ public GraphIndex.View getView() { return view; } + /** + * When using pruning, we are using a heuristic to terminate the search earlier. + * In certain cases, it can lead to speedups. This is set to false by default. + * @param usage a boolean that determines whether we do early termination or not. + */ + public void usePruning(boolean usage) { + pruneSearch = usage; + } + /** * Convenience function for simple one-off searches. It is caller's responsibility to make sure that it * is the unique owner of the vectors instance passed in here. @@ -105,6 +116,19 @@ public static SearchResult search(VectorFloat queryVector, int topK, RandomAc } } + /** + * Sets the view of the graph to be used by the searcher. + *

+ * This method should be used when the searcher operates over a view whose contents might not reflect all changes + * to the underlying graph, such as {@link OnHeapGraphIndex.ConcurrentGraphIndexView}. This is an optimization over + * creating a new graph searcher with every update to the view. + * + * @param view the new view + */ + public void setView(GraphIndex.View view) { + this.view = view; + } + /** * Call GraphSearcher constructor instead */ @@ -152,8 +176,33 @@ public SearchResult search(SearchScoreProvider scoreProvider, int rerankK, float threshold, float rerankFloor, - Bits acceptOrds) { - return searchInternal(scoreProvider, topK, rerankK, threshold, rerankFloor, view.entryNode(), acceptOrds); + Bits acceptOrds) + { + NodeAtLevel entry = view.entryNode(); + if (acceptOrds == null) { + throw new IllegalArgumentException("Use MatchAllBits to indicate that all ordinals are accepted, instead of null"); + } + if (rerankK < topK) { + throw new IllegalArgumentException(String.format("rerankK %d must be >= topK %d", rerankK, topK)); + } + + if (entry == null) { + return new SearchResult(new SearchResult.NodeScore[0], 0, 0, Float.POSITIVE_INFINITY); + } + + initializeInternal(scoreProvider, entry, acceptOrds); + + // Move downward from entry.level to 1 + int numVisited = 0; + for (int lvl = entry.level; lvl > 0; lvl--) { + // Search this layer with minimal parameters since we just want the best candidate + numVisited += searchOneLayer(scoreProvider, 1, 0.0f, lvl, Bits.ALL); + assert approximateResults.size() == 1 : approximateResults.size(); + setEntryPointsFromPreviousLayer(); + } + + // Now do the main search at layer 0 + return resume(numVisited, topK, rerankK, threshold, rerankFloor); } /** @@ -195,24 +244,16 @@ public SearchResult search(SearchScoreProvider scoreProvider, return search(scoreProvider, topK, 0.0f, acceptOrds); } - /** - * Set up the state for a new search and kick it off - */ - SearchResult searchInternal(SearchScoreProvider scoreProvider, - int topK, - int rerankK, - float threshold, - float rerankFloor, - int ep, - Bits rawAcceptOrds) - { - if (rawAcceptOrds == null) { - throw new IllegalArgumentException("Use MatchAllBits to indicate that all ordinals are accepted, instead of null"); - } - if (rerankK < topK) { - throw new IllegalArgumentException(String.format("rerankK %d must be >= topK %d", rerankK, topK)); - } + void setEntryPointsFromPreviousLayer() { + // push the candidates seen so far back onto the queue for the next layer + // at worst we save recomputing the similarity; at best we might connect to a more distant cluster + approximateResults.foreach(candidates::push); + evictedResults.foreach(candidates::push); + evictedResults.clear(); + approximateResults.clear(); + } + void initializeInternal(SearchScoreProvider scoreProvider, NodeAtLevel entry, Bits rawAcceptOrds) { // save search parameters for potential later resume initializeScoreProvider(scoreProvider); this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes()); @@ -222,22 +263,26 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider, candidates.clear(); visited.clear(); - // no entry point -> empty results - if (ep < 0) { - return new SearchResult(new SearchResult.NodeScore[0], 0, 0, Float.POSITIVE_INFINITY); - } - - // kick off the actual search at the entry point - float score = scoreProvider.scoreFunction().similarityTo(ep); - visited.add(ep); - candidates.push(ep, score); - return resume(1, topK, rerankK, threshold, rerankFloor); + // Start with entry point + float score = scoreProvider.scoreFunction().similarityTo(entry.node); + visited.add(entry.node); + candidates.push(entry.node, score); } /** - * Resume the previous search where it left off and search for the best (new) `topK` neighbors. - *

- * SearchResult.visitedCount resets with each call to `search` or `resume`. + * Performs a single-layer ANN search, expanding from the given candidates queue. + * + * @param scoreProvider the current query's scoring/approximation logic + * @param rerankK how many results to over-query for approximate ranking + * @param threshold similarity threshold, or 0f if none + * @param level which layer to search + *

+ * Modifies the internal search state. + * When it's done, `approximateResults` contains the best `rerankK` results found at the given layer. + * @param acceptOrdsThisLayer a Bits instance indicating which nodes are acceptable results. + * If {@link Bits#ALL}, all nodes are acceptable. + * It is caller's responsibility to ensure that there are enough acceptable nodes + * that we don't search the entire graph trying to satisfy topK. */ // Since Astra / Cassandra's usage drives the design decisions here, it's worth being explicit // about how that works and why. @@ -258,28 +303,23 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider, // incorrect and is discarded, and there is no reason to pass a rerankFloor parameter to resume(). // // Finally: resume() also drives the use of CachingReranker. - private SearchResult resume(int initialVisited, int topK, int rerankK, float threshold, float rerankFloor) { + int searchOneLayer(SearchScoreProvider scoreProvider, + int rerankK, + float threshold, + int level, + Bits acceptOrdsThisLayer) + { try { - assert approximateResults.size() == 0; // should be cleared out by extractScores - assert rerankedResults.size() == 0; // should be cleared out by extractScores + assert approximateResults.size() == 0; // should be cleared by setEntryPointsFromPreviousLayer approximateResults.setMaxSize(rerankK); - rerankedResults.setMaxSize(topK); - int numVisited = initialVisited; + int numVisited = 0; // track scores to predict when we are done with threshold queries var scoreTracker = threshold > 0 ? new ScoreTracker.TwoPhaseTracker(threshold) - : PRUNE ? new ScoreTracker.RelaxedMonotonicityTracker(rerankK) : new ScoreTracker.NoOpTracker(); + : pruneSearch ? new ScoreTracker.RelaxedMonotonicityTracker(rerankK) : new ScoreTracker.NoOpTracker(); VectorFloat similarities = null; - // add evicted results from the last call back to the candidates - var previouslyEvicted = evictedResults.size() > 0 ? new SparseBits() : Bits.NONE; - evictedResults.foreach((node, score) -> { - candidates.push(node, score); - ((SparseBits) previouslyEvicted).set(node); - }); - evictedResults.clear(); - // the main search loop while (candidates.size() > 0) { // we're done when we have K results and the best candidate is worse than the worst result so far @@ -294,15 +334,10 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr // process the top candidate int topCandidateNode = candidates.pop(); - if (acceptOrds.get(topCandidateNode) && topCandidateScore >= threshold) { + if (acceptOrdsThisLayer.get(topCandidateNode) && topCandidateScore >= threshold) { addTopCandidate(topCandidateNode, topCandidateScore, rerankK); } - // if this candidate came from evictedResults, we don't need to evaluate its neighbors again - if (previouslyEvicted.get(topCandidateNode)) { - continue; - } - // skip edge loading if we've found a local maximum and we have enough results if (scoreTracker.shouldStop() && candidates.size() >= rerankK - approximateResults.size()) { continue; @@ -314,8 +349,8 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr if (useEdgeLoading) { similarities = scoreFunction.edgeLoadingSimilarityTo(topCandidateNode); } - var it = view.getNeighborsIterator(topCandidateNode); - for (int i = 0; i < it.size(); i++) { + int i = 0; + for (var it = view.getNeighborsIterator(level, topCandidateNode); it.hasNext(); ) { var friendOrd = it.nextInt(); if (!visited.add(friendOrd)) { continue; @@ -327,52 +362,68 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr : scoreFunction.similarityTo(friendOrd); scoreTracker.track(friendSimilarity); candidates.push(friendOrd, friendSimilarity); + i++; } } - // rerank results - assert approximateResults.size() <= rerankK; - NodeQueue popFromQueue; - float worstApproximateInTopK; - int reranked; - if (cachingReranker == null) { - // save the worst candidates in evictedResults for potential resume() - while (approximateResults.size() > topK) { - var nScore = approximateResults.topScore(); - var n = approximateResults.pop(); - evictedResults.add(n, nScore); - } - - reranked = 0; - worstApproximateInTopK = Float.POSITIVE_INFINITY; - popFromQueue = approximateResults; - } else { - int oldReranked = cachingReranker.getRerankCalls(); - worstApproximateInTopK = approximateResults.rerank(topK, cachingReranker, rerankFloor, rerankedResults, evictedResults); - reranked = cachingReranker.getRerankCalls() - oldReranked; - approximateResults.clear(); - popFromQueue = rerankedResults; - } - // pop the top K results from the results queue, which has the worst candidates at the top - assert popFromQueue.size() <= topK; - var nodes = new SearchResult.NodeScore[popFromQueue.size()]; - for (int i = nodes.length - 1; i >= 0; i--) { - var nScore = popFromQueue.topScore(); - var n = popFromQueue.pop(); - nodes[i] = new SearchResult.NodeScore(n, nScore); - } - // that should be everything - assert popFromQueue.size() == 0; - - return new SearchResult(nodes, numVisited, reranked, worstApproximateInTopK); + return numVisited; } catch (Throwable t) { // clear scratch structures if terminated via throwable, as they may not have been drained approximateResults.clear(); - rerankedResults.clear(); throw t; } } + SearchResult resume(int numVisited, int topK, int rerankK, float threshold, float rerankFloor) { + // rR is persistent to save on allocations + rerankedResults.clear(); + rerankedResults.setMaxSize(topK); + + // add evicted results from the last call back to the candidates + evictedResults.foreach((node, score) -> { + candidates.push(node, score); + }); + evictedResults.clear(); + + numVisited += searchOneLayer(scoreProvider, rerankK, threshold, 0, acceptOrds); + + // rerank results + assert approximateResults.size() <= rerankK; + NodeQueue popFromQueue; + float worstApproximateInTopK; + int reranked; + if (cachingReranker == null) { + // save the worst candidates in evictedResults for potential resume() + while (approximateResults.size() > topK) { + var nScore = approximateResults.topScore(); + var n = approximateResults.pop(); + evictedResults.add(n, nScore); + } + + reranked = 0; + worstApproximateInTopK = Float.POSITIVE_INFINITY; + popFromQueue = approximateResults; + } else { + int oldReranked = cachingReranker.getRerankCalls(); + worstApproximateInTopK = approximateResults.rerank(topK, cachingReranker, rerankFloor, rerankedResults, evictedResults); + reranked = cachingReranker.getRerankCalls() - oldReranked; + approximateResults.clear(); + popFromQueue = rerankedResults; + } + // pop the top K results from the results queue, which has the worst candidates at the top + assert popFromQueue.size() <= topK; + var nodes = new SearchResult.NodeScore[popFromQueue.size()]; + for (int i = nodes.length - 1; i >= 0; i--) { + var nScore = popFromQueue.topScore(); + var n = popFromQueue.pop(); + nodes[i] = new SearchResult.NodeScore(n, nScore); + } + // that should be everything + assert popFromQueue.size() == 0; + + return new SearchResult(nodes, numVisited, reranked, worstApproximateInTopK); + } + @SuppressWarnings("StatementWithEmptyBody") private void addTopCandidate(int topCandidateNode, float topCandidateScore, int rerankK) { // add the new node to the results queue, and any evicted node to evictedResults in case we resume later diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java index c8e138142..fd0d4f792 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeArray.java @@ -238,7 +238,7 @@ public NodeArray copy() { public NodeArray copy(int newSize) { if (size > newSize) { - throw new IllegalArgumentException("Cannot copy to a smaller size"); + throw new IllegalArgumentException(String.format("Cannot copy %d nodes to a smaller size %d", size, newSize)); } NodeArray copy = new NodeArray(newSize); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 0bf3cd602..acf8d0528 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -30,6 +30,8 @@ import io.github.jbellis.jvector.util.NumericUtils; import org.agrona.collections.Int2ObjectHashMap; +import java.util.Arrays; + import static java.lang.Math.min; /** @@ -111,6 +113,7 @@ public boolean push(int newNode, float newScore) { * @return the encoded score, node ID */ private long encode(int node, float score) { + assert node >= 0 : node; return order.apply( (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); } @@ -256,4 +259,19 @@ public void foreach(NodeConsumer consumer) { public interface NodeConsumer { void accept(int node, float score); } + + /** + * Copies the other NodeQueue to this one. If its order (MIN_HEAP or MAX_HEAP) is the same as this, + * it is copied verbatim. If it differs, every element is re-inserted into this. + * @param other the other node queue. + */ + public void copyFrom(NodeQueue other) { + if (this.order == other.order) { + this.heap.copyFrom(other.heap); + } else { + // can't avoid re-encoding since order influences it + clear(); + other.foreach(this::push); + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java index eac8d9393..fa718ea02 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodesIterator.java @@ -32,25 +32,19 @@ * number of nodes to be iterated over. The nodes are NOT guaranteed to be presented in any * particular order. */ -public abstract class NodesIterator implements PrimitiveIterator.OfInt { - protected final int size; - - /** - * Constructor for iterator based on the size - */ - public NodesIterator(int size) { - this.size = size; - } - +public interface NodesIterator extends PrimitiveIterator.OfInt { /** * The number of elements in this iterator * */ - public int size() { - return size; - } + int size(); + + static NodesIterator fromPrimitiveIterator(PrimitiveIterator.OfInt iterator, int size) { + return new NodesIterator() { + @Override + public int size() { + return size; + } - public static NodesIterator fromPrimitiveIterator(PrimitiveIterator.OfInt iterator, int size) { - return new NodesIterator(size) { @Override public int nextInt() { return iterator.nextInt(); @@ -63,18 +57,24 @@ public boolean hasNext() { }; } - public static class ArrayNodesIterator extends NodesIterator { + class ArrayNodesIterator implements NodesIterator { private final int[] nodes; private int cur = 0; + private final int size; /** Constructor for iterator based on integer array representing nodes */ public ArrayNodesIterator(int[] nodes, int size) { - super(size); assert nodes != null; assert size <= nodes.length; + this.size = size; this.nodes = nodes; } + @Override + public int size() { + return size; + } + public ArrayNodesIterator(int[] nodes) { this(nodes, nodes.length); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 309609364..bd149483f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -24,55 +24,91 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.DenseIntMap; import io.github.jbellis.jvector.util.RamUsageEstimator; +import io.github.jbellis.jvector.util.SparseIntMap; import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.agrona.collections.IntArrayList; import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicIntegerArray; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.StampedLock; import java.util.stream.IntStream; /** * An {@link GraphIndex} that offers concurrent access; for typical graphs you will get significant * speedups in construction and searching as you add threads. * - *

To search this graph, you should use a View obtained from {@link #getView()} to perform `seek` - * and `nextNeighbor` operations. + *

The base layer (layer 0) contains all nodes, while higher layers are stored in sparse maps. + * For searching, use a view obtained from {@link #getView()} which supports level–aware operations. */ public class OnHeapGraphIndex implements GraphIndex { - static final int NO_ENTRY_POINT = -1; + // Used for saving and loading OnHeapGraphIndex + public static final int MAGIC = 0x75EC4012; // JVECTOR, with some imagination - // the current graph entry node, NO_ENTRY_POINT if not set - private final AtomicInteger entryPoint = new AtomicInteger(NO_ENTRY_POINT); + // The current entry node for searches + private final AtomicReference entryPoint; - final ConcurrentNeighborMap nodes; - private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0); - private final AtomicInteger maxNodeId = new AtomicInteger(NO_ENTRY_POINT); - - // max neighbors/edges per node - final int maxDegree; + // Layers of the graph, with layer 0 as the bottom (dense) layer containing all nodes. + final List layers = new ArrayList<>(); - OnHeapGraphIndex(int M, int maxOverflowDegree, BuildScoreProvider scoreProvider, float alpha) { - this.maxDegree = M; - this.nodes = new ConcurrentNeighborMap(scoreProvider, maxDegree, maxOverflowDegree, alpha); + private final CompletionTracker completions; + private final ThreadSafeGrowableBitSet deletedNodes = new ThreadSafeGrowableBitSet(0); + private final AtomicInteger maxNodeId = new AtomicInteger(-1); + + // Maximum number of neighbors (edges) per node per layer + final IntArrayList maxDegrees; + // The ratio by which we can overflow the neighborhood of a node during construction. Since it is a multiplicative + // ratio, i.e., the maximum allowable degree if maxDegree * overflowRatio, it should be higher than 1. + private final double overflowRatio; + + public final ConcurrentMap> constructionBatch; + + OnHeapGraphIndex(List maxDegrees, double overflowRatio, BuildScoreProvider scoreProvider, float alpha, int batchSize) { + this.overflowRatio = overflowRatio; + this.maxDegrees = new IntArrayList(); + setDegrees(maxDegrees); + entryPoint = new AtomicReference<>(); + this.completions = new CompletionTracker(1024); + // Initialize the base layer (layer 0) with a dense map. + this.layers.add(new ConcurrentNeighborMap(new DenseIntMap<>(1024), + scoreProvider, + getDegree(0), + (int) (getDegree(0) * overflowRatio), + alpha)); + this.constructionBatch = new ConcurrentHashMap<>(batchSize); } /** - * Returns the neighbors connected to the given node, or null if the node does not exist. + * Returns the neighbors for the given node at the specified level, or null if the node does not exist. * - * @param node the node whose neighbors are returned, represented as an ordinal. + * @param level the layer + * @param node the node id + * @return the Neighbors structure or null */ - ConcurrentNeighborMap.Neighbors getNeighbors(int node) { - return nodes.get(node); + Neighbors getNeighbors(int level, int node) { + if (level >= layers.size()) { + return null; + } + return layers.get(level).get(node); } - @Override - public int size() { - return nodes.size(); + public int size(int level) { + return layers.get(level).size(); } /** @@ -85,20 +121,40 @@ public int size() { * responsibility of the caller. * *

It is also the responsibility of the caller to ensure that each node is only added once. - * - * @param nodeId the node to add, represented as an ordinal */ - public void addNode(int nodeId) { - nodes.addNode(nodeId); - maxNodeId.accumulateAndGet(nodeId, Math::max); + public void addNode(NodeAtLevel nodeLevel) { + ensureLayersExist(nodeLevel.level); + + // add the node to each layer + for (int i = 0; i <= nodeLevel.level; i++) { + layers.get(i).addNode(nodeLevel.node); + } + maxNodeId.accumulateAndGet(nodeLevel.node, Math::max); + } + + private void ensureLayersExist(int level) { + for (int i = layers.size(); i <= level; i++) { + synchronized (layers) { + if (i == layers.size()) { // doublecheck after locking + var denseMap = layers.get(0); + var map = new ConcurrentNeighborMap(new SparseIntMap<>(), + denseMap.scoreProvider, + getDegree(level), + (int) (getDegree(level) * overflowRatio), + denseMap.alpha); + layers.add(map); + } + } + } } /** * Only for use by Builder loading a saved graph */ - void addNode(int nodeId, NodeArray nodes) { + void addNode(int level, int nodeId, NodeArray nodes) { assert nodes != null; - this.nodes.addNode(nodeId, nodes); + ensureLayersExist(level); + this.layers.get(level).addNode(nodeId, nodes); maxNodeId.accumulateAndGet(nodeId, Math::max); } @@ -109,60 +165,74 @@ public void markDeleted(int node) { deletedNodes.set(node); } - /** must be called after addNode once neighbors are linked */ - void maybeSetInitialEntryNode(int node) { - entryPoint.accumulateAndGet(node, - (oldEntry, newEntry) -> { - if (oldEntry >= 0) { - return oldEntry; - } else { - return newEntry; - } - }); - } - - void updateEntryNode(int node) { - entryPoint.set(node); + /** must be called after addNode once neighbors are linked in all levels. */ + void markComplete(NodeAtLevel nodeLevel) { + entryPoint.accumulateAndGet( + nodeLevel, + (oldEntry, newEntry) -> { + if (oldEntry == null || newEntry.level > oldEntry.level) { + return newEntry; + } else { + return oldEntry; + } + }); + completions.markComplete(nodeLevel.node); } - @Override - public int maxDegree() { - return maxDegree; + void updateEntryNode(NodeAtLevel newEntry) { + entryPoint.set(newEntry); } - int entry() { + NodeAtLevel entry() { return entryPoint.get(); } @Override - public NodesIterator getNodes() { - return nodes.nodesIterator(); + public NodesIterator getNodes(int level) { + return NodesIterator.fromPrimitiveIterator(nodeStream(level).iterator(), + layers.get(level).size()); + } + + /** + * this does call get() internally to filter level 0, so if you're going to use it in a pipeline + * that also calls get(), consider using your own raw IntStream.range instead + */ + IntStream nodeStream(int level) { + var layer = layers.get(level); + return level == 0 + ? IntStream.range(0, getIdUpperBound()).filter(i -> layer.get(i) != null) + : ((SparseIntMap) layer.neighbors).keysStream(); } @Override public long ramBytesUsed() { + var graphBytesUsed = IntStream.range(0, layers.size()).mapToLong(this::ramBytesUsedOneLayer).sum(); + return graphBytesUsed + completions.ramBytesUsed(); + } + + public long ramBytesUsedOneLayer(int layer) { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; var REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; var AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; - long neighborSize = ramBytesUsedOneNode() * size(); + long neighborSize = ramBytesUsedOneNode(layer) * layers.get(layer).size(); return OH_BYTES + REF_BYTES * 2L + AH_BYTES + neighborSize; } - public long ramBytesUsedOneNode() { + public long ramBytesUsedOneNode(int layer) { // we include the REF_BYTES for the CNS reference here to make it self-contained for addGraphNode() int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; - return REF_BYTES + ConcurrentNeighborMap.Neighbors.ramBytesUsed(nodes.nodeArrayLength()); + return REF_BYTES + Neighbors.ramBytesUsed(layers.get(layer).nodeArrayLength()); } @Override public String toString() { - return String.format("OnHeapGraphIndex(size=%d, entryPoint=%d)", size(), entryPoint.get()); + return String.format("OnHeapGraphIndex(size=%d, entryPoint=%s)", size(0), entryPoint.get()); } @Override public void close() { - // no-op + // No resources to close. } /** @@ -176,13 +246,23 @@ public ConcurrentGraphIndexView getView() { return new ConcurrentGraphIndexView(); } + /** + * A View that assumes no concurrent modifications are made + */ + public GraphIndex.View getFrozenView() { + return new FrozenView(); + } + + /** + * Validates that the current entry node has been completely added. + */ void validateEntryNode() { - if (size() == 0) { + if (size(0) == 0) { return; } - var en = entryPoint.get(); - if (!(en >= 0 && getNeighbors(en) != null)) { - throw new IllegalStateException("Entry node was incompletely added! " + en); + NodeAtLevel entry = getView().entryNode(); + if (entry == null || getNeighbors(entry.level, entry.node) == null) { + throw new IllegalStateException("Entry node was incompletely added! " + entry); } } @@ -191,14 +271,20 @@ public ThreadSafeGrowableBitSet getDeletedNodes() { } /** - * @return true iff the node was present. + * Removes the given node from all layers. + * + * @param node the node id to remove + * @return the number of layers from which it was removed */ - boolean removeNode(int node) { - try { - return nodes.remove(node) != null; - } finally { - deletedNodes.clear(node); + int removeNode(int node) { + int found = 0; + for (var layer : layers) { + if (layer.remove(node) != null) { + found++; + } } + deletedNodes.clear(node); + return found; } @Override @@ -206,38 +292,130 @@ public int getIdUpperBound() { return maxNodeId.get() + 1; } + @Override public boolean containsNode(int nodeId) { - return nodes.contains(nodeId); + return layers.get(0).contains(nodeId); } - public double getAverageDegree() { - return IntStream.range(0, getIdUpperBound()) - .filter(this::containsNode) - .mapToDouble(i -> getNeighbors(i).size()) + /** + * Returns the average degree computed over nodes in the specified layer. + * + * @param level the level of interest. + * @return the average degree or NaN if no nodes are present. + */ + public double getAverageDegree(int level) { + return nodeStream(level) + .mapToDouble(i -> getNeighbors(level, i).size()) .average() .orElse(Double.NaN); } - public class ConcurrentGraphIndexView implements GraphIndex.View { - public NodesIterator getNeighborsIterator(int node) { - var neighbors = getNeighbors(node); - assert neighbors != null : "Node " + node + " not found"; - return neighbors.iterator(); + @Override + public int getMaxLevel() { + for (int lvl = 0; lvl < layers.size(); lvl++) { + if (layers.get(lvl).size() == 0) { + return lvl - 1; + } + } + return layers.size() - 1; + } + + @Override + public int getDegree(int level) { + if (level >= maxDegrees.size()) { + return maxDegrees.get(maxDegrees.size() - 1); } + return maxDegrees.get(level); + } + + @Override + public int maxDegree() { + return maxDegrees.stream().mapToInt(i -> i).max().orElseThrow(); + } + + public int getLayerSize(int level) { + return layers.get(level).size(); + } + + public void setDegrees(List layerDegrees) { + maxDegrees.clear(); + maxDegrees.addAll(layerDegrees); + } + + /** + * A concurrent View of the graph that is safe to search concurrently with updates and with other + * searches. The View provides a limited kind of snapshot isolation: only nodes completely added + * to the graph at the time the View was created will be visible (but the connections between them + * are allowed to change, so you could potentially get different top K results from the same query + * if concurrent updates are in progress.) + */ + public class ConcurrentGraphIndexView extends FrozenView { + // It is tempting, but incorrect, to try to provide "adequate" isolation by + // (1) keeping a bitset of complete nodes and giving that to the searcher as nodes to + // accept -- but we need to keep incomplete nodes out of the search path entirely, + // not just out of the result set, or + // (2) keeping a bitset of complete nodes and restricting the View to those nodes + // -- but we needs to consider neighbor diversity separately for concurrent + // inserts and completed nodes; this allows us to keep the former out of the latter, + // but not the latter out of the former (when a node completes while we are working, + // that was in-progress when we started.) + // The only really foolproof solution is to implement snapshot isolation as + // we have done here. + private final int timestamp = completions.clock(); @Override - public int size() { - return OnHeapGraphIndex.this.size(); + public NodesIterator getNeighborsIterator(int level, int node) { + var it = getNeighbors(level, node).iterator(); + return new NodesIterator() { + int nextNode = advance(); + + private int advance() { + while (it.hasNext()) { + int n = it.nextInt(); + if (completions.completedAt(n) < timestamp) { + return n; + } + } + return Integer.MIN_VALUE; + } + + @Override + public int size() { + throw new UnsupportedOperationException(); + } + + @Override + public int nextInt() { + int current = nextNode; + if (current == Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException(); + } + nextNode = advance(); + return current; + } + + @Override + public boolean hasNext() { + return nextNode != Integer.MIN_VALUE; + } + }; } + } + private class FrozenView implements View { @Override - public int entryNode() { - return entryPoint.get(); + public NodesIterator getNeighborsIterator(int level, int node) { + return getNeighbors(level, node).iterator(); } @Override - public String toString() { - return "OnHeapGraphIndexView(size=" + size() + ", entryPoint=" + entryPoint.get(); + public int size() { + return OnHeapGraphIndex.this.size(0); + } + + @Override + public NodeAtLevel entryNode() { + return entryPoint.get(); } @Override @@ -254,37 +432,124 @@ public int getIdUpperBound() { @Override public void close() { + // No resources to close + } + + @Override + public String toString() { + NodeAtLevel entry = entryNode(); + return String.format("%s(size=%d, entryNode=%s)", getClass().getSimpleName(), size(), entry); } } + /** + * Saves the graph to the given DataOutput for reloading into memory later + */ public void save(DataOutput out) { if (deletedNodes.cardinality() > 0) { - throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); + throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); } - // graph-level properties try (var view = getView()) { - out.writeInt(size()); - out.writeInt(view.entryNode()); - out.writeInt(maxDegree()); - - // neighbors - nodes.forEach((nodeId, value) -> { - try { - var neighbors = value.iterator(); - out.writeInt(nodeId); - - out.writeInt(neighbors.size()); - for (int n = 0; n < neighbors.size(); n++) { - out.writeInt(neighbors.nextInt()); + out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number + out.writeInt(4); // The version + + // Write graph-level properties. + out.writeInt(layers.size()); + assert view.entryNode().level == getMaxLevel(); + out.writeInt(view.entryNode().node); + + for (int level = 0; level < layers.size(); level++) { + out.writeInt(size(level)); + out.writeInt(getDegree(level)); + + // Save neighbors from the layer. + var baseLayer = layers.get(level); + baseLayer.forEach((nodeId, neighbors) -> { + try { + NodesIterator iterator = neighbors.iterator(); + out.writeInt(nodeId); + out.writeInt(iterator.size()); + for (int n = 0; n < iterator.size(); n++) { + out.writeInt(iterator.nextInt()); + } + assert !iterator.hasNext(); + } catch (IOException e) { + throw new UncheckedIOException(e); } - assert !neighbors.hasNext(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); + }); + } } catch (IOException e) { throw new UncheckedIOException(e); } } + + /** + * A helper class that tracks completion times for nodes. + */ + static final class CompletionTracker implements Accountable { + private final AtomicInteger logicalClock = new AtomicInteger(); + private volatile AtomicIntegerArray completionTimes; + private final StampedLock sl = new StampedLock(); + + public CompletionTracker(int initialSize) { + completionTimes = new AtomicIntegerArray(initialSize); + for (int i = 0; i < initialSize; i++) { + completionTimes.set(i, Integer.MAX_VALUE); + } + } + + void markComplete(int node) { + int completionClock = logicalClock.getAndIncrement(); + ensureCapacity(node); + long stamp; + do { + stamp = sl.tryOptimisticRead(); + completionTimes.set(node, completionClock); + } while (!sl.validate(stamp)); + } + + int clock() { + return logicalClock.get(); + } + + public int completedAt(int node) { + AtomicIntegerArray ct = completionTimes; + if (node >= ct.length()) { + return Integer.MAX_VALUE; + } + return ct.get(node); + } + + @Override + public long ramBytesUsed() { + int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + return REF_BYTES + Integer.BYTES + REF_BYTES + + (long) Integer.BYTES * completionTimes.length(); + } + + private void ensureCapacity(int node) { + if (node < completionTimes.length()) { + return; + } + long stamp = sl.writeLock(); + try { + AtomicIntegerArray oldArray = completionTimes; + if (node >= oldArray.length()) { + int newSize = (node + 1) * 2; + AtomicIntegerArray newArr = new AtomicIntegerArray(newSize); + for (int i = 0; i < newSize; i++) { + if (i < oldArray.length()) { + newArr.set(i, oldArray.get(i)); + } else { + newArr.set(i, Integer.MAX_VALUE); + } + } + completionTimes = newArr; + } + } finally { + sl.unlockWrite(stamp); + } + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java index 9505463fd..c31af00eb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java @@ -16,6 +16,9 @@ package io.github.jbellis.jvector.graph; +import java.util.Arrays; +import java.util.Objects; + /** * Container class for results of an ANN search, along with associated metrics about the behavior of the search. */ @@ -62,7 +65,7 @@ public float getWorstApproximateScoreInTopK() { return worstApproximateScoreInTopK; } - public static final class NodeScore { + public static final class NodeScore implements Comparable { public final int node; public final float score; @@ -75,5 +78,37 @@ public NodeScore(int node, float score) { public String toString() { return String.format("NodeScore(%d, %s)", node, score); } + + @Override + public int compareTo(NodeScore o) { + // Sort by score in descending order (highest score first) + int scoreCompare = Float.compare(o.score, this.score); + // If scores are equal, break ties using node id (ascending order) + return scoreCompare != 0 ? scoreCompare : Integer.compare(node, o.node); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + NodeScore nodeScore = (NodeScore) o; + return node == nodeScore.node && Float.compare(score, nodeScore.score) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(node, score); + } + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + SearchResult that = (SearchResult) o; + return visitedCount == that.visitedCount && rerankedCount == that.rerankedCount && Float.compare(worstApproximateScoreInTopK, that.worstApproximateScoreInTopK) == 0 && Objects.deepEquals(nodes, that.nodes); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(nodes), visitedCount, rerankedCount, worstApproximateScoreInTopK); } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java deleted file mode 100644 index 59d551213..000000000 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.graph.disk; - -import io.github.jbellis.jvector.graph.GraphIndex; -import io.github.jbellis.jvector.graph.NodesIterator; -import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.util.Accountable; -import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import io.github.jbellis.jvector.vector.types.VectorFloat; - -import java.io.IOException; - -public class CachingGraphIndex implements GraphIndex, Accountable -{ - private static final int CACHE_DISTANCE = 3; - - private final GraphCache cache_; - private final OnDiskGraphIndex graph; - - public CachingGraphIndex(OnDiskGraphIndex graph) - { - this(graph, CACHE_DISTANCE); - } - - public CachingGraphIndex(OnDiskGraphIndex graph, int cacheDistance) - { - this.graph = graph; - this.cache_ = GraphCache.load(graph, cacheDistance); - } - - @Override - public int size() { - return graph.size(); - } - - @Override - public NodesIterator getNodes() { - return graph.getNodes(); - } - - @Override - public ScoringView getView() { - return new View(cache_, graph.getView()); - } - - @Override - public int maxDegree() { - return graph.maxDegree(); - } - - @Override - public long ramBytesUsed() { - return graph.ramBytesUsed() + cache_.ramBytesUsed(); - } - - @Override - public void close() throws IOException { - graph.close(); - } - - @Override - public String toString() { - return String.format("CachingGraphIndex(graph=%s)", graph); - } - - public static class View implements ScoringView { - private final GraphCache cache; - protected final OnDiskGraphIndex.View view; - - public View(GraphCache cache, OnDiskGraphIndex.View view) { - this.cache = cache; - this.view = view; - } - - @Override - public NodesIterator getNeighborsIterator(int ordinal) { - var node = cache.getNode(ordinal); - if (node != null) { - return new NodesIterator.ArrayNodesIterator(node.neighbors, node.neighbors.length); - } - return view.getNeighborsIterator(ordinal); - } - - @Override - public int size() { - return view.size(); - } - - @Override - public int entryNode() { - return view.entryNode(); - } - - @Override - public Bits liveNodes() { - return view.liveNodes(); - } - - @Override - public void close() throws IOException { - view.close(); - } - - @Override - public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { - return view.rerankerFor(queryVector, vsf); - } - - @Override - public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { - return view.approximateScoreFunctionFor(queryVector, vsf); - } - } -} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java index be57781f2..8f7226a66 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CommonHeader.java @@ -16,67 +16,117 @@ package io.github.jbellis.jvector.graph.disk; +import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import io.github.jbellis.jvector.graph.GraphIndex; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * Base header for OnDiskGraphIndex functionality. */ -class CommonHeader { +public class CommonHeader { + private static final Logger logger = LoggerFactory.getLogger(CommonHeader.class); + + private static final int V4_MAX_LAYERS = 32; + public final int version; - public final int size; public final int dimension; public final int entryNode; - public final int maxDegree; + public final List layerInfo; public final int idUpperBound; - CommonHeader(int version, int size, int dimension, int entryNode, int maxDegree, int idUpperBound) { + CommonHeader(int version, int dimension, int entryNode, List layerInfo, int idUpperBound) { this.version = version; - this.size = size; this.dimension = dimension; this.entryNode = entryNode; - this.maxDegree = maxDegree; + this.layerInfo = layerInfo; this.idUpperBound = idUpperBound; } - void write(DataOutput out) throws IOException { + void write(RandomAccessWriter out) throws IOException { + logger.debug("Writing common header at position {}", out.position()); if (version >= 3) { out.writeInt(OnDiskGraphIndex.MAGIC); out.writeInt(version); } - out.writeInt(size); + out.writeInt(layerInfo.get(0).size); out.writeInt(dimension); out.writeInt(entryNode); - out.writeInt(maxDegree); - + out.writeInt(layerInfo.get(0).degree); if (version >= 4) { out.writeInt(idUpperBound); + + if (layerInfo.size() > V4_MAX_LAYERS) { + var msg = String.format("Number of layers %d exceeds maximum of %d", layerInfo.size(), V4_MAX_LAYERS); + throw new IllegalArgumentException(msg); + } + logger.debug("Writing {} layers", layerInfo.size()); + out.writeInt(layerInfo.size()); + // Write actual layer info + for (LayerInfo info : layerInfo) { + out.writeInt(info.size); + out.writeInt(info.degree); + } + // Pad remaining entries with zeros + for (int i = layerInfo.size(); i < V4_MAX_LAYERS; i++) { + out.writeInt(0); // size + out.writeInt(0); // degree + } + } else { + if (layerInfo.size() > 1) { + throw new IllegalArgumentException("Layer info is not supported in version " + version); + } } + logger.debug("Common header finished writing at position {}", out.position()); } - static CommonHeader load(RandomAccessReader reader) throws IOException { - int maybeMagic = reader.readInt(); + static CommonHeader load(RandomAccessReader in) throws IOException { + logger.debug("Loading common header at position {}", in.getPosition()); + int maybeMagic = in.readInt(); int version; int size; if (maybeMagic == OnDiskGraphIndex.MAGIC) { - version = reader.readInt(); - size = reader.readInt(); + version = in.readInt(); + size = in.readInt(); } else { version = 2; size = maybeMagic; } - int dimension = reader.readInt(); - int entryNode = reader.readInt(); - int maxDegree = reader.readInt(); - + int dimension = in.readInt(); + int entryNode = in.readInt(); + int maxDegree = in.readInt(); int idUpperBound = size; - if (version >= 4) { - idUpperBound = reader.readInt(); + List layerInfo; + if (version < 4) { + layerInfo = List.of(new LayerInfo(size, maxDegree)); + } else { + idUpperBound = in.readInt(); + int numLayers = in.readInt(); + logger.debug("{} layers", numLayers); + layerInfo = new ArrayList<>(); + for (int i = 0; i < numLayers; i++) { + LayerInfo info = new LayerInfo(in.readInt(), in.readInt()); + layerInfo.add(info); + } + // Skip over remaining padding entries + for (int i = numLayers; i < V4_MAX_LAYERS; i++) { + in.readInt(); + in.readInt(); + } } + logger.debug("Common header finished reading at position {}", in.getPosition()); - return new CommonHeader(version, size, dimension, entryNode, maxDegree, idUpperBound); + return new CommonHeader(version, dimension, entryNode, layerInfo, idUpperBound); } int size() { @@ -85,8 +135,46 @@ int size() { size += 2; } if (version >= 4) { - size += 1; + size += 2 + 2 * V4_MAX_LAYERS; } return size * Integer.BYTES; } -} + + @VisibleForTesting + public static class LayerInfo { + public final int size; + public final int degree; + + public LayerInfo(int size, int degree) { + this.size = size; + this.degree = degree; + } + + public static List fromGraph(GraphIndex graph, OrdinalMapper mapper) { + return IntStream.rangeClosed(0, graph.getMaxLevel()) + .mapToObj(i -> new LayerInfo(graph.size(i), graph.getDegree(i))) + .collect(Collectors.toList()); + } + + @Override + public String toString() { + return "LayerInfo{" + + "size=" + size + + ", degree=" + degree + + '}'; + } + + @Override + public int hashCode() { + return Objects.hash(size, degree); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + LayerInfo other = (LayerInfo) obj; + return size == other.size && degree == other.degree; + } + } +} \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphCache.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphCache.java deleted file mode 100644 index bb4ec89c3..000000000 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphCache.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.graph.disk; - -import io.github.jbellis.jvector.util.Accountable; -import io.github.jbellis.jvector.util.RamUsageEstimator; -import org.agrona.collections.Int2ObjectHashMap; - -public abstract class GraphCache implements Accountable -{ - public static class CachedNode implements Accountable { - public final int[] neighbors; - - public CachedNode( int[] neighbors) { - this.neighbors = neighbors; - } - - public long ramBytesUsed() { - return RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long) neighbors.length * Integer.BYTES; - } - } - - /** return the cached node if present, or null if not */ - public abstract CachedNode getNode(int ordinal); - - public static GraphCache load(OnDiskGraphIndex graph, int distance) - { - if (distance < 0) - return new EmptyGraphCache(); - return new HMGraphCache(graph, distance); - } - - public abstract long ramBytesUsed(); - - private static final class EmptyGraphCache extends GraphCache - { - @Override - public CachedNode getNode(int ordinal) { - return null; - } - - @Override - public long ramBytesUsed() - { - return 0; - } - } - - private static final class HMGraphCache extends GraphCache - { - // Map is created on construction and never modified - private final Int2ObjectHashMap cache; - private long ramBytesUsed = 0; - - public HMGraphCache(OnDiskGraphIndex graph, int distance) { - try (var view = graph.getView()) { - var tmpCache = new Int2ObjectHashMap(); - cacheNeighborsOf(tmpCache, view, view.entryNode(), distance); - // Assigning to a final value ensure it is safely published - cache = tmpCache; - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private void cacheNeighborsOf(Int2ObjectHashMap tmpCache, OnDiskGraphIndex.View view, int ordinal, int distance) { - // cache this node - var it = view.getNeighborsIterator(ordinal); - int[] neighbors = new int[it.size()]; - int i = 0; - while (it.hasNext()) { - neighbors[i++] = it.nextInt(); - } - var node = new CachedNode(neighbors); - tmpCache.put(ordinal, node); - // ignores internal Map overhead but that should be negligible compared to the node contents - ramBytesUsed += Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_REF + node.ramBytesUsed(); - - // call recursively on neighbors - if (distance > 0) { - for (var neighbor : neighbors) { - if (!tmpCache.containsKey(neighbor)) { - cacheNeighborsOf(tmpCache, view, neighbor, distance - 1); - } - } - } - } - - - @Override - public CachedNode getNode(int ordinal) { - return cache.get(ordinal); - } - - @Override - public long ramBytesUsed() - { - return ramBytesUsed; - } - } -} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java index 5e1143fbd..bbc6ffb5f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java @@ -17,6 +17,9 @@ package io.github.jbellis.jvector.graph.disk; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import java.io.DataOutput; import java.io.IOException; @@ -35,7 +38,7 @@ class Header { this.features = features; } - void write(DataOutput out) throws IOException { + void write(RandomAccessWriter out) throws IOException { common.write(out); if (common.version >= 3) { @@ -78,4 +81,4 @@ static Header load(RandomAccessReader reader, long offset) throws IOException { return new Header(common, features); } -} +} \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index b53df50eb..391f5f7cc 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -22,8 +22,17 @@ import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FeatureSource; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.Accountable; +import org.agrona.collections.Int2ObjectHashMap; +import java.util.ArrayList; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -35,9 +44,14 @@ import java.io.UncheckedIOException; import java.nio.file.Path; import java.util.EnumMap; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A class representing a graph index stored on disk. The base graph contains only graph structure. @@ -48,39 +62,78 @@ */ public class OnDiskGraphIndex implements GraphIndex, AutoCloseable, Accountable { + private static final Logger logger = LoggerFactory.getLogger(OnDiskGraphIndex.class); public static final int CURRENT_VERSION = 4; static final int MAGIC = 0xFFFF0D61; // FFFF to distinguish from old graphs, which should never start with a negative size "ODGI" static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); final ReaderSupplier readerSupplier; final int version; - final int size; - final int maxDegree; final int dimension; - final int entryNode; + final NodeAtLevel entryNode; final int idUpperBound; final int inlineBlockSize; // total size of all inline elements contributed by features final EnumMap features; final EnumMap inlineOffsets; + private final List layerInfo; + // offset of L0 adjacency data private final long neighborsOffset; + /** For layers > 0, store adjacency fully in memory. */ + private final AtomicReference>> inMemoryNeighbors; OnDiskGraphIndex(ReaderSupplier readerSupplier, Header header, long neighborsOffset) { this.readerSupplier = readerSupplier; this.version = header.common.version; - this.size = header.common.size; + this.layerInfo = header.common.layerInfo; this.dimension = header.common.dimension; - this.entryNode = header.common.entryNode; - this.maxDegree = header.common.maxDegree; + this.entryNode = new NodeAtLevel(header.common.layerInfo.size() - 1, header.common.entryNode); this.idUpperBound = header.common.idUpperBound; this.features = header.features; this.neighborsOffset = neighborsOffset; var inlineBlockSize = 0; inlineOffsets = new EnumMap<>(FeatureId.class); for (var entry : features.entrySet()) { - inlineOffsets.put(entry.getKey(), inlineBlockSize); - inlineBlockSize += entry.getValue().inlineSize(); + var feature = entry.getValue(); + if (!(feature instanceof SeparatedFeature)) { + inlineOffsets.put(entry.getKey(), inlineBlockSize); + inlineBlockSize += feature.featureSize(); + } } this.inlineBlockSize = inlineBlockSize; + inMemoryNeighbors = new AtomicReference<>(null); + } + + private List> loadInMemoryLayers(RandomAccessReader in) throws IOException { + var imn = new ArrayList>(layerInfo.size()); + // For levels > 0, we load adjacency into memory + imn.add(null); // L0 placeholder so we don't have to mangle indexing + long L0size = 0; + L0size = idUpperBound * (inlineBlockSize + Integer.BYTES * (1L + 1L + layerInfo.get(0).degree)); + in.seek(neighborsOffset + L0size); + + for (int lvl = 1; lvl < layerInfo.size(); lvl++) { + CommonHeader.LayerInfo info = layerInfo.get(lvl); + Int2ObjectHashMap edges = new Int2ObjectHashMap<>(); + + for (int i = 0; i < info.size; i++) { + int nodeId = in.readInt(); + assert nodeId >= 0 && nodeId < layerInfo.get(0).size : + String.format("Node ID %d out of bounds for layer %d", nodeId, lvl); + int neighborCount = in.readInt(); + assert neighborCount >= 0 && neighborCount <= info.degree + : String.format("Node %d neighborCount %d > M %d", nodeId, neighborCount, info.degree); + int[] neighbors = new int[neighborCount]; + in.read(neighbors, 0, neighborCount); + + // skip any padding up to 'degree' neighbors + int skip = info.degree - neighborCount; + if (skip > 0) in.seek(in.getPosition() + ((long) skip * Integer.BYTES)); + + edges.put(nodeId, neighbors); + } + imn.add(edges); + } + return imn; } /** @@ -88,7 +141,12 @@ public class OnDiskGraphIndex implements GraphIndex, AutoCloseable, Accountable */ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset) { try (var reader = readerSupplier.get()) { + logger.debug("Loading OnDiskGraphIndex from offset={}", offset); var header = Header.load(reader, offset); + logger.debug("Header loaded: version={}, dimension={}, entryNode={}, layerInfoCount={}", + header.common.version, header.common.dimension, header.common.entryNode, header.common.layerInfo.size()); + logger.debug("Position after reading header={}", + reader.getPosition()); return new OnDiskGraphIndex(readerSupplier, header, reader.getPosition()); } catch (Exception e) { throw new RuntimeException("Error initializing OnDiskGraph at offset " + offset, e); @@ -111,13 +169,13 @@ public int getDimension() { } @Override - public int size() { - return size; + public int size(int level) { + return layerInfo.get(level).size; } @Override - public int maxDegree() { - return maxDegree; + public int getDegree(int level) { + return layerInfo.get(level).degree; } @Override @@ -126,22 +184,35 @@ public int getIdUpperBound() { } @Override - public NodesIterator getNodes() { + public NodesIterator getNodes(int level) { + int size = size(level); + int maxDegree = getDegree(level); + + long layer0NodeSize = (long) Integer.BYTES // ids + + inlineBlockSize // inline elements + + (Integer.BYTES * (long) (maxDegree + 1)); + long layerUpperNodeSize = (long) Integer.BYTES // ids + + (Integer.BYTES * (long) (maxDegree + 1)); // neighbor count + neighbors) + long thisLayerNodeSide = level == 0? layer0NodeSize : layerUpperNodeSize; + + long layerOffset = neighborsOffset; + layerOffset += level > 0? layer0NodeSize * size(0) : 0; + for (int lvl = 1; lvl < level; lvl++) { + layerOffset += layerUpperNodeSize * size(lvl); + } + try (var reader = readerSupplier.get()) { - int[] valid_nodes = new int[size]; + int[] validNodes = new int[size(level)]; + int upperBound = level == 0 ? getIdUpperBound() : size(level); int pos = 0; - for (int node = 0; node < getIdUpperBound(); node++) { - long node_offset = neighborsOffset + - (node * ((long) Integer.BYTES // ids - + inlineBlockSize // inline elements - + (Integer.BYTES * (long) (maxDegree + 1)) // neighbor count + neighbors) - )); + for (int node = 0; node < upperBound; node++) { + long node_offset = layerOffset + (node * thisLayerNodeSide); reader.seek(node_offset); if (reader.readInt() != -1) { - valid_nodes[pos++] = node; + validNodes[pos++] = node; } } - return new NodesIterator.ArrayNodesIterator(valid_nodes, size); + return new NodesIterator.ArrayNodesIterator(validNodes, size); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -159,10 +230,20 @@ public void close() throws IOException { @Override public String toString() { - return String.format("OnDiskGraphIndex(size=%d, entryPoint=%d, features=%s)", size, entryNode, + return String.format("OnDiskGraphIndex(layers=%s, entryPoint=%s, features=%s)", layerInfo, entryNode, features.keySet().stream().map(Enum::name).collect(Collectors.joining(","))); } + @Override + public int getMaxLevel() { + return entryNode.level; + } + + @Override + public int maxDegree() { + return layerInfo.stream().mapToInt(li -> li.degree).max().orElseThrow(); + } + // re-declared to specify type @Override public View getView() { @@ -179,7 +260,7 @@ public class View implements FeatureSource, ScoringView, RandomAccessVectorValue public View(RandomAccessReader reader) { this.reader = reader; - this.neighbors = new int[maxDegree]; + this.neighbors = new int[layerInfo.stream().mapToInt(li -> li.degree).max().orElse(0)]; } @Override @@ -198,36 +279,55 @@ public RandomAccessVectorValues copy() { throw new UnsupportedOperationException(); // need to copy reader } - protected long inlineOffsetFor(int node, FeatureId featureId) { + protected long offsetFor(int node, FeatureId featureId) { + Feature feature = features.get(featureId); + + // Separated features are just global offset + node offset + if (feature instanceof SeparatedFeature) { + SeparatedFeature sf = (SeparatedFeature) feature; + return sf.getOffset() + (node * (long) feature.featureSize()); + } + + // Inline features are in layer 0 only return neighborsOffset + (node * ((long) Integer.BYTES // ids + inlineBlockSize // inline elements - + (Integer.BYTES * (long) (maxDegree + 1)) // neighbor count + neighbors) + + (Integer.BYTES * (long) (layerInfo.get(0).degree + 1)) // neighbor count + neighbors) )) + Integer.BYTES + // id inlineOffsets.get(featureId); } - long neighborsOffsetFor(int node) { - return neighborsOffset + - (node + 1) * (Integer.BYTES + (long) inlineBlockSize) + - (node * (long) Integer.BYTES * (maxDegree + 1)); + private long neighborsOffsetFor(int level, int node) { + assert level == 0; // higher layers are in memory + int degree = layerInfo.get(level).degree; + + // skip node ID + inline features + long skipInline = Integer.BYTES + inlineBlockSize; + long blockBytes = skipInline + (long) Integer.BYTES * (degree + 1); + + long offsetWithinLayer = blockBytes * node; + return neighborsOffset + offsetWithinLayer + skipInline; } @Override - public RandomAccessReader inlineReaderForNode(int node, FeatureId featureId) throws IOException { - long offset = inlineOffsetFor(node, featureId); + public RandomAccessReader featureReaderForNode(int node, FeatureId featureId) throws IOException { + long offset = offsetFor(node, featureId); reader.seek(offset); return reader; } @Override public VectorFloat getVector(int node) { - if (!features.containsKey(FeatureId.INLINE_VECTORS)) { - throw new UnsupportedOperationException("No inline vectors in this graph"); + var feature = features.get(FeatureId.INLINE_VECTORS); + if (feature == null) { + feature = features.get(FeatureId.SEPARATED_VECTORS); + } + if (feature == null) { + throw new UnsupportedOperationException("No full-resolution vectors in this graph"); } try { - long offset = inlineOffsetFor(node, FeatureId.INLINE_VECTORS); + long offset = offsetFor(node, feature.id()); reader.seek(offset); return vectorTypeSupport.readFloatVector(reader, dimension); } catch (IOException e) { @@ -237,12 +337,16 @@ public VectorFloat getVector(int node) { @Override public void getVectorInto(int node, VectorFloat vector, int offset) { - if (!features.containsKey(FeatureId.INLINE_VECTORS)) { - throw new UnsupportedOperationException("No inline vectors in this graph"); + var feature = features.get(FeatureId.INLINE_VECTORS); + if (feature == null) { + feature = features.get(FeatureId.SEPARATED_VECTORS); + } + if (feature == null) { + throw new UnsupportedOperationException("No full-resolution vectors in this graph"); } try { - long diskOffset = inlineOffsetFor(node, FeatureId.INLINE_VECTORS); + long diskOffset = offsetFor(node, feature.id()); reader.seek(diskOffset); vectorTypeSupport.readFloatVector(reader, dimension, vector, offset); } catch (IOException e) { @@ -250,13 +354,32 @@ public void getVectorInto(int node, VectorFloat vector, int offset) { } } - public NodesIterator getNeighborsIterator(int node) { + public NodesIterator getNeighborsIterator(int level, int node) { try { - reader.seek(neighborsOffsetFor(node)); - int neighborCount = reader.readInt(); - assert neighborCount <= maxDegree : String.format("Node %d neighborCount %d > M %d", node, neighborCount, maxDegree); - reader.read(neighbors, 0, neighborCount); - return new NodesIterator.ArrayNodesIterator(neighbors, neighborCount); + if (level == 0) { + // For layer 0, read from disk + reader.seek(neighborsOffsetFor(level, node)); + int neighborCount = reader.readInt(); + assert neighborCount <= neighbors.length + : String.format("Node %d neighborCount %d > M %d", node, neighborCount, neighbors.length); + reader.read(neighbors, 0, neighborCount); + return new NodesIterator.ArrayNodesIterator(neighbors, neighborCount); + } else { + // For levels > 0, read from memory + var imn = inMemoryNeighbors.updateAndGet(current -> { + if (current != null) { + return current; + } + try { + return loadInMemoryLayers(reader); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + int[] stored = imn.get(level).get(node); + assert stored != null : String.format("No neighbors found for node %d at level %d", node, level); + return new NodesIterator.ArrayNodesIterator(stored, stored.length); + } } catch (IOException e) { throw new UncheckedIOException(e); } @@ -264,11 +387,12 @@ public NodesIterator getNeighborsIterator(int node) { @Override public int size() { - return size; + // For vector operations we only care about layer 0 + return OnDiskGraphIndex.this.size(0); } @Override - public int entryNode() { + public NodeAtLevel entryNode() { return entryNode; } @@ -282,7 +406,6 @@ public Bits liveNodes() { return Bits.ALL; } - @Override public void close() throws IOException { reader.close(); @@ -326,7 +449,7 @@ public static void write(GraphIndex graph, .build()) { var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, - nodeId -> new InlineVectors.State(vectors.getVector(nodeId))); + nodeId -> new InlineVectors.State(vectors.getVector(nodeId))); writer.write(suppliers); } } @@ -334,9 +457,8 @@ public static void write(GraphIndex graph, @VisibleForTesting static boolean areHeadersEqual(OnDiskGraphIndex g1, OnDiskGraphIndex g2) { return g1.version == g2.version && - g1.size == g2.size && - g1.maxDegree == g2.maxDegree && - g1.dimension == g2.dimension && - g1.entryNode == g2.entryNode; + g1.dimension == g2.dimension && + g1.entryNode.equals(g2.entryNode) && + g1.layerInfo.equals(g2.layerInfo); } -} +} \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index 1d1eb01a5..2f53d8b41 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -20,6 +20,13 @@ import io.github.jbellis.jvector.disk.RandomAccessWriter; import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedVectors; import org.agrona.collections.Int2IntHashMap; import java.io.Closeable; @@ -27,15 +34,24 @@ import java.io.IOException; import java.nio.file.Path; import java.util.EnumMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.IntFunction; +import java.util.stream.Collectors; /** * Write a graph index to disk, for later loading as an OnDiskGraphIndex. *

* Implements `getFeatureSource` to allow incremental construction of a larger-than-memory graph * (using the writer as the source of INLINE_VECTORS). + * + * Layout: + * [CommonHeader] + * [Header with Features] + * [Edges + inline features for level 0] + * [Edges for levels 1..N] + * [Separated features] */ public class OnDiskGraphIndexWriter implements Closeable { private final int version; @@ -50,6 +66,7 @@ public class OnDiskGraphIndexWriter implements Closeable { private final long startOffset; private final int headerSize; private volatile int maxOrdinalWritten = -1; + private final List inlineFeatures; private OnDiskGraphIndexWriter(RandomAccessWriter out, int version, @@ -59,17 +76,22 @@ private OnDiskGraphIndexWriter(RandomAccessWriter out, int dimension, EnumMap features) { + if (graph.getMaxLevel() > 0 && version < 4) { + throw new IllegalArgumentException("Multilayer graphs must be written with version 4 or higher"); + } this.version = version; this.graph = graph; - this.view = graph.getView(); + this.view = graph instanceof OnHeapGraphIndex ? ((OnHeapGraphIndex) graph).getFrozenView() : graph.getView(); this.ordinalMapper = oldToNewOrdinals; this.dimension = dimension; this.featureMap = features; + this.inlineFeatures = features.values().stream().filter(f -> !(f instanceof SeparatedFeature)).collect(Collectors.toList()); this.out = out; this.startOffset = startOffset; // create a mock Header to determine the correct size - var ch = new CommonHeader(version, 0, dimension, view.entryNode(), graph.maxDegree(), 0); + var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); + var ch = new CommonHeader(version, dimension, 0, layerInfo, 0); var placeholderHeader = new Header(ch, featureMap); this.headerSize = placeholderHeader.size(); } @@ -112,10 +134,10 @@ public synchronized void writeInline(int ordinal, Map out.seek(featureOffsetForOrdinal(ordinal)); - for (var feature : featureMap.values()) { + for (var feature : inlineFeatures) { var state = stateMap.get(feature.id()); if (state == null) { - out.seek(out.position() + feature.inlineSize()); + out.seek(out.position() + feature.featureSize()); } else { feature.writeInline(out, state); } @@ -132,8 +154,8 @@ public int getMaxOrdinal() { } private long featureOffsetForOrdinal(int ordinal) { - int edgeSize = Integer.BYTES * (1 + graph.maxDegree()); - long inlineBytes = ordinal * (long) (Integer.BYTES + featureMap.values().stream().mapToInt(Feature::inlineSize).sum() + edgeSize); + int edgeSize = Integer.BYTES * (1 + graph.getDegree(0)); + long inlineBytes = ordinal * (long) (Integer.BYTES + inlineFeatures.stream().mapToInt(Feature::featureSize).sum() + edgeSize); return startOffset + headerSize + inlineBytes // previous nodes @@ -147,6 +169,10 @@ private long featureOffsetForOrdinal(int ordinal) { *

* Each supplier takes a node ordinal and returns a FeatureState suitable for Feature.writeInline. */ + private boolean isSeparated(Feature feature) { + return feature instanceof SeparatedFeature; + } + public synchronized void write(Map> featureStateSuppliers) throws IOException { if (graph instanceof OnHeapGraphIndex) { @@ -162,24 +188,24 @@ public synchronized void write(Map> featur } if (ordinalMapper.maxOrdinal() < graph.size() - 1) { var msg = String.format("Ordinal mapper from [0..%d] does not cover all nodes in the graph of size %d", - ordinalMapper.maxOrdinal(), graph.size()); + ordinalMapper.maxOrdinal(), graph.size()); throw new IllegalStateException(msg); } - writeHeader(); + writeHeader(); // sets position to start writing features - // for each graph node, write the associated features, followed by its neighbors + // for each graph node, write the associated features, followed by its neighbors at L0 for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) { var originalOrdinal = ordinalMapper.newToOld(newOrdinal); // if no node exists with the given ordinal, write a placeholder if (originalOrdinal == OrdinalMapper.OMITTED) { out.writeInt(-1); - for (var feature : featureMap.values()) { - out.seek(out.position() + feature.inlineSize()); + for (var feature : inlineFeatures) { + out.seek(out.position() + feature.featureSize()); } out.writeInt(0); - for (int n = 0; n < graph.maxDegree(); n++) { + for (int n = 0; n < graph.getDegree(0); n++) { out.writeInt(-1); } continue; @@ -191,19 +217,19 @@ public synchronized void write(Map> featur } out.writeInt(newOrdinal); // unnecessary, but a reasonable sanity check assert out.position() == featureOffsetForOrdinal(newOrdinal) : String.format("%d != %d", out.position(), featureOffsetForOrdinal(newOrdinal)); - for (var feature : featureMap.values()) { + for (var feature : inlineFeatures) { var supplier = featureStateSuppliers.get(feature.id()); if (supplier == null) { - out.seek(out.position() + feature.inlineSize()); + out.seek(out.position() + feature.featureSize()); } else { feature.writeInline(out, supplier.apply(originalOrdinal)); } } - var neighbors = view.getNeighborsIterator(originalOrdinal); - if (neighbors.size() > graph.maxDegree()) { + var neighbors = view.getNeighborsIterator(0, originalOrdinal); + if (neighbors.size() > graph.getDegree(0)) { var msg = String.format("Node %d has more neighbors %d than the graph's max degree %d -- run Builder.cleanup()!", - originalOrdinal, neighbors.size(), graph.maxDegree()); + originalOrdinal, neighbors.size(), graph.getDegree(0)); throw new IllegalStateException(msg); } // write neighbors list @@ -220,10 +246,68 @@ public synchronized void write(Map> featur assert !neighbors.hasNext(); // pad out to maxEdgesPerNode - for (; n < graph.maxDegree(); n++) { + for (; n < graph.getDegree(0); n++) { out.writeInt(-1); } } + + // write sparse levels + for (int level = 1; level <= graph.getMaxLevel(); level++) { + int layerSize = graph.size(level); + int layerDegree = graph.getDegree(level); + int nodesWritten = 0; + for (var it = graph.getNodes(level); it.hasNext(); ) { + int originalOrdinal = it.nextInt(); + // node id + out.writeInt(ordinalMapper.oldToNew(originalOrdinal)); + // neighbors + var neighbors = view.getNeighborsIterator(level, originalOrdinal); + out.writeInt(neighbors.size()); + int n = 0; + for ( ; n < neighbors.size(); n++) { + out.writeInt(ordinalMapper.oldToNew(neighbors.nextInt())); + } + assert !neighbors.hasNext() : "Mismatch between neighbor's reported size and actual size"; + // pad out to degree + for (; n < layerDegree; n++) { + out.writeInt(-1); + } + nodesWritten++; + } + if (nodesWritten != layerSize) { + throw new IllegalStateException("Mismatch between layer size and nodes written"); + } + } + + // Write separated features + for (var featureEntry : featureMap.entrySet()) { + if (isSeparated(featureEntry.getValue())) { + var fid = featureEntry.getKey(); + var supplier = featureStateSuppliers.get(fid); + if (supplier == null) { + throw new IllegalStateException("Supplier for feature " + fid + " not found"); + } + + // Set the offset for this feature + var feature = (SeparatedFeature) featureEntry.getValue(); + feature.setOffset(out.position()); + + // Write separated data for each node + for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) { + int originalOrdinal = ordinalMapper.newToOld(newOrdinal); + if (originalOrdinal != OrdinalMapper.OMITTED) { + feature.writeSeparately(out, supplier.apply(originalOrdinal)); + } else { + out.seek(out.position() + feature.featureSize()); + } + } + } + } + + // Write the header again with updated offsets + long currentPosition = out.position(); + writeHeader(); + out.seek(currentPosition); out.flush(); } @@ -237,12 +321,12 @@ public synchronized void write(Map> featur public synchronized void writeHeader() throws IOException { // graph-level properties out.seek(startOffset); + var layerInfo = CommonHeader.LayerInfo.fromGraph(graph, ordinalMapper); var commonHeader = new CommonHeader(version, - graph.size(), - dimension, - ordinalMapper.oldToNew(view.entryNode()), - graph.maxDegree(), - ordinalMapper.maxOrdinal() + 1); + dimension, + ordinalMapper.oldToNew(view.entryNode().node), + layerInfo, + ordinalMapper.maxOrdinal() + 1); var header = new Header(commonHeader, featureMap); header.write(out); out.flush(); @@ -319,7 +403,7 @@ public Builder withMapper(OrdinalMapper ordinalMapper) { /** * Set the starting offset for the graph index in the output file. This is useful if you want to - * write the index to an existing file. + * append the index to an existing file. */ public Builder withStartOffset(long startOffset) { this.startOffset = startOffset; @@ -336,8 +420,12 @@ public OnDiskGraphIndexWriter build() throws IOException { dimension = ((InlineVectors) features.get(FeatureId.INLINE_VECTORS)).dimension(); } else if (features.containsKey(FeatureId.NVQ_VECTORS)) { dimension = ((NVQ) features.get(FeatureId.NVQ_VECTORS)).dimension(); + } else if (features.containsKey(FeatureId.SEPARATED_VECTORS)) { + dimension = ((SeparatedVectors) features.get(FeatureId.SEPARATED_VECTORS)).dimension(); + } else if (features.containsKey(FeatureId.SEPARATED_NVQ)) { + dimension = ((SeparatedNVQ) features.get(FeatureId.SEPARATED_NVQ)).dimension(); } else { - throw new IllegalArgumentException("Inline or NVQ vectors must be provided."); + throw new IllegalArgumentException("Inline or separated vector feature must be provided"); } if (ordinalMapper == null) { @@ -354,4 +442,4 @@ public Feature getFeature(FeatureId featureId) { return features.get(featureId); } } -} +} \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Feature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java similarity index 86% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Feature.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java index bb56c83de..a72ff10b6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Feature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import java.io.DataOutput; import java.io.IOException; @@ -29,13 +29,15 @@ public interface Feature { int headerSize(); - int inlineSize(); + int featureSize(); void writeHeader(DataOutput out) throws IOException; - void writeInline(DataOutput out, State state) throws IOException; + default void writeInline(DataOutput out, State state) throws IOException { + // default no-op + } - // Feature implementations should implement a State as well for use with writeInline + // Feature implementations should implement a State as well for use with writeInline/writeSeparately interface State { } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java similarity index 86% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java index 1964fe87a..dd0857834 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java @@ -14,9 +14,10 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.disk.CommonHeader; import java.util.Collections; import java.util.EnumSet; @@ -32,7 +33,9 @@ public enum FeatureId { INLINE_VECTORS(InlineVectors::load), FUSED_ADC(FusedADC::load), - NVQ_VECTORS(NVQ::load); + NVQ_VECTORS(NVQ::load), + SEPARATED_VECTORS(SeparatedVectors::load), + SEPARATED_NVQ(SeparatedNVQ::load); public static final Set ALL = Collections.unmodifiableSet(EnumSet.allOf(FeatureId.class)); @@ -42,7 +45,7 @@ public enum FeatureId { this.loader = loader; } - Feature load(CommonHeader header, RandomAccessReader reader) { + public Feature load(CommonHeader header, RandomAccessReader reader) { return loader.apply(header, reader); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureSource.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java similarity index 83% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureSource.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java index 6e9d40ca9..b8b24a949 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureSource.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureSource.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; @@ -22,5 +22,5 @@ import java.io.IOException; public interface FeatureSource extends Closeable { - RandomAccessReader inlineReaderForNode(int node, FeatureId featureId) throws IOException; + RandomAccessReader featureReaderForNode(int node, FeatureId featureId) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java similarity index 85% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java index e76dd2738..d17e38fa1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java @@ -14,10 +14,12 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.disk.CommonHeader; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.quantization.FusedADCPQDecoder; import io.github.jbellis.jvector.quantization.PQVectors; @@ -53,7 +55,7 @@ public FusedADC(int maxDegree, ProductQuantization pq) { } this.maxDegree = maxDegree; this.pq = pq; - this.reusableResults = ThreadLocal.withInitial(() -> OnDiskGraphIndex.vectorTypeSupport.createFloatVector(maxDegree)); + this.reusableResults = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(maxDegree)); this.reusableNeighbors = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree)); } @@ -68,19 +70,20 @@ public int headerSize() { } @Override - public int inlineSize() { + public int featureSize() { return pq.compressedVectorSize() * maxDegree; } static FusedADC load(CommonHeader header, RandomAccessReader reader) { + // TODO doesn't work with different degrees try { - return new FusedADC(header.maxDegree, ProductQuantization.load(reader)); + return new FusedADC(header.layerInfo.get(0).degree, ProductQuantization.load(reader)); } catch (IOException e) { throw new UncheckedIOException(e); } } - ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) { + public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) { var neighbors = new PackedNeighbors(view); return FusedADCPQDecoder.newDecoder(neighbors, pq, queryVector, reusableResults.get(), vsf, esf); } @@ -100,7 +103,7 @@ public void writeInline(DataOutput out, Feature.State state_) throws IOException var state = (FusedADC.State) state_; var pqv = state.pqVectors; - var neighbors = state.view.getNeighborsIterator(state.nodeId); + var neighbors = state.view.getNeighborsIterator(0, state.nodeId); int n = 0; var neighborSize = neighbors.size(); compressedNeighbors.zero(); @@ -135,9 +138,9 @@ public PackedNeighbors(OnDiskGraphIndex.View view) { public ByteSequence getPackedNeighbors(int node) { try { - var reader = view.inlineReaderForNode(node, FeatureId.FUSED_ADC); + var reader = view.featureReaderForNode(node, FeatureId.FUSED_ADC); var tlNeighbors = reusableNeighbors.get(); - OnDiskGraphIndex.vectorTypeSupport.readByteSequence(reader, tlNeighbors); + vectorTypeSupport.readByteSequence(reader, tlNeighbors); return tlNeighbors; } catch (IOException e) { throw new RuntimeException(e); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java similarity index 94% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java index d61fbe4f3..59e2b359c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java @@ -14,9 +14,10 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; @@ -45,7 +46,7 @@ public int headerSize() { return 0; } - public int inlineSize() { + public int featureSize() { return dimension * Float.BYTES; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java similarity index 84% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java index ef52fb10d..2489ada21 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java @@ -14,9 +14,11 @@ * limitations under the License. */ -package io.github.jbellis.jvector.graph.disk; +package io.github.jbellis.jvector.graph.disk.feature; import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.disk.CommonHeader; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.quantization.NVQScorer; import io.github.jbellis.jvector.quantization.NVQuantization; @@ -53,7 +55,7 @@ public int headerSize() { } @Override - public int inlineSize() { return nvq.compressedVectorSize();} + public int featureSize() { return nvq.compressedVectorSize();} public int dimension() { return nvq.globalMean.length(); @@ -86,14 +88,14 @@ public State(QuantizedVector vector) { } } - ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, - VectorSimilarityFunction vsf, - FeatureSource source) { + public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, + VectorSimilarityFunction vsf, + FeatureSource source) { var function = scorer.scoreFunctionFor(queryVector, vsf); return node2 -> { try { - var reader = source.inlineReaderForNode(node2, FeatureId.NVQ_VECTORS); + var reader = source.featureReaderForNode(node2, FeatureId.NVQ_VECTORS); QuantizedVector.loadInto(reader, reusableQuantizedVector.get()); } catch (IOException e) { throw new RuntimeException(e); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java new file mode 100644 index 000000000..d90aee603 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java @@ -0,0 +1,27 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +import java.io.DataOutput; +import java.io.IOException; + +public interface SeparatedFeature extends Feature { + void setOffset(long offset); + long getOffset(); + + void writeSeparately(DataOutput out, State state) throws IOException; +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java new file mode 100644 index 000000000..b5d4cc476 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java @@ -0,0 +1,118 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.disk.CommonHeader; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.quantization.NVQScorer; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.io.DataOutput; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class SeparatedNVQ implements SeparatedFeature { + private final NVQuantization nvq; + private final NVQScorer scorer; + private final ThreadLocal reusableQuantizedVector; + private long offset; + + public SeparatedNVQ(NVQuantization nvq, long offset) { + this.nvq = nvq; + this.offset = offset; + scorer = new NVQScorer(this.nvq); + reusableQuantizedVector = ThreadLocal.withInitial(() -> NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension)); + } + + @Override + public void setOffset(long offset) { + this.offset = offset; + } + + @Override + public long getOffset() { + return offset; + } + + @Override + public FeatureId id() { + return FeatureId.SEPARATED_NVQ; + } + + @Override + public int headerSize() { + return nvq.compressorSize() + Long.BYTES; + } + + @Override + public int featureSize() { + return nvq.compressedVectorSize(); + } + + @Override + public void writeHeader(DataOutput out) throws IOException { + nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION); + out.writeLong(offset); + } + + @Override + public void writeSeparately(DataOutput out, State state_) throws IOException { + var state = (NVQ.State) state_; + if (state.vector != null) { + state.vector.write(out); + } else { + // Write zeros for missing vector + NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension).write(out); + } + } + + // Using NVQ.State + + static SeparatedNVQ load(CommonHeader header, RandomAccessReader reader) { + try { + var nvq = NVQuantization.load(reader); + long offset = reader.readLong(); + return new SeparatedNVQ(nvq, offset); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public int dimension() { + return nvq.globalMean.length(); + } + + ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, + VectorSimilarityFunction vsf, + FeatureSource source) { + var function = scorer.scoreFunctionFor(queryVector, vsf); + + return node2 -> { + try { + var reader = source.featureReaderForNode(node2, FeatureId.SEPARATED_NVQ); + NVQuantization.QuantizedVector.loadInto(reader, reusableQuantizedVector.get()); + } catch (IOException e) { + throw new RuntimeException(e); + } + return function.similarityTo(reusableQuantizedVector.get()); + }; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java new file mode 100644 index 000000000..f6bff8472 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java @@ -0,0 +1,95 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.disk.CommonHeader; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; + +import java.io.DataOutput; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class SeparatedVectors implements SeparatedFeature { + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private final int dimension; + private long offset; + + public SeparatedVectors(int dimension, long offset) { + this.dimension = dimension; + this.offset = offset; + } + + @Override + public void setOffset(long offset) { + this.offset = offset; + } + + @Override + public long getOffset() { + return offset; + } + + @Override + public FeatureId id() { + return FeatureId.SEPARATED_VECTORS; + } + + @Override + public int headerSize() { + return Long.BYTES; + } + + @Override + public int featureSize() { + return dimension * Float.BYTES; + } + + @Override + public void writeHeader(DataOutput out) throws IOException { + out.writeLong(offset); + } + + @Override + public void writeSeparately(DataOutput out, State state_) throws IOException { + var state = (InlineVectors.State) state_; + if (state.vector != null) { + vectorTypeSupport.writeFloatVector(out, state.vector); + } else { + // Write zeros for missing vector + for (int j = 0; j < dimension; j++) { + out.writeFloat(0.0f); + } + } + } + + // Using InlineVectors.State + + static SeparatedVectors load(CommonHeader header, RandomAccessReader reader) { + try { + long offset = reader.readLong(); + return new SeparatedVectors(header.dimension, offset); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public int dimension() { + return dimension; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java index 2b99d3329..d55ffbd8c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java @@ -16,7 +16,7 @@ package io.github.jbellis.jvector.quantization; -import io.github.jbellis.jvector.graph.disk.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java index d6edd9da6..2843c16ce 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/AbstractLongHeap.java @@ -153,4 +153,16 @@ public long get(int i) { long[] getHeapArray() { return heap; } + + /** + * Copies the contents and current size from `other`. Does NOT copy subclass field like BLH's maxSize + */ + public void copyFrom(AbstractLongHeap other) + { + if (this.heap.length < other.size) { + this.heap = new long[other.heap.length]; + } + System.arraycopy(other.heap, 0, this.heap, 0, other.size); + this.size = other.size; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java index 9c70862c1..683cfb5dc 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java @@ -31,7 +31,7 @@ * "Dense-ish" means that space is allocated for all keys from 0 to the highest key, but * it is valid to have gaps in the keys. The value associated with "gap" keys is null. */ -public class DenseIntMap { +public class DenseIntMap implements IntMap { // locking strategy: // - writelock to resize the array // - readlock to update the array with put or remove @@ -40,37 +40,12 @@ public class DenseIntMap { private volatile AtomicReferenceArray objects; private final AtomicInteger size; - public DenseIntMap(int initialSize) { - objects = new AtomicReferenceArray<>(initialSize); + public DenseIntMap(int initialCapacity) { + objects = new AtomicReferenceArray<>(initialCapacity); size = new AtomicInteger(); } - /** - * @param key ordinal - * Prefer compareAndPut() - */ - @Deprecated - public void put(int key, T value) { - if (value == null) { - throw new IllegalArgumentException("put() value cannot be null -- use remove() instead"); - } - - ensureCapacity(key); - rwl.readLock().lock(); - try { - var isInsert = objects.getAndSet(key, value) == null; - if (isInsert) { - size.incrementAndGet(); - } - } finally { - rwl.readLock().unlock(); - } - } - - /** - * @param key ordinal - * @return true if successful, false if the current value != `existing` - */ + @Override public boolean compareAndPut(int key, T existing, T value) { if (value == null) { throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); @@ -90,17 +65,12 @@ public boolean compareAndPut(int key, T existing, T value) { } } - /** - * @return number of items that have been added - */ + @Override public int size() { return size.get(); } - /** - * @param key ordinal - * @return the value of the key, or null if not set - */ + @Override public T get(int key) { if (key >= objects.length()) { return null; @@ -130,9 +100,7 @@ private void ensureCapacity(int node) { } } - /** - * @return the former value of the key, or null if it was not set - */ + @Override public T remove(int key) { if (key >= objects.length()) { return null; @@ -155,21 +123,12 @@ public T remove(int key) { } } + @Override public boolean containsKey(int key) { return get(key) != null; } - public NodesIterator keysIterator() { - // implemented here because we can't make it threadsafe AND performant elsewhere - var minSize = size(); // if keys are added concurrently we will miss them - var ref = objects; - var keysInts = IntStream.range(0, ref.length()).filter(i -> ref.get(i) != null).iterator(); - return NodesIterator.fromPrimitiveIterator(keysInts, minSize); - } - - /** - * Iterates keys in ascending order and calls the consumer for each non-null key-value pair. - */ + @Override public void forEach(IntBiConsumer consumer) { var ref = objects; for (int i = 0; i < ref.length(); i++) { @@ -179,9 +138,4 @@ public void forEach(IntBiConsumer consumer) { } } } - - @FunctionalInterface - public interface IntBiConsumer { - void consume(int key, T value); - } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableLongHeap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableLongHeap.java index 9f190c72b..819567406 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableLongHeap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/GrowableLongHeap.java @@ -28,9 +28,6 @@ * An AbstractLongHeap that can grow in size (unbounded, except for memory and array size limits). */ public class GrowableLongHeap extends AbstractLongHeap { - - protected int size = 0; - /** * Create an empty heap with the configured initial size. * diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java new file mode 100644 index 000000000..713e9a3ab --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java @@ -0,0 +1,60 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.util; + +import io.github.jbellis.jvector.graph.NodesIterator; + +import java.util.stream.IntStream; + +public interface IntMap { + /** + * @param key ordinal + * @return true if successful, false if the current value != `existing` + */ + boolean compareAndPut(int key, T existing, T value); + + /** + * @return number of items that have been added + */ + int size(); + + /** + * @param key ordinal + * @return the value of the key, or null if not set + */ + T get(int key); + + /** + * @return the former value of the key, or null if it was not set + */ + T remove(int key); + + /** + * @return true iff the given key is set in the map + */ + boolean containsKey(int key); + + /** + * Iterates keys in ascending order and calls the consumer for each non-null key-value pair. + */ + void forEach(IntBiConsumer consumer); + + @FunctionalInterface + interface IntBiConsumer { + void consume(int key, T2 value); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java new file mode 100644 index 000000000..a8fc555e5 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java @@ -0,0 +1,73 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.util; + +import io.github.jbellis.jvector.graph.NodesIterator; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.IntStream; + +public class SparseIntMap implements IntMap { + private final ConcurrentHashMap map; + + public SparseIntMap() { + this.map = new ConcurrentHashMap<>(); + } + + @Override + public boolean compareAndPut(int key, T existing, T value) { + if (value == null) { + throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); + } + + if (existing == null) { + T result = map.putIfAbsent(key, value); + return result == null; + } + + return map.replace(key, existing, value); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public T get(int key) { + return map.get(key); + } + + @Override + public T remove(int key) { + return map.remove(key); + } + + @Override + public boolean containsKey(int key) { + return map.containsKey(key); + } + + public IntStream keysStream() { + return map.keySet().stream().mapToInt(key -> key); + } + + @Override + public void forEach(IntBiConsumer consumer) { + map.forEach(consumer::consume); + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java index 380e7182e..d487b2b47 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java @@ -22,7 +22,7 @@ import io.github.jbellis.jvector.example.util.DataSetCreator; import io.github.jbellis.jvector.example.util.DownloadHelper; import io.github.jbellis.jvector.example.util.Hdf5Loader; -import io.github.jbellis.jvector.graph.disk.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import java.io.IOException; @@ -43,8 +43,11 @@ public static void main(String[] args) throws IOException { System.out.println("Heap space available is " + Runtime.getRuntime().maxMemory()); var mGrid = List.of(32); // List.of(16, 24, 32, 48, 64, 96, 128); - var searchDepthConstructionGrid = List.of(100); // List.of(60, 80, 100, 120, 160, 200, 400, 600, 800); - var overqueryGrid = List.of(1.0, 2.0); // rerankK = oq * topK + var efConstructionGrid = List.of(100); // List.of(60, 80, 100, 120, 160, 200, 400, 600, 800); + var overqueryGrid = List.of(1.0, 2.0, 5.0); // rerankK = oq * topK + var neighborOverflowGrid = List.of(1.2f); // List.of(1.2f, 2.0f); + var addHierarchyGrid = List.of(true); // List.of(false, true); + var usePruningGrid = List.of(false); // List.of(false, true); List> buildCompression = Arrays.asList( ds -> new PQParameters(ds.getDimension() / 8, 256, ds.similarityFunction == VectorSimilarityFunction.EUCLIDEAN, UNWEIGHTED), __ -> CompressorParameters.NONE @@ -74,7 +77,7 @@ public static void main(String[] args) throws IOException { "nv-qa-v4-100k", "colbert-1M", "gecko-100k"); - executeNw(coreFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, searchDepthConstructionGrid, overqueryGrid); + executeNw(coreFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, overqueryGrid, usePruningGrid); var extraFiles = List.of( "openai-v3-large-3072-100k", @@ -82,7 +85,7 @@ public static void main(String[] args) throws IOException { "e5-small-v2-100k", "e5-base-v2-100k", "e5-large-v2-100k"); - executeNw(extraFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, searchDepthConstructionGrid, overqueryGrid); + executeNw(extraFiles, pattern, buildCompression, featureSets, searchCompression, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, overqueryGrid, usePruningGrid); // smaller vectors from ann-benchmarks var hdf5Files = List.of( @@ -99,7 +102,7 @@ public static void main(String[] args) throws IOException { for (var f : hdf5Files) { if (pattern.matcher(f).find()) { DownloadHelper.maybeDownloadHdf5(f); - Grid.runAll(Hdf5Loader.load(f), mGrid, searchDepthConstructionGrid, featureSets, buildCompression, searchCompression, overqueryGrid); + Grid.runAll(Hdf5Loader.load(f), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, overqueryGrid, usePruningGrid); } } @@ -109,15 +112,15 @@ public static void main(String[] args) throws IOException { ds -> new PQParameters(ds.getDimension(), 256, true, UNWEIGHTED)); buildCompression = Arrays.asList(__ -> CompressorParameters.NONE); var grid2d = DataSetCreator.create2DGrid(4_000_000, 10_000, 100); - Grid.runAll(grid2d, mGrid, searchDepthConstructionGrid, featureSets, buildCompression, searchCompression, overqueryGrid); + Grid.runAll(grid2d, mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, searchCompression, overqueryGrid, usePruningGrid); } } - private static void executeNw(List coreFiles, Pattern pattern, List> buildCompression, List> featureSets, List> compressionGrid, List mGrid, List searchDepthConstructionGrid, List overquerySearchGrid) throws IOException { + private static void executeNw(List coreFiles, Pattern pattern, List> buildCompression, List> featureSets, List> compressionGrid, List mGrid, List efConstructionGrid, List neighborOverflowGrid, List addHierarchyGrid, List efSearchGrid, List usePruningGrid) throws IOException { for (var nwDatasetName : coreFiles) { if (pattern.matcher(nwDatasetName).find()) { var mfd = DownloadHelper.maybeDownloadFvecs(nwDatasetName); - Grid.runAll(mfd.load(), mGrid, searchDepthConstructionGrid, featureSets, buildCompression, compressionGrid, overquerySearchGrid); + Grid.runAll(mfd.load(), mGrid, efConstructionGrid, neighborOverflowGrid, addHierarchyGrid, featureSets, buildCompression, compressionGrid, efSearchGrid, usePruningGrid); } } } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index 848d90998..194067953 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -25,12 +25,11 @@ import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.SearchResult; -import io.github.jbellis.jvector.graph.disk.CachingGraphIndex; -import io.github.jbellis.jvector.graph.disk.Feature; -import io.github.jbellis.jvector.graph.disk.FeatureId; -import io.github.jbellis.jvector.graph.disk.FusedADC; -import io.github.jbellis.jvector.graph.disk.InlineVectors; -import io.github.jbellis.jvector.graph.disk.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; @@ -63,6 +62,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; import java.util.function.IntFunction; @@ -80,19 +80,26 @@ public class Grid { static void runAll(DataSet ds, List mGrid, - List searchDepthConstructionGrid, + List efConstructionGrid, + List neighborOverflowGrid, + List addHierarchyGrid, List> featureSets, List> buildCompressors, List> compressionGrid, - List overquerySearchFactor) throws IOException + List efSearchFactor, + List usePruningGrid) throws IOException { var testDirectory = Files.createTempDirectory(dirPrefix); try { - for (int M : mGrid) { - for (int overqueryC : searchDepthConstructionGrid) { - for (var bc : buildCompressors) { - var compressor = getCompressor(bc, ds); - runOneGraph(featureSets, M, overqueryC, compressor, compressionGrid, overquerySearchFactor, ds, testDirectory); + for (var addHierarchy : addHierarchyGrid) { + for (int M : mGrid) { + for (float neighborOverflow: neighborOverflowGrid) { + for (int efC : efConstructionGrid) { + for (var bc : buildCompressors) { + var compressor = getCompressor(bc, ds); + runOneGraph(featureSets, M, efC, neighborOverflow, addHierarchy, compressor, compressionGrid, efSearchFactor, usePruningGrid, ds, testDirectory); + } + } } } } @@ -111,18 +118,21 @@ static void runAll(DataSet ds, static void runOneGraph(List> featureSets, int M, - int searchDepthConstruction, + int efConstruction, + float neighborOverflow, + boolean addHierarchy, VectorCompressor buildCompressor, List> compressionGrid, - List overquerySearchOptions, + List efSearchOptions, + List usePruningGrid, DataSet ds, Path testDirectory) throws IOException { Map, GraphIndex> indexes; if (buildCompressor == null) { - indexes = buildInMemory(featureSets, M, searchDepthConstruction, ds, testDirectory); + indexes = buildInMemory(featureSets, M, efConstruction, neighborOverflow, addHierarchy, ds, testDirectory); } else { - indexes = buildOnDisk(featureSets, M, searchDepthConstruction, ds, testDirectory, buildCompressor); + indexes = buildOnDisk(featureSets, M, efConstruction, neighborOverflow, addHierarchy, ds, testDirectory, buildCompressor); } try { @@ -139,9 +149,9 @@ static void runOneGraph(List> featureSets, } indexes.forEach((features, index) -> { - try (var cs = new ConfiguredSystem(ds, index instanceof OnDiskGraphIndex ? new CachingGraphIndex((OnDiskGraphIndex) index) : index, cv, + try (var cs = new ConfiguredSystem(ds, index, cv, index instanceof OnDiskGraphIndex ? ((OnDiskGraphIndex) index).getFeatureSet() : Set.of())) { - testConfiguration(cs, overquerySearchOptions); + testConfiguration(cs, efSearchOptions, usePruningGrid); } catch (Exception e) { throw new RuntimeException(e); } @@ -159,7 +169,9 @@ static void runOneGraph(List> featureSets, private static Map, GraphIndex> buildOnDisk(List> featureSets, int M, - int searchDepthConstruction, + int efConstruction, + float neighborOverflow, + boolean addHierarchy, DataSet ds, Path testDirectory, VectorCompressor buildCompressor) @@ -169,7 +181,7 @@ private static Map, GraphIndex> buildOnDisk(List, OnDiskGraphIndexWriter> writers = new HashMap<>(); @@ -252,7 +264,8 @@ private static BuilderWithSuppliers builderWithSuppliers(Set features throws FileNotFoundException { var identityMapper = new OrdinalMapper.IdentityMapper(floatVectors.size() - 1); - var builder = new OnDiskGraphIndexWriter.Builder(onHeapGraph, outPath).withMapper(identityMapper); + var builder = new OnDiskGraphIndexWriter.Builder(onHeapGraph, outPath) + .withMapper(identityMapper); Map> suppliers = new EnumMap<>(FeatureId.class); for (var featureId : features) { switch (featureId) { @@ -291,7 +304,9 @@ public BuilderWithSuppliers(OnDiskGraphIndexWriter.Builder builder, Map, GraphIndex> buildInMemory(List> featureSets, int M, - int searchDepthConstruction, + int efConstruction, + float neighborOverflow, + boolean addHierarchy, DataSet ds, Path testDirectory) throws IOException @@ -300,16 +315,29 @@ private static Map, GraphIndex> buildInMemory(List, GraphIndex> indexes = new HashMap<>(); long start; var bsp = BuildScoreProvider.randomAccessScoreProvider(floatVectors, ds.similarityFunction); - GraphIndexBuilder builder = new GraphIndexBuilder(bsp, floatVectors.dimension(), M, searchDepthConstruction, 1.2f, 1.2f); + GraphIndexBuilder builder = new GraphIndexBuilder(bsp, + floatVectors.dimension(), + M, + efConstruction, + neighborOverflow, + 1.2f, + addHierarchy, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool()); start = System.nanoTime(); var onHeapGraph = builder.build(floatVectors); - System.out.format("Build (%s) M=%d overquery=%d in %.2fs with avg degree %.2f and %.2f short edges%n", + System.out.format("Build (%s) M=%d overflow=%.2f ef=%d in %.2fs%n", "full res", M, - searchDepthConstruction, - (System.nanoTime() - start) / 1_000_000_000.0, - onHeapGraph.getAverageDegree(), - builder.getAverageShortEdges()); + neighborOverflow, + efConstruction, + (System.nanoTime() - start) / 1_000_000_000.0); + for (int i = 0; i <= onHeapGraph.getMaxLevel(); i++) { + System.out.format(" L%d: %d nodes, %.2f avg degree%n", + i, + onHeapGraph.getLayerSize(i), + onHeapGraph.getAverageDegree(i)); + } int n = 0; for (var features : featureSets) { if (features.contains(FeatureId.FUSED_ADC)) { @@ -333,17 +361,18 @@ private static Map, GraphIndex> buildInMemory(List> cachedCompressors = new IdentityHashMap<>(); - private static void testConfiguration(ConfiguredSystem cs, List overquerySearchOptions) { + private static void testConfiguration(ConfiguredSystem cs, List efSearchOptions, List usePruningGrid) { var topK = cs.ds.groundTruth.get(0).size(); System.out.format("Using %s:%n", cs.index); - for (var overquery : overquerySearchOptions) { + for (var overquery : efSearchOptions) { var start = System.nanoTime(); int rerankK = (int) (topK * overquery); - var pqr = performQueries(cs, topK, rerankK, 2); - var recall = ((double) pqr.topKFound) / (2 * cs.ds.queryVectors.size() * topK); - System.out.format(" Query top %d/%d recall %.4f in %.2fs after %,d nodes visited%n", - topK, rerankK, recall, (System.nanoTime() - start) / 1_000_000_000.0, pqr.nodesVisited); - + for (var usePruning : usePruningGrid) { + var pqr = performQueries(cs, topK, rerankK, usePruning, 2); + var recall = ((double) pqr.topKFound) / (2 * cs.ds.queryVectors.size() * topK); + System.out.format(" Query top %d/%d recall %.4f in %.2fms after %,d nodes visited with pruning=%b%n", + topK, rerankK, recall, (System.nanoTime() - start) / 1_000_000.0, pqr.nodesVisited, usePruning); + } } } @@ -401,7 +430,7 @@ private static long topKCorrect(int topK, SearchResult.NodeScore[] nn, Set queryVector = vectorTypeSupport.createFloatVector(ctx.dimension); @@ -260,7 +262,7 @@ String search(String input, SessionContext ctx) { ? ((GraphIndex.ScoringView) view).rerankerFor(queryVector, ctx.similarityFunction) : ctx.ravv.rerankerFor(queryVector, ctx.similarityFunction); var ssp = new SearchScoreProvider(sf, rr); - r = new GraphSearcher(ctx.index).search(ssp, searchOverquery, Bits.ALL); + r = new GraphSearcher(ctx.index).search(ssp, searchEf, Bits.ALL); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index 9daaf809b..3a206de59 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -27,10 +27,10 @@ import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.SearchResult; -import io.github.jbellis.jvector.graph.disk.Feature; -import io.github.jbellis.jvector.graph.disk.FeatureId; -import io.github.jbellis.jvector.graph.disk.InlineVectors; -import io.github.jbellis.jvector.graph.disk.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; @@ -87,7 +87,8 @@ public static void siftInMemory(ArrayList> baseVectors) throws IO 16, // graph degree 100, // construction search depth 1.2f, // allow degree overflow during construction by this factor - 1.2f)) // relax neighbor diversity requirement by this factor + 1.2f, // relax neighbor diversity requirement by this factor + false)) { // build the index (in memory) OnHeapGraphIndex index = builder.build(ravv); @@ -112,7 +113,7 @@ public static void siftInMemoryWithSearcher(ArrayList> baseVector RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); - try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f)) { + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, false)) { OnHeapGraphIndex index = builder.build(ravv); // search for a random vector using a GraphSearcher and SearchScoreProvider @@ -133,7 +134,7 @@ public static void siftInMemoryWithRecall(List> baseVectors, List RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); - try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f)) { + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, false)) { OnHeapGraphIndex index = builder.build(ravv); // measure our recall against the (exactly computed) ground truth Function, SearchScoreProvider> sspFactory = q -> SearchScoreProvider.exact(q, VectorSimilarityFunction.EUCLIDEAN, ravv); @@ -148,7 +149,7 @@ public static void siftPersisted(List> baseVectors, List> baseVectors, List> baseVectors, List> baseVectors, List< Path indexPath = Files.createTempFile("siftsmall", ".inline"); Path pqPath = Files.createTempFile("siftsmall", ".pq"); // Builder creation looks mostly the same - try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f); + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, false); // explicit Writer for the first time, this is what's behind OnDiskGraphIndex.write OnDiskGraphIndexWriter writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexPath) .with(new NVQ(nvq)) diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java index d2c0cbb88..679fe8dac 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MMapReader.java @@ -54,7 +54,16 @@ public int readInt() { } @Override - public float readFloat() throws IOException { + public long readLong() { + try { + return buffer.memory().getLong(position); + } finally { + position += Long.BYTES; + } + } + + @Override + public float readFloat() { try { return buffer.memory().getFloat(position); } finally { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java index a3d26a4a3..530fa0565 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/MultiFileDatasource.java @@ -54,6 +54,10 @@ public DataSet load() throws IOException { } public static Map byName = new HashMap<>() {{ + put("degen-200k", new MultiFileDatasource("degen-200k", + "ada-degen/degen_base_vectors.fvec", + "ada-degen/degen_query_vectors.fvec", + "ada-degen/degen_ground_truth.ivec")); put("cohere-english-v3-100k", new MultiFileDatasource("cohere-english-v3-100k", "wikipedia_squad/100k/cohere_embed-english-v3.0_1024_base_vectors_100000.fvec", "wikipedia_squad/100k/cohere_embed-english-v3.0_1024_query_vectors_10000.fvec", diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java index cbe93ce93..e46e6db5d 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/disk/MemorySegmentReader.java @@ -100,6 +100,13 @@ public int readInt() { return k; } + @Override + public long readLong() { + var l = memory.get(longLayout, position); + position += 8; + return l; + } + @Override public float readFloat() { var f = memory.get(floatLayout, position); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 1068c86e1..a97e4858f 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -22,10 +22,11 @@ import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.OnHeapGraphIndex; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; -import io.github.jbellis.jvector.graph.disk.Feature; -import io.github.jbellis.jvector.graph.disk.FeatureId; -import io.github.jbellis.jvector.graph.disk.FusedADC; -import io.github.jbellis.jvector.graph.disk.InlineVectors; +import io.github.jbellis.jvector.graph.disk.CommonHeader; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.quantization.PQVectors; @@ -162,17 +163,17 @@ public static void writeFusedGraph(GraphIndex graph, RandomAccessVectorValues ra } } - public static Set getNeighborNodes(GraphIndex.View g, int node) { + public static Set getNeighborNodes(GraphIndex.View g, int level, int node) { Set neighbors = new HashSet<>(); - for (var it = g.getNeighborsIterator(node); it.hasNext(); ) { + for (var it = g.getNeighborsIterator(level, node); it.hasNext(); ) { int n = it.nextInt(); neighbors.add(n); } return neighbors; } - static List sortedNodes(GraphIndex h) { - var graphNodes = h.getNodes(); + static List sortedNodes(GraphIndex h, int level) { + var graphNodes = h.getNodes(level); // TODO List nodes = new ArrayList<>(); while (graphNodes.hasNext()) { nodes.add(graphNodes.next()); @@ -182,38 +183,41 @@ static List sortedNodes(GraphIndex h) { } public static void assertGraphEquals(GraphIndex g, GraphIndex h) { - // construct these up front since they call seek which will mess up our test loop - String prettyG = GraphIndex.prettyPrint(g); - String prettyH = GraphIndex.prettyPrint(h); - assertEquals(String.format("the number of nodes in the graphs are different:%n%s%n%s", - prettyG, - prettyH), - g.size(), - h.size()); - - // assert equal nodes in each graph - List hNodes = sortedNodes(h); - List gNodes = sortedNodes(g); - assertEquals(String.format("nodes in the graphs are different:%n%s%n%s", - prettyG, - prettyH), - gNodes, - hNodes); - - // assert equal nodes' neighbours in each graph - NodesIterator gNodesIterator = g.getNodes(); - var gv = g.getView(); - var hv = h.getView(); - while (gNodesIterator.hasNext()) { - int node = gNodesIterator.nextInt(); - assertEqualsLazy(() -> String.format("arcs differ for node %d%n%s%n%s", - node, - prettyG, - prettyH), - getNeighborNodes(gv, node), - getNeighborNodes(hv, node)); - } - } + // construct these up front since they call seek which will mess up our test loop + String prettyG = GraphIndex.prettyPrint(g); + String prettyH = GraphIndex.prettyPrint(h); + assertEquals(String.format("the number of nodes in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + g.size(), + h.size()); + + assertEquals(g.getView().entryNode(), h.getView().entryNode()); + for (int level = 0; level <= g.getMaxLevel(); level++) { + // assert equal nodes in each graph + List hNodes = sortedNodes(h, level); + List gNodes = sortedNodes(g, level); + assertEquals(String.format("nodes in the graphs are different:%n%s%n%s", + prettyG, + prettyH), + gNodes, + hNodes); + + // assert equal nodes' neighbours in each graph + NodesIterator gNodesIterator = g.getNodes(level); + var gv = g.getView(); + var hv = h.getView(); + while (gNodesIterator.hasNext()) { + int node = gNodesIterator.nextInt(); + assertEqualsLazy(() -> String.format("arcs differ for node %d%n%s%n%s", + node, + prettyG, + prettyH), + getNeighborNodes(gv, level, node), + getNeighborNodes(hv, level, node)); + } + } + } /** * For when building the failure message is expensive @@ -234,22 +238,31 @@ public static OnHeapGraphIndex buildSequentially(GraphIndexBuilder builder, Rand public static class FullyConnectedGraphIndex implements GraphIndex { private final int entryNode; - private final int size; + private final List layerSizes; public FullyConnectedGraphIndex(int entryNode, int size) { + this(entryNode, List.of(size)); + } + + public FullyConnectedGraphIndex(int entryNode, List layerSizes) { this.entryNode = entryNode; - this.size = size; + this.layerSizes = layerSizes; } @Override - public int size() - { - return size; + public int size(int level) { + return layerSizes.get(level); + } + + @Override + public int maxDegree() { + return layerSizes.stream().mapToInt(i -> i).max().orElseThrow(); } @Override - public NodesIterator getNodes() { - return new NodesIterator.ArrayNodesIterator(IntStream.range(0, size).toArray(), size); + public NodesIterator getNodes(int level) { + int n = layerSizes.get(level); + return new NodesIterator.ArrayNodesIterator(IntStream.range(0, n).toArray(), n); } @Override @@ -258,8 +271,13 @@ public View getView() { } @Override - public int maxDegree() { - return size - 1; + public int getDegree(int level) { + return layerSizes.get(level) - 1; + } + + @Override + public int getMaxLevel() { + return layerSizes.size() - 1; } @Override @@ -267,18 +285,20 @@ public void close() { } private class FullyConnectedGraphIndexView implements View { @Override - public NodesIterator getNeighborsIterator(int node) { - return new NodesIterator.ArrayNodesIterator(IntStream.range(0, size).filter(i -> i != node).toArray() , size - 1); + public NodesIterator getNeighborsIterator(int level, int node) { + return new NodesIterator.ArrayNodesIterator(IntStream.range(0, layerSizes.get(level)) + .filter(i -> i != node).toArray(), + layerSizes.get(level) - 1); } @Override public int size() { - return size; + return FullyConnectedGraphIndex.this.size(0); } @Override - public int entryNode() { - return entryNode; + public NodeAtLevel entryNode() { + return new NodeAtLevel(layerSizes.size() - 1, entryNode); } @Override @@ -297,41 +317,60 @@ public long ramBytesUsed() { } public static class RandomlyConnectedGraphIndex implements GraphIndex { - private final int size; - private final Map nodes; + private final List layerInfo; + private final List> layerAdjacency; private final int entryNode; - public RandomlyConnectedGraphIndex(int size, int M, Random random) { - this.size = size; - this.nodes = new ConcurrentHashMap<>(); - - var maxNeighbors = Math.min(M, size - 1); - var nodeIds = IntStream.range(0, size).boxed().collect(Collectors.toCollection(ArrayList::new)); - Collections.shuffle(nodeIds, random); - - for (int i = 0; i < size; i++) { - Set neighborSet = new HashSet<>(); - while (neighborSet.size() < maxNeighbors) { - var neighborIdx = random.nextInt(size); - if (neighborIdx != i) { - neighborSet.add(nodeIds.get(neighborIdx)); + public RandomlyConnectedGraphIndex(List layerInfo, Random random) { + this.layerInfo = layerInfo; + this.layerAdjacency = new ArrayList<>(layerInfo.size()); + + // Build adjacency for each layer + for (int level = 0; level < layerInfo.size(); level++) { + int size = layerInfo.get(level).size; + int maxNeighbors = layerInfo.get(level).degree; + Map adjacency = new ConcurrentHashMap<>(); + + // Generate node IDs in random order + var nodeIds = IntStream.range(0, size).boxed().collect(Collectors.toCollection(ArrayList::new)); + Collections.shuffle(nodeIds, random); + + // Fill adjacency + for (int i = 0; i < size; i++) { + Set neighborSet = new HashSet<>(); + while (neighborSet.size() < maxNeighbors) { + int neighborIdx = random.nextInt(size); + if (neighborIdx != i) { + neighborSet.add(nodeIds.get(neighborIdx)); + } } - nodes.put(nodeIds.get(i), neighborSet.stream().mapToInt(Integer::intValue).toArray()); + adjacency.put(nodeIds.get(i), neighborSet.stream().mapToInt(Integer::intValue).toArray()); } + layerAdjacency.add(adjacency); } + + // Pick an entry node from the top layer + this.entryNode = random.nextInt(layerInfo.get(layerInfo.size() - 1).size); + } - this.entryNode = random.nextInt(size); + public RandomlyConnectedGraphIndex(int size, int M, Random random) { + this(List.of(new CommonHeader.LayerInfo(size, M)), random); } @Override - public int size() - { - return size; + public int getMaxLevel() { + return layerInfo.size() - 1; } @Override - public NodesIterator getNodes() { - return new NodesIterator.ArrayNodesIterator(IntStream.range(0, size).toArray(), size); + public int size(int level) { + return layerInfo.get(level).size; + } + + @Override + public NodesIterator getNodes(int level) { + int sz = layerInfo.get(level).size; + return new NodesIterator.ArrayNodesIterator(IntStream.range(0, sz).toArray(), sz); } @Override @@ -339,9 +378,14 @@ public View getView() { return new RandomlyConnectedGraphIndexView(); } + @Override + public int getDegree(int level) { + return layerInfo.get(level).degree; + } + @Override public int maxDegree() { - return nodes.get(0).length; + return layerInfo.stream().mapToInt(li -> li.degree).max().orElseThrow(); } @Override @@ -349,17 +393,18 @@ public void close() { } private class RandomlyConnectedGraphIndexView implements View { @Override - public NodesIterator getNeighborsIterator(int node) { - return new NodesIterator.ArrayNodesIterator(nodes.get(node)); + public NodesIterator getNeighborsIterator(int level, int node) { + var adjacency = layerAdjacency.get(level); + return new NodesIterator.ArrayNodesIterator(adjacency.get(node)); } public int size() { - return size; + return layerInfo.get(0).size; } @Override - public int entryNode() { - return entryNode; + public NodeAtLevel entryNode() { + return new NodeAtLevel(getMaxLevel(), entryNode); } @Override diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 460b95fc7..6300c2908 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -35,6 +35,7 @@ import java.util.ArrayList; import java.util.function.Supplier; +import static io.github.jbellis.jvector.TestUtil.assertGraphEquals; import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -55,8 +56,13 @@ public void tearDown() { TestUtil.deleteQuietly(testDirectory); } - @Test + @Test public void testRescore() { + testRescore(false); + testRescore(true); + } + + public void testRescore(boolean addHierarchy) { // Create test vectors where each vector is [node_id, 0] var vectors = new ArrayList>(); vectors.add(vts.createFloatVector(new float[] {0, 0})); @@ -66,13 +72,13 @@ public void testRescore() { // Initial score provider uses dot product, so scores will equal node IDs var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); - var builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f, addHierarchy); // Add 3 nodes builder.addGraphNode(0, ravv.getVector(0)); builder.addGraphNode(1, ravv.getVector(1)); builder.addGraphNode(2, ravv.getVector(2)); - var neighbors = builder.graph.getNeighbors(0); + var neighbors = builder.graph.getNeighbors(0, 0); // TODO assertEquals(1, neighbors.getNode(0)); assertEquals(2, neighbors.getNode(1)); assertEquals(0.5f, neighbors.getScore(0), 1E-6); @@ -95,7 +101,7 @@ public void testRescore() { assertTrue(newGraph.containsNode(2)); // Check node 0's neighbors, score and order should be different - var newNeighbors = newGraph.getNeighbors(0); + var newNeighbors = newGraph.getNeighbors(0, 0); // TODO assertEquals(2, newNeighbors.getNode(0)); assertEquals(1, newNeighbors.getNode(1)); assertEquals(0.2f, newNeighbors.getScore(0), 1E-6); @@ -106,16 +112,18 @@ public void testRescore() { @Test public void testSaveAndLoad() throws IOException { int dimension = randomIntBetween(2, 32); - var ravv = MockVectorValues.fromValues(createRandomFloatVectors(randomIntBetween(10, 100), dimension, getRandom())); + int size = randomIntBetween(10, 100); + var ravv = MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, getRandom())); + Supplier newBuilder = () -> - new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true); var indexDataPath = testDirectory.resolve("index_builder.data"); var builder = newBuilder.get(); - try (var graph = TestUtil.buildSequentially(builder, ravv); - var out = TestUtil.openDataOutputStream(indexDataPath)) - { + var graph = TestUtil.buildSequentially(builder, ravv); + + try (var out = TestUtil.openDataOutputStream(indexDataPath)) { graph.save(out); } @@ -128,5 +136,6 @@ public void testSaveAndLoad() throws IOException { for (int i = 0; i < ravv.size(); i++) { assertTrue(builder.graph.containsNode(i)); } + assertGraphEquals(graph, builder.graph); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java index 3e6f7bcdc..3b9fad594 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java @@ -34,24 +34,26 @@ public class Test2DThreshold extends LuceneTestCase { @Test public void testThreshold10k() throws IOException { for (int i = 0; i < 10; i++) { - testThreshold(10_000, 16, 0.85f, 0.9f); + testThreshold(10_000, 16, 0.85f, 0.9f, false); + testThreshold(10_000, 16, 0.85f, 0.9f, true); } } @Test public void testThreshold20k() throws IOException { for (int i = 0; i < 10; i++) { - testThreshold(20_000, 24, 0.80f, 0.9f); + testThreshold(20_000, 24, 0.80f, 0.9f, false); + testThreshold(20_000, 24, 0.80f, 0.9f, true); } } - public void testThreshold(int graphSize, int maxDegree, float visitedRatioThreshold, float recallThreshold) throws IOException { + public void testThreshold(int graphSize, int maxDegree, float visitedRatioThreshold, float recallThreshold, boolean addHierarchy) throws IOException { var R = getRandom(); // build index VectorFloat[] vectors = TestVectorGraph.createRandomFloatVectors(graphSize, 2, R); var ravv = new ListRandomAccessVectorValues(List.of(vectors), 2); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.EUCLIDEAN, maxDegree, 2 * maxDegree, 1.2f, 1.4f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.EUCLIDEAN, maxDegree, 2 * maxDegree, 1.2f, 1.4f, addHierarchy); var onHeapGraph = builder.build(ravv); // test raw vectors diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java index 8d841ac26..56a35fe42 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java @@ -20,39 +20,35 @@ import io.github.jbellis.jvector.LuceneTestCase; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; -import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.util.Bits; -import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.junit.Test; import java.io.IOException; -import java.io.RandomAccessFile; import java.nio.file.Files; -import java.util.concurrent.ForkJoinPool; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; import java.util.stream.IntStream; import static io.github.jbellis.jvector.TestUtil.assertGraphEquals; import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; -import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectorsParallel; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestDeletions extends LuceneTestCase { @Test public void testMarkDeleted() { + testMarkDeleted(false); + testMarkDeleted(true); + } + + public void testMarkDeleted(boolean addHierarchy) { // graph of 10 vectors int dimension = 2; var ravv = MockVectorValues.fromValues(createRandomFloatVectors(10, dimension, getRandom())); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy); var graph = TestUtil.buildSequentially(builder, ravv); // delete a random entry @@ -67,7 +63,7 @@ public void testMarkDeleted() { } } // check that asking for the entire graph back still doesn't surface the deleted one - var v = ravv.getVector(n); + var v = ravv.getVector(n).copy(); var results = GraphSearcher.search(v, ravv.size(), ravv, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(GraphIndex.prettyPrint(graph), ravv.size() - 1, results.getNodes().length); for (var ns : results.getNodes()) { @@ -77,10 +73,15 @@ public void testMarkDeleted() { @Test public void testCleanup() throws IOException { + testCleanup(false); + testCleanup(true); + } + + public void testCleanup(boolean addHierarchy) throws IOException { // graph of 10 vectors int dimension = 2; var ravv = MockVectorValues.fromValues(createRandomFloatVectors(10, dimension, getRandom())); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy); var graph = TestUtil.buildSequentially(builder, ravv); // delete all nodes that connect to a random node @@ -88,7 +89,7 @@ public void testCleanup() throws IOException { int nDeleted = 0; try (var view = graph.getView()) { for (var i = 0; i < view.size(); i++) { - for (var it = view.getNeighborsIterator(i); it.hasNext(); ) { + for (var it = view.getNeighborsIterator(0, i); it.hasNext(); ) { // TODO hardcoded level if (nodeToIsolate == it.nextInt()) { builder.markNodeDeleted(i); nDeleted++; @@ -114,7 +115,8 @@ public void testCleanup() throws IOException { try (var out = TestUtil.openDataOutputStream(outputPath)) { graph.save(out); } - var b2 = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + + var b2 = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath)) { b2.load(readerSupplier.get()); } @@ -124,10 +126,15 @@ public void testCleanup() throws IOException { @Test public void testMarkingAllNodesAsDeleted() { + testMarkingAllNodesAsDeleted(false); + testMarkingAllNodesAsDeleted(true); + } + + public void testMarkingAllNodesAsDeleted(boolean addHierarchy) { // build graph int dimension = 2; var ravv = MockVectorValues.fromValues(createRandomFloatVectors(10, dimension, getRandom())); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, addHierarchy); var graph = TestUtil.buildSequentially(builder, ravv); // mark all deleted @@ -138,11 +145,16 @@ public void testMarkingAllNodesAsDeleted() { // removeDeletedNodes should leave the graph empty builder.removeDeletedNodes(); assertEquals(0, graph.size()); - assertEquals(OnHeapGraphIndex.NO_ENTRY_POINT, graph.entry()); + assertNull(graph.entry()); } @Test public void testNoPathToLiveNodesWhenRemovingDeletedNodes2() throws IOException { + testNoPathToLiveNodesWhenRemovingDeletedNodes2(false); + testNoPathToLiveNodesWhenRemovingDeletedNodes2(true); + } + + public void testNoPathToLiveNodesWhenRemovingDeletedNodes2(boolean addHierarchy) throws IOException { var vts = VectorizationProvider.getInstance().getVectorTypeSupport(); var random = getRandom(); // generate two clusters of vectors @@ -157,7 +169,7 @@ public void testNoPathToLiveNodesWhenRemovingDeletedNodes2() throws IOException ); // add the vectors, then delete all the ones from the first (larger) cluster - try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f)) { + try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 4, 10, 1.0f, 1.0f, addHierarchy)) { for (int i = 0; i < 1100; i++) { builder.addGraphNode(i, ravv.getVector(i)); } @@ -167,7 +179,7 @@ public void testNoPathToLiveNodesWhenRemovingDeletedNodes2() throws IOException } builder.cleanup(); - assert builder.graph.getView().entryNode() != OnHeapGraphIndex.NO_ENTRY_POINT; + assert builder.graph.getView().entryNode() != null; } } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index bceb5247f..8e446fc13 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -89,12 +89,18 @@ VectorFloat getTargetVector() { @Test public void testSearchWithSkewedAcceptOrds() { + testSearchWithSkewedAcceptOrds(false); + testSearchWithSkewedAcceptOrds(true); + } + + public void testSearchWithSkewedAcceptOrds(boolean addHierarchy) { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; RandomAccessVectorValues vectors = circularVectorValues(nDoc); getRandom().nextInt(); - GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.0f); + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.0f, addHierarchy); var graph = TestUtil.buildSequentially(builder, vectors); + validateIndex(graph); // Skip over half of the documents that are closest to the query vector FixedBitSet acceptOrds = new FixedBitSet(nDoc); @@ -126,11 +132,18 @@ public void testSearchWithSkewedAcceptOrds() { @Test // build a random graph and check that resuming a search finds the same nodes as an equivalent from-scratch search public void testResume() { + testResume(false); + testResume(true); + } + + // build a random graph and check that resuming a search finds the same nodes as an equivalent from-scratch search + public void testResume(boolean addHierarchy) { int size = 1000; int dim = 2; var vectors = vectorValues(size, dim); - var builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f); + var builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f, addHierarchy); var graph = builder.build(vectors); + validateIndex(graph); Bits acceptOrds = getRandom().nextBoolean() ? Bits.ALL : createRandomAcceptOrds(0, size); int initialTopK = 10; @@ -160,11 +173,18 @@ public void testResume() { @Test // resuming a search should not need to rerank the nodes that were already evaluated public void testRerankCaching() { + testRerankCaching(false); + testRerankCaching(true); + } + + // resuming a search should not need to rerank the nodes that were already evaluated + public void testRerankCaching(boolean addHierarchy) { int size = 1000; int dim = 2; var vectors = vectorValues(size, dim); - var builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f); + var builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f, addHierarchy); var graph = builder.build(vectors); + validateIndex(graph); var pq = ProductQuantization.compute(vectors, 2, 256, false); var pqv = pq.encodeAll(vectors); @@ -187,12 +207,19 @@ public void testRerankCaching() { // If an exception is thrown during search, the next search should still function @Test public void testExceptionalTermination() { + testExceptionalTermination(false); + testExceptionalTermination(true); + } + + // If an exception is thrown during search, the next search should still function + public void testExceptionalTermination(boolean addHierarchy) { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 20, 100, 1.0f, 1.4f); + new GraphIndexBuilder(vectors, similarityFunction, 20, 100, 1.0f, 1.4f, addHierarchy); var graph = TestUtil.buildSequentially(builder, vectors); + validateIndex(graph); // wrap vectors so that the second access to a vector throws an exception var wrappedVectors = new RandomAccessVectorValues() { @@ -245,15 +272,30 @@ public int size() { // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 assertTrue("sum(result docs)=" + sum + " for " + GraphIndex.prettyPrint(builder.graph), sum < 75); + } - for (int i = 0; i < nDoc; i++) { - ConcurrentNeighborMap.Neighbors neighbors = graph.getNeighbors(i); - Iterator it = neighbors.iterator(); - while (it.hasNext()) { - // all neighbors should be valid node ids. - assertTrue(it.next() < nDoc); + private static void validateIndex(OnHeapGraphIndex graph) { + for (int level = graph.getMaxLevel(); level > 0; level--) { + for (var nodeIt = graph.getNodes(level); nodeIt.hasNext(); ) { + var nodeInLevel = nodeIt.nextInt(); + + // node's neighbors should also exist in the same level + var neighbors = graph.getNeighbors(level, nodeInLevel); + for (int neighbor : neighbors.copyDenseNodes()) { + assertNotNull(graph.getNeighbors(level, neighbor)); + } + + // node should exist at every layer below it + for (int lowerLevel = level - 1; lowerLevel >= 0; lowerLevel--) { + assertNotNull(graph.getNeighbors(lowerLevel, nodeInLevel)); + } } } + + // no holes in lowest level (not true for all graphs but true for the ones constructed here) + for (int i = 0; i < graph.getIdUpperBound(); i++) { + assertNotNull(graph.getNeighbors(0, i)); + } } // Make sure we actually approximately find the closest k elements. Mostly this is about @@ -261,12 +303,21 @@ public int size() { // oriented in the right directions @Test public void testAknnDiverse() { + testAknnDiverse(false); + testAknnDiverse(true); + } + + // Make sure we actually approximately find the closest k elements. Mostly this is about + // ensuring that we have all the distance functions, comparators, priority queues and so on + // oriented in the right directions + public void testAknnDiverse(boolean addHierarchy) { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomAccessVectorValues vectors = circularVectorValues(nDoc); GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 20, 100, 1.0f, 1.4f); + new GraphIndexBuilder(vectors, similarityFunction, 20, 100, 1.0f, 1.4f, addHierarchy); var graph = TestUtil.buildSequentially(builder, vectors); + validateIndex(graph); // run some searches SearchResult.NodeScore[] nn = GraphSearcher.search(getTargetVector(), 10, @@ -284,25 +335,22 @@ public void testAknnDiverse() { // We expect to get approximately 100% recall; // the lowest docIds are closest to zero; sum(0,9) = 45 assertTrue("sum(result docs)=" + sum + " for " + GraphIndex.prettyPrint(builder.graph), sum < 75); - - for (int i = 0; i < nDoc; i++) { - ConcurrentNeighborMap.Neighbors neighbors = graph.getNeighbors(i); - Iterator it = neighbors.iterator(); - while (it.hasNext()) { - // all neighbors should be valid node ids. - assertTrue(it.next() < nDoc); - } - } } @Test public void testSearchWithAcceptOrds() { + testSearchWithAcceptOrds(false); + testSearchWithAcceptOrds(true); + } + + public void testSearchWithAcceptOrds(boolean addHierarchy) { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.4f); + new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.4f, addHierarchy); var graph = TestUtil.buildSequentially(builder, vectors); + validateIndex(graph); // the first 10 docs must not be deleted to ensure the expected recall Bits acceptOrds = createRandomAcceptOrds(10, nDoc); SearchResult.NodeScore[] nn = GraphSearcher.search(getTargetVector(), @@ -326,12 +374,18 @@ public void testSearchWithAcceptOrds() { @Test public void testSearchWithSelectiveAcceptOrds() { + testSearchWithSelectiveAcceptOrds(false); + testSearchWithSelectiveAcceptOrds(true); + } + + public void testSearchWithSelectiveAcceptOrds(boolean addHierarchy) { int nDoc = 100; RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.4f); + new GraphIndexBuilder(vectors, similarityFunction, 32, 100, 1.0f, 1.4f, addHierarchy); var graph = TestUtil.buildSequentially(builder, vectors); + validateIndex(graph); // Only mark a few vectors as accepted var acceptOrds = new FixedBitSet(nDoc); for (int i = 0; i < nDoc; i += nextInt(15, 20)) { @@ -364,19 +418,24 @@ public void testSearchWithSelectiveAcceptOrds() { @Test public void testGraphIndexBuilderInvalid() { + testGraphIndexBuilderInvalid(false); + testGraphIndexBuilderInvalid(true); + } + + public void testGraphIndexBuilderInvalid(boolean addHierarchy) { assertThrows(NullPointerException.class, - () -> new GraphIndexBuilder(null, null, 0, 0, 1.0f, 1.0f)); + () -> new GraphIndexBuilder(null, null, 0, 0, 1.0f, 1.0f, addHierarchy)); // M must be > 0 assertThrows(IllegalArgumentException.class, () -> { RandomAccessVectorValues vectors = vectorValues(1, 1); - new GraphIndexBuilder(vectors, similarityFunction, 0, 10, 1.0f, 1.0f); + new GraphIndexBuilder(vectors, similarityFunction, 0, 10, 1.0f, 1.0f, addHierarchy); }); // beamWidth must be > 0 assertThrows(IllegalArgumentException.class, () -> { RandomAccessVectorValues vectors = vectorValues(1, 1); - new GraphIndexBuilder(vectors, similarityFunction, 10, 0, 1.0f, 1.0f); + new GraphIndexBuilder(vectors, similarityFunction, 10, 0, 1.0f, 1.0f, addHierarchy); }); } @@ -387,6 +446,11 @@ public void testRamUsageEstimate() { @Test public void testDiversity() { + testDiversity(false); + testDiversity(true); + } + + public void testDiversity(boolean addHierarchy) { similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; // Some carefully checked test cases with simple 2d vectors on the unit circle: VectorFloat[] values = { @@ -401,7 +465,7 @@ public void testDiversity() { MockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 4, 10, 1.0f, 1.0f); + new GraphIndexBuilder(vectors, similarityFunction, 4, 10, 1.0f, 1.0f, addHierarchy); // node 0 is added by the builder constructor builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); @@ -441,6 +505,11 @@ public void testDiversity() { @Test public void testDiversityFallback() { + testDiversityFallback(false); + testDiversityFallback(true); + } + + public void testDiversityFallback(boolean addHierarchy) { similarityFunction = VectorSimilarityFunction.EUCLIDEAN; // Some test cases can't be exercised in two dimensions; // in particular if a new neighbor displaces an existing neighbor @@ -456,7 +525,7 @@ public void testDiversityFallback() { MockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 2, 10, 1.0f, 1.0f); + new GraphIndexBuilder(vectors, similarityFunction, 2, 10, 1.0f, 1.0f, addHierarchy); builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); builder.addGraphNode(2, vectors.getVector(2)); @@ -476,6 +545,11 @@ public void testDiversityFallback() { @Test public void testDiversity3d() { + testDiversity3d(false); + testDiversity3d(true); + } + + public void testDiversity3d(boolean addHierarchy) { similarityFunction = VectorSimilarityFunction.EUCLIDEAN; // test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives VectorFloat[] values = { @@ -487,7 +561,7 @@ public void testDiversity3d() { MockVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list GraphIndexBuilder builder = - new GraphIndexBuilder(vectors, similarityFunction, 2, 10, 1.0f, 1.0f); + new GraphIndexBuilder(vectors, similarityFunction, 2, 10, 1.0f, 1.0f, addHierarchy); builder.addGraphNode(0, vectors.getVector(0)); builder.addGraphNode(1, vectors.getVector(1)); builder.addGraphNode(2, vectors.getVector(2)); @@ -507,7 +581,7 @@ public void testDiversity3d() { private void assertNeighbors(OnHeapGraphIndex graph, int node, int... expected) { Arrays.sort(expected); - ConcurrentNeighborMap.Neighbors nn = graph.getNeighbors(node); + ConcurrentNeighborMap.Neighbors nn = graph.getNeighbors(0, node); // TODO Iterator it = nn.iterator(); int[] actual = new int[nn.size()]; for (int i = 0; i < actual.length; i++) { @@ -520,12 +594,19 @@ private void assertNeighbors(OnHeapGraphIndex graph, int node, int... expected) @Test // build a random graph, then check that it has at least 90% recall public void testRandom() { + testRandom(false); + testRandom(true); + } + + // build a random graph, then check that it has at least 90% recall + public void testRandom(boolean addHierarchy) { int size = between(100, 150); int dim = between(2, 15); MockVectorValues vectors = vectorValues(size, dim); int topK = 5; - GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f); + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f, addHierarchy); var graph = builder.build(vectors); + validateIndex(graph); Bits acceptOrds = getRandom().nextBoolean() ? Bits.ALL : createRandomAcceptOrds(0, size); int efSearch = 100; @@ -579,22 +660,33 @@ private int computeOverlap(int[] a, int[] b) { @Test public void testConcurrentNeighbors() { + testConcurrentNeighbors(false); + testConcurrentNeighbors(true); + } + + public void testConcurrentNeighbors(boolean addHierarchy) { RandomAccessVectorValues vectors = circularVectorValues(100); - GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 2, 30, 1.0f, 1.4f); + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 2, 30, 1.0f, 1.4f, addHierarchy); var graph = builder.build(vectors); + validateIndex(graph); for (int i = 0; i < vectors.size(); i++) { - assertTrue(graph.getNeighbors(i).size() <= 2); + assertTrue(graph.getNeighbors(0, i).size() <= 2); // TODO } } @Test - public void testZeroCentroid() - { + public void testZeroCentroid() { + testZeroCentroid(false); + testZeroCentroid(true); + } + + public void testZeroCentroid(boolean addHierarchy) { var rawVectors = List.of(vectorTypeSupport.createFloatVector(new float[] {-1, -1}), vectorTypeSupport.createFloatVector(new float[] {1, 1})); var vectors = new ListRandomAccessVectorValues(rawVectors, 2); - var builder = new GraphIndexBuilder(vectors, VectorSimilarityFunction.COSINE, 2, 2, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(vectors, VectorSimilarityFunction.COSINE, 2, 2, 1.0f, 1.0f, addHierarchy); try (var graph = builder.build(vectors)) { + validateIndex(graph); var qv = vectorTypeSupport.createFloatVector(new float[] {0.5f, 0.5f}); var results = GraphSearcher.search(qv, 1, vectors, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(1, results.getNodes().length); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestGraphCache.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestGraphCache.java deleted file mode 100644 index 1fd6cb899..000000000 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestGraphCache.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.graph.disk; - -import com.carrotsearch.randomizedtesting.RandomizedTest; -import io.github.jbellis.jvector.TestUtil; -import io.github.jbellis.jvector.disk.SimpleMappedReader; -import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.RandomAccessVectorValues; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import static io.github.jbellis.jvector.TestUtil.writeGraph; -import static org.junit.Assert.*; - -public class TestGraphCache extends RandomizedTest { - private Path testDirectory; - private Path onDiskGraphIndexPath; - private RandomAccessVectorValues vectors; - - - @Before - public void setup() throws IOException { - var fullyConnectedGraph = new TestUtil.FullyConnectedGraphIndex(0, 6); - vectors = new ListRandomAccessVectorValues(IntStream.range(0, 6).mapToObj(i -> TestUtil.randomVector(getRandom(), 2)).collect(Collectors.toList()), 2); - testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); - onDiskGraphIndexPath = testDirectory.resolve("fullyConnectedGraph"); - writeGraph(fullyConnectedGraph, vectors, onDiskGraphIndexPath); - } - - @After - public void tearDown() { - TestUtil.deleteQuietly(testDirectory); - } - - @Test - public void testGraphCacheLoading() throws Exception { - try (var readerSupplier = new SimpleMappedReader.Supplier(onDiskGraphIndexPath); - var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) - { - var none = GraphCache.load(onDiskGraph, -1); - assertEquals(0, none.ramBytesUsed()); - assertNull(none.getNode(0)); - var zero = GraphCache.load(onDiskGraph, 0); - assertNotNull(zero.getNode(0)); - assertNull(zero.getNode(1)); - var one = GraphCache.load(onDiskGraph, 1); - // move from caching entry node to entry node + all its neighbors (5) - assertEquals(one.ramBytesUsed(), zero.ramBytesUsed() * (onDiskGraph.size())); - for (int i = 0; i < 6; i++) { - // fully connected, - assertEquals(one.getNode(i).neighbors.length, onDiskGraph.maxDegree()); - } - } - } -} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index daaa736af..020930301 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -21,12 +21,22 @@ import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.GraphSearcher; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.NodesIterator; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.TestVectorGraph; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedVectors; +import io.github.jbellis.jvector.quantization.NVQuantization; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import org.junit.After; import org.junit.Before; @@ -37,8 +47,10 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.*; +import java.util.concurrent.ThreadLocalRandom; import static io.github.jbellis.jvector.TestUtil.getNeighborNodes; +import static io.github.jbellis.jvector.TestUtil.randomVector; import static org.junit.Assert.*; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @@ -81,25 +93,30 @@ public void testSimpleGraphs() throws Exception { @Test public void testRenumberingOnDelete() throws IOException { + testRenumberingOnDelete(false); + testRenumberingOnDelete(true); + } + + public void testRenumberingOnDelete(boolean addHierarchy) throws IOException { // graph of 3 vectors var ravv = new TestVectorGraph.CircularFloatVectorValues(3); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, addHierarchy); var original = TestUtil.buildSequentially(builder, ravv); // delete the first node builder.markNodeDeleted(0); builder.cleanup(); - builder.setEntryPoint(builder.getGraph().getIdUpperBound() - 1); + builder.setEntryPoint(0, builder.getGraph().getIdUpperBound() - 1); // TODO // check assertEquals(2, original.size()); var originalView = original.getView(); // 1 -> 2 - assertEquals(1, getNeighborNodes(originalView, 1).size()); - assertTrue(getNeighborNodes(originalView, 1).contains(2)); + assertEquals(1, getNeighborNodes(originalView, 0, 1).size()); + assertTrue(getNeighborNodes(originalView, 0, 1).contains(2)); // 2 -> 1 - assertEquals(1, getNeighborNodes(originalView, 2).size()); - assertTrue(getNeighborNodes(originalView, 2).contains(1)); + assertEquals(1, getNeighborNodes(originalView, 0, 2).size()); + assertTrue(getNeighborNodes(originalView, 0, 2).contains(1)); // create renumbering map Map oldToNewMap = OnDiskGraphIndexWriter.sequentialRenumbering(original); @@ -116,21 +133,25 @@ public void testRenumberingOnDelete() throws IOException { var onDiskView = onDiskGraph.getView()) { // entry point renumbering - assertNotNull(onDiskView.getVector(onDiskGraph.entryNode)); + assertNotNull(onDiskView.getVector(onDiskGraph.entryNode.node)); // 0 -> 1 - assertTrue(getNeighborNodes(onDiskView, 0).contains(1)); + assertTrue(getNeighborNodes(onDiskView, 0, 0).contains(1)); // 1 -> 0 - assertTrue(getNeighborNodes(onDiskView, 1).contains(0)); + assertTrue(getNeighborNodes(onDiskView, 0, 1).contains(0)); } catch (Exception e) { throw new RuntimeException(e); } } - @Test public void testReorderingRenumbering() throws IOException { + testReorderingRenumbering(false); + testReorderingRenumbering(true); + } + + public void testReorderingRenumbering(boolean addHierarchy) throws IOException { // graph of 3 vectors var ravv = new TestVectorGraph.CircularFloatVectorValues(3); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, addHierarchy); var original = TestUtil.buildSequentially(builder, ravv); // create renumbering map @@ -157,9 +178,14 @@ public void testReorderingRenumbering() throws IOException { @Test public void testReorderingWithHoles() throws IOException { + testReorderingWithHoles(false); + testReorderingWithHoles(true); + } + + public void testReorderingWithHoles(boolean addHierarchy) throws IOException { // graph of 3 vectors var ravv = new TestVectorGraph.CircularFloatVectorValues(3); - var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, addHierarchy); var original = TestUtil.buildSequentially(builder, ravv); // create renumbering map @@ -179,7 +205,7 @@ public void testReorderingWithHoles() throws IOException { assertEquals(11, onDiskGraph.getIdUpperBound()); Set nodesInGraph = new HashSet<>(); - for (NodesIterator it = onDiskGraph.getNodes(); it.hasNext(); ) { + for (NodesIterator it = onDiskGraph.getNodes(0); it.hasNext(); ) { nodesInGraph.add(it.next()); } assertEquals(nodesInGraph, Set.of(0, 2, 10)); @@ -194,7 +220,93 @@ public void testReorderingWithHoles() throws IOException { private static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVectorValues ravv) { for (int i = 0; i < view.size(); i++) { - assertEquals("Incorrect vector at " + i, view.getVector(i), ravv.getVector(i)); + assertEquals("Incorrect vector at " + i, ravv.getVector(i), view.getVector(i)); + } + } + + private static void validateSeparatedNVQ(OnDiskGraphIndex.View view, + RandomAccessVectorValues ravv, + NVQuantization nvq) throws IOException + { + assertEquals("Sizes differ", ravv.size(), view.size()); + // Reusable buffer for decoding + var quantized = NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, + nvq.bitsPerDimension); + for (int i = 0; i < view.size(); i++) { + try (var reader = view.featureReaderForNode(i, FeatureId.SEPARATED_NVQ)) { + NVQuantization.QuantizedVector.loadInto(reader, quantized); + } + // sanity check? + } + } + + @Test + public void testSimpleGraphSeparated() throws Exception { + for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph)) { + var outputPath = testDirectory.resolve("test_graph_separated_" + graph.getClass().getSimpleName()); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size()); + + // Write graph with SEPARATED_VECTORS + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath) + .with(new SeparatedVectors(ravv.dimension(), 0L)) + .build()) + { + writer.write(Feature.singleStateFactory( + FeatureId.SEPARATED_VECTORS, + nodeId -> new InlineVectors.State(ravv.getVector(nodeId)) + )); + } + + // Read and validate + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier); + var onDiskView = onDiskGraph.getView()) + { + TestUtil.assertGraphEquals(graph, onDiskGraph); + validateVectors(onDiskView, ravv); + } + } + } + + @Test + public void testLargeGraphSeparatedNVQ() throws Exception { + // Build a large-ish graph + var nodeCount = 100_000; + var maxDegree = 32; + var graph = new TestUtil.RandomlyConnectedGraphIndex(nodeCount, maxDegree, getRandom()); + var outputPath = testDirectory.resolve("large_graph_nvq"); + + // Create random vectors + var dimension = 64; + var vectors = TestUtil.createRandomVectors(nodeCount, dimension); + var ravv = new ListRandomAccessVectorValues(vectors, dimension); + + // Compute NVQ and build a SeparatedNVQ feature + var nvq = NVQuantization.compute(ravv, /* e.g. subquantizers=2 */ 2); + var separatedNVQ = new SeparatedNVQ(nvq, 0L); + + // Write the graph with SEPARATED_NVQ + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath) + .with(separatedNVQ) + .build()) + { + // Provide the states for each node + writer.write(Feature.singleStateFactory( + FeatureId.SEPARATED_NVQ, + nodeId -> new NVQ.State(nvq.encode(ravv.getVector(nodeId))) + )); + } + + // Read back the graph & check structure, then decode vectors + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier); + var onDiskView = onDiskGraph.getView()) + { + // structure check + TestUtil.assertGraphEquals(graph, onDiskGraph); + + // decode and compare vectors + validateSeparatedNVQ(onDiskView, ravv, nvq); } } @@ -207,16 +319,12 @@ public void testLargeGraph() throws Exception TestUtil.writeGraph(graph, ravv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath.toAbsolutePath()); - var onDiskGraph = OnDiskGraphIndex.load(readerSupplier); - var cachedOnDiskGraph = new CachingGraphIndex(onDiskGraph)) + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { TestUtil.assertGraphEquals(graph, onDiskGraph); - TestUtil.assertGraphEquals(graph, cachedOnDiskGraph); - try (var onDiskView = onDiskGraph.getView(); - var cachedOnDiskView = onDiskGraph.getView()) + try (var onDiskView = onDiskGraph.getView()) { validateVectors(onDiskView, ravv); - validateVectors(cachedOnDiskView, ravv); } } } @@ -229,13 +337,13 @@ public void testV0Read() throws IOException { var onDiskGraph = OnDiskGraphIndex.load(readerSupplier); var onDiskView = onDiskGraph.getView()) { - assertEquals(32, onDiskGraph.maxDegree); + assertEquals(32, onDiskGraph.getDegree(0)); assertEquals(2, onDiskGraph.version); - assertEquals(100_000, onDiskGraph.size); + assertEquals(100_000, onDiskGraph.size(0)); assertEquals(2, onDiskGraph.dimension); - assertEquals(99779, onDiskGraph.entryNode); + assertEquals(99779, onDiskGraph.entryNode.node); assertEquals(EnumSet.of(FeatureId.INLINE_VECTORS), onDiskGraph.features.keySet()); - var actualNeighbors = getNeighborNodes(onDiskView, 12345); + var actualNeighbors = getNeighborNodes(onDiskView, 0, 12345); var expectedNeighbors = Set.of(67461, 9540, 85444, 13638, 89415, 21255, 73737, 46985, 71373, 47436, 94863, 91343, 27215, 59730, 69911, 91867, 89373, 6621, 59106, 98922, 69679, 47728, 60722, 56052, 28854, 38902, 21561, 20665, 41722, 57917, 34495, 5183); assertEquals(expectedNeighbors, actualNeighbors); } @@ -265,6 +373,64 @@ public void testV0Write() throws IOException { assertArrayEquals(contents1, contents2); } + @Test + public void testMultiLayerFullyConnected() throws Exception { + // Suppose we have 3 layers of sizes 5, 4, 3 + var graph = new TestUtil.FullyConnectedGraphIndex(1, List.of(5, 4, 3)); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0)); + var outputPath = testDirectory.resolve("fully_connected_multilayer"); + TestUtil.writeGraph(graph, ravv, outputPath); + + // read back + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) + { + // verify the multi-layer structure + assertEquals(2, onDiskGraph.getMaxLevel()); + assertEquals(5, onDiskGraph.size(0)); + assertEquals(4, onDiskGraph.size(1)); + assertEquals(3, onDiskGraph.size(2)); + TestUtil.assertGraphEquals(graph, onDiskGraph); + + var q = randomVector(ThreadLocalRandom.current(), ravv.dimension()); + var results1 = GraphSearcher.search(q, 10, ravv, VectorSimilarityFunction.EUCLIDEAN, graph, Bits.ALL); + var results2 = GraphSearcher.search(q, 10, ravv, VectorSimilarityFunction.EUCLIDEAN, onDiskGraph, Bits.ALL); + assertEquals(results1, results2); + } + } + + @Test + public void testMultiLayerRandomlyConnected() throws Exception { + // 3 layers + var layerInfo = List.of( + new CommonHeader.LayerInfo(100, 8), + new CommonHeader.LayerInfo(10, 3), + new CommonHeader.LayerInfo(5, 2) + ); + var graph = new TestUtil.RandomlyConnectedGraphIndex(layerInfo, getRandom()); + var ravv = new TestVectorGraph.CircularFloatVectorValues(graph.size(0)); + var outputPath = testDirectory.resolve("random_multilayer"); + + TestUtil.writeGraph(graph, ravv, outputPath); + + // read back + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) + { + // confirm multi-layer + assertEquals(2, onDiskGraph.getMaxLevel()); + assertEquals(100, onDiskGraph.size(0)); + assertEquals(10, onDiskGraph.size(1)); + assertEquals(5, onDiskGraph.size(2)); + TestUtil.assertGraphEquals(graph, onDiskGraph); + + var q = randomVector(ThreadLocalRandom.current(), ravv.dimension()); + var results1 = GraphSearcher.search(q, 10, ravv, VectorSimilarityFunction.EUCLIDEAN, graph, Bits.ALL); + var results2 = GraphSearcher.search(q, 10, ravv, VectorSimilarityFunction.EUCLIDEAN, onDiskGraph, Bits.ALL); + assertEquals(results1, results2); + } + } + @Test public void testV0WriteIncremental() throws IOException { var fileIn = new File("resources/version0.odgi"); @@ -331,7 +497,7 @@ public void testIncrementalWrites() throws IOException { var pqv = (PQVectors) pq.encodeAll(ravv); try (var writer = new OnDiskGraphIndexWriter.Builder(graph, incrementalFadcPath) .with(new InlineVectors(ravv.dimension())) - .with(new FusedADC(graph.maxDegree(), pq)) + .with(new FusedADC(graph.getDegree(0), pq)) .build()) { // write inline vectors incrementally diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphBuildBench.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphBuildBench.java index 3a5ab4037..9db25e642 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphBuildBench.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphBuildBench.java @@ -54,10 +54,20 @@ public Parameters() { @OutputTimeUnit(TimeUnit.SECONDS) public void testGraphBuild(Blackhole bh, Parameters p) { long start = System.nanoTime(); - GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.ravv, p.ds.similarityFunction, 8, 60, 1.2f, 1.4f); + GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.ravv, p.ds.similarityFunction, 8, 60, 1.2f, 1.4f, false); graphIndexBuilder.build(p.ravv); - var avgShortEdges = graphIndexBuilder.getAverageShortEdges(); - System.out.format("Build M=%d ef=%d in %.2fs with %.2f short edges%n", - 32, 600, (System.nanoTime() - start) / 1_000_000_000.0, avgShortEdges); + System.out.format("Build M=%d ef=%d in %.2fs%n", + 32, 600, (System.nanoTime() - start) / 1_000_000_000.0); + } + + @Benchmark + @BenchmarkMode(Mode.Throughput) + @OutputTimeUnit(TimeUnit.SECONDS) + public void testGraphBuildWithHierarchy(Blackhole bh, Parameters p) { + long start = System.nanoTime(); + GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.ravv, p.ds.similarityFunction, 8, 60, 1.2f, 1.4f, true); + graphIndexBuilder.build(p.ravv); + System.out.format("Build M=%d ef=%d in %.2fs%n", + 32, 600, (System.nanoTime() - start) / 1_000_000_000.0); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphIndexBench.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphIndexBench.java index 5bbd0ee9b..20b292dca 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphIndexBench.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/microbench/GraphIndexBench.java @@ -96,8 +96,17 @@ public Parameters() { @BenchmarkMode(Mode.Throughput) @OutputTimeUnit(TimeUnit.SECONDS) public void testGraphBuild(Blackhole bh, Parameters p) { - GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.vectors, VectorSimilarityFunction.DOT_PRODUCT, 8, 60, 1.2f, 1.4f); + GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.vectors, VectorSimilarityFunction.DOT_PRODUCT, 8, 60, 1.2f, 1.4f, false); bh.consume(graphIndexBuilder.build(p.vectors)); } + + @Benchmark + @BenchmarkMode(Mode.Throughput) + @OutputTimeUnit(TimeUnit.SECONDS) + public void testGraphBuildWithHierarchy(Blackhole bh, Parameters p) { + GraphIndexBuilder graphIndexBuilder = new GraphIndexBuilder(p.vectors, VectorSimilarityFunction.DOT_PRODUCT, 8, 60, 1.2f, 1.4f, true); + bh.consume(graphIndexBuilder.build(p.vectors)); + } + } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java index 8240699a7..a476e41d6 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java @@ -21,7 +21,6 @@ import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.disk.CachingGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import org.junit.After; @@ -63,12 +62,10 @@ public void testFusedGraph() throws Exception { TestUtil.writeFusedGraph(graph, ravv, pqv, outputPath); try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); - var onDiskGraph = OnDiskGraphIndex.load(readerSupplier, 0); - var cachedOnDiskGraph = new CachingGraphIndex(onDiskGraph, 5)) + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier, 0)) { TestUtil.assertGraphEquals(graph, onDiskGraph); - TestUtil.assertGraphEquals(graph, cachedOnDiskGraph); - try (var cachedOnDiskView = cachedOnDiskGraph.getView()) + try (var cachedOnDiskView = onDiskGraph.getView()) { for (var similarityFunction : VectorSimilarityFunction.values()) { var queryVector = TestUtil.randomVector(getRandom(), 512); @@ -78,14 +75,14 @@ public void testFusedGraph() throws Exception { var fusedScoreFunction = cachedOnDiskView.approximateScoreFunctionFor(queryVector, similarityFunction); var ordinal = getRandom().nextInt(graph.size()); // first pass compares fused ADC's direct similarity to reranker's similarity, used for comparisons to a specific node - var neighbors = cachedOnDiskView.getNeighborsIterator(ordinal); + var neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); // TODO for (; neighbors.hasNext(); ) { var neighbor = neighbors.next(); var similarity = fusedScoreFunction.similarityTo(neighbor); assertEquals(reranker.similarityTo(neighbor), similarity, 0.01); } // second pass compares fused ADC's edge similarity prior to having enough information for quantization to PQ - neighbors = cachedOnDiskView.getNeighborsIterator(ordinal); + neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); // TODO var edgeSimilarities = fusedScoreFunction.edgeLoadingSimilarityTo(ordinal); for (int j = 0; neighbors.hasNext(); j++) { var neighbor = neighbors.next(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMap.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java similarity index 77% rename from jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMap.java rename to jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java index 74826e37e..ac4695bbc 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMap.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java @@ -32,16 +32,20 @@ import java.util.concurrent.CountDownLatch; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) -public class TestDenseIntMap extends RandomizedTest { +public class TestIntMap extends RandomizedTest { @Test public void testInsert() { - var map = new DenseIntMap(100); + testInsertInternal(new DenseIntMap<>(100)); + testInsertInternal(new SparseIntMap<>()); + } + + private void testInsertInternal(IntMap map) { for (int i = 0; i < 3; i++) { Assert.assertNull(map.get(i)); Assert.assertFalse(map.containsKey(i)); - map.put(i, "value" + i); + map.compareAndPut(i, null, "value" + i); Assert.assertEquals("value" + i, map.get(i)); Assert.assertTrue(map.containsKey(i)); Assert.assertEquals(i + 1, map.size()); @@ -50,14 +54,18 @@ public void testInsert() { @Test public void testUpdate() { - var map = new DenseIntMap(100); + testUpdateInternal(new DenseIntMap<>(100)); + testUpdateInternal(new SparseIntMap<>()); + } + + private void testUpdateInternal(IntMap map) { for (int i = 0; i < 3; i++) { - map.put(i, "value" + i); + map.compareAndPut(i, null, "value" + i); } Assert.assertEquals(3, map.size()); for (int i = 0; i < 3; i++) { - map.put(i, "new-value" + i); + map.compareAndPut(i, map.get(i), "new-value" + i); Assert.assertEquals("new-value" + i, map.get(i)); Assert.assertEquals(3, map.size()); } @@ -65,9 +73,13 @@ public void testUpdate() { @Test public void testRemove() { - var map = new DenseIntMap(100); + testRemoveInternal(new DenseIntMap<>(100)); + testRemoveInternal(new SparseIntMap<>()); + } + + private void testRemoveInternal(IntMap map) { for (int i = 0; i < 3; i++) { - map.put(i, "value" + i); + map.compareAndPut(i, null, "value" + i); } Assert.assertEquals(3, map.size()); @@ -82,12 +94,12 @@ public void testRemove() { @Test public void testConcurrency() throws InterruptedException { for (int i = 0; i < 100; i++) { - testConcurrencyOnce(); + testConcurrencyOnce(new DenseIntMap<>(100)); + testConcurrencyOnce(new SparseIntMap<>()); } } - private static void testConcurrencyOnce() throws InterruptedException { - var map = new DenseIntMap(100); + private static void testConcurrencyOnce(IntMap map) throws InterruptedException { var source = new ConcurrentHashMap(); int nThreads = randomIntBetween(2, 16); @@ -103,7 +115,7 @@ private static void testConcurrencyOnce() throws InterruptedException { } else { String value = randomAsciiAlphanumOfLength(20); source.put(key, value); - map.put(key, value); + map.compareAndPut(key, map.get(key), value); } } } finally {