diff --git a/build.xml b/build.xml
index 04073f37b024..75d3cddda0e1 100644
--- a/build.xml
+++ b/build.xml
@@ -743,7 +743,7 @@
-
+
diff --git a/src/java/org/apache/cassandra/cache/ChunkCache.java b/src/java/org/apache/cassandra/cache/ChunkCache.java
index 6def74cc21b2..206614742973 100644
--- a/src/java/org/apache/cassandra/cache/ChunkCache.java
+++ b/src/java/org/apache/cassandra/cache/ChunkCache.java
@@ -297,7 +297,7 @@ public void invalidateFileNow(File file)
synchronousCache.invalidateAll(Iterables.filter(cache.asMap().keySet(), x -> (x.readerId & mask) == fileId));
}
- static class Key
+ static class Key implements Comparable
{
final long readerId;
final long position;
@@ -312,11 +312,15 @@ private Key(long readerId, long position)
@Override
public int hashCode()
{
- final int prime = 31;
- int result = 1;
- result = prime * result + Long.hashCode(readerId);
- result = prime * result + Long.hashCode(position);
- return result;
+ // Mix readerId and position into a single long using a large prime multiplier
+ // This constant is a mixing constant derived from the Golden Ratio
+ long mixed = (readerId + position) * 0x9E3779B97F4A7C15L;
+
+ // Spread the bits (XOR-shift) to ensure high bits affect low bits
+ mixed ^= (mixed >>> 32);
+ mixed ^= (mixed >>> 16);
+
+ return (int) mixed;
}
@Override
@@ -331,6 +335,17 @@ public boolean equals(Object obj)
return (position == other.position)
&& readerId == other.readerId;
}
+
+ @Override
+ public int compareTo(Key other) {
+ // Compare readerId first
+ int cmp = Long.compare(this.readerId, other.readerId);
+ if (cmp != 0) {
+ return cmp;
+ }
+ // Then compare position
+ return Long.compare(this.position, other.position);
+ }
}
/**
diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
index 1ec5752d6a25..5b2c01c7b2a1 100644
--- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
+++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java
@@ -436,12 +436,16 @@ public enum CassandraRelevantProperties
SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"),
// Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index.
SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"),
+ // Whether compaction should build vector indexes using fused adc
+ SAI_VECTOR_ENABLE_FUSED("cassandra.sai.vector.enable_fused", "true"),
// Use nvq when building graphs in compaction. Disabled by default for now. Enabling will reduce recall slightly
// while also reducing the storage footprint.
SAI_VECTOR_ENABLE_NVQ("cassandra.sai.vector.enable_nvq", "false"),
// NVQ number of subvectors. This isn't really expected to change much so we're only exposing
// it as a global variable in case it's needed.
SAI_VECTOR_NVQ_NUM_SUB_VECTORS("cassandra.sai.vector.nvq_num_sub_vectors", "2"),
+ // When building a compaction graph, encode layer 0 nodes in parallel and subsequently use async io for writes.
+ SAI_ENCODE_AND_WRITE_VECTOR_GRAPH_IN_PARALLEL("cassandra.sai.vector.encode_write_graph_parallel", "true"),
/**
* Whether to disable auto-compaction
*/
diff --git a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
index 6edaa21812f8..e8f51905442d 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/format/Version.java
@@ -39,6 +39,7 @@
import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat;
import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
+import org.apache.cassandra.index.sai.disk.v8.V8OnDiskFormat;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.io.sstable.format.SSTableFormat;
import org.apache.cassandra.schema.SchemaConstants;
@@ -75,10 +76,12 @@ public class Version implements Comparable
public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec"));
// total terms count serialization in index metadata, enables ANN_USE_SYNTHETIC_SCORE by default
public static final Version ED = new Version("ed", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ed"));
+ // jvector file format version 6 (skipped 5)
+ public static final Version FA = new Version("fa", V8OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "fa"));
// These are in reverse-chronological order so that the latest version is first. Version matching tests
// are more likely to match the latest version, so we want to test that one first.
- public static final List ALL = Lists.newArrayList(ED, EC, EB, DC, DB, CA, BA, AA);
+ public static final List ALL = Lists.newArrayList(FA, ED, EC, EB, DC, DB, CA, BA, AA);
public static final Version EARLIEST = AA;
public static final Version VECTOR_EARLIEST = BA;
diff --git a/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java
new file mode 100644
index 000000000000..aca3d530ce97
--- /dev/null
+++ b/src/java/org/apache/cassandra/index/sai/disk/v8/V8OnDiskFormat.java
@@ -0,0 +1,32 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.cassandra.index.sai.disk.v8;
+
+import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
+
+public class V8OnDiskFormat extends V7OnDiskFormat
+{
+ public static final V8OnDiskFormat instance = new V8OnDiskFormat();
+
+ @Override
+ public int jvectorFileFormatVersion()
+ {
+ return 6;
+ }
+}
\ No newline at end of file
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
index 8745ece4983f..c158a08e6129 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java
@@ -95,7 +95,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata segmentMe
SegmentMetadata.ComponentMetadata termsMetadata = this.componentMetadatas.get(IndexComponentType.TERMS_DATA);
graphHandle = indexFiles.termsData();
- var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset);
+ var rawGraph = OnDiskGraphIndex.load(graphHandle::createReader, termsMetadata.offset, false);
features = rawGraph.getFeatureSet();
graph = rawGraph;
usesNVQ = features.contains(FeatureId.NVQ_VECTORS);
@@ -123,7 +123,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata segmentMe
}
VectorCompression.CompressionType compressionType = VectorCompression.CompressionType.values()[reader.readByte()];
- if (features.contains(FeatureId.FUSED_ADC))
+ if (features.contains(FeatureId.FUSED_PQ))
{
assert compressionType == VectorCompression.CompressionType.PRODUCT_QUANTIZATION;
compressedVectors = null;
@@ -239,9 +239,7 @@ public CloseableIterator search(VectorFloat> queryVector,
{
var view = (ImmutableGraphIndex.ScoringView) searcher.getView();
SearchScoreProvider ssp;
- // FusedADC can no longer be written due to jvector upgrade. However, it's possible these index files
- // still exist, so we have to support them.
- if (features.contains(FeatureId.FUSED_ADC))
+ if (features.contains(FeatureId.FUSED_PQ))
{
var asf = view.approximateScoreFunctionFor(queryVector, similarityFunction);
var rr = isRerankless ? null : view.rerankerFor(queryVector, similarityFunction);
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
index 12529c20c36c..8a16c5b999fc 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java
@@ -25,12 +25,14 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
+import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
+import java.util.function.IntFunction;
import java.util.function.IntUnaryOperator;
import java.util.function.ToIntFunction;
@@ -41,17 +43,21 @@
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
import io.github.jbellis.jvector.graph.GraphSearcher;
+import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
+import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
import io.github.jbellis.jvector.quantization.BinaryQuantization;
import io.github.jbellis.jvector.quantization.CompressedVectors;
+import io.github.jbellis.jvector.quantization.ImmutablePQVectors;
+import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.quantization.VectorCompressor;
@@ -63,6 +69,7 @@
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
+import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.agrona.collections.IntHashSet;
@@ -89,6 +96,7 @@
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
+import org.apache.cassandra.io.util.File;
import org.apache.cassandra.io.util.SequentialWriter;
import org.apache.cassandra.service.StorageService;
import org.apache.cassandra.tracing.Tracing;
@@ -106,6 +114,9 @@ public enum PQVersion {
V1, // includes unit vector calculation
}
+ /** whether to use fused ADC when writing indexes (assuming all other conditions are met) */
+ private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean();
+
/** minimum number of rows to perform PQ codebook generation */
public static final int MIN_PQ_ROWS = 1024;
@@ -449,36 +460,24 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
OrdinalMapper ordinalMapper = remappedPostings.ordinalMapper;
- // Write the NVQ feature. We could compute this at insert time, but because the graph allows for parallel
- // insertions, it would be a bit more complicated. All vectors are in memory, so the computation to build the
- // mean vector should be pretty fast, and this path is only used when we don't have an existing
- // ProductQuantization or when we're using BQ.
- NVQuantization nvq = writeNvq ? NVQuantization.compute(vectorValues, NUM_SUB_VECTORS) : null;
-
IndexComponent.ForWrite termsDataComponent = perIndexComponents.addOrGet(IndexComponentType.TERMS_DATA);
var indexFile = termsDataComponent.file();
long termsOffset = SAICodecUtils.headerSize();
if (indexFile.exists())
termsOffset += indexFile.length();
try (var pqOutput = perIndexComponents.addOrGet(IndexComponentType.PQ).openOutput(true);
- var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true);
- var indexWriter = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath())
- .withStartOffset(termsOffset)
- .withVersion(perIndexComponents.version().onDiskFormat().jvectorFileFormatVersion())
- .withMapper(ordinalMapper)
- .with(nvq != null ? new NVQ(nvq) : new InlineVectors(vectorValues.dimension()))
- .build())
+ var postingsOutput = perIndexComponents.addOrGet(IndexComponentType.POSTING_LISTS).openOutput(true))
{
SAICodecUtils.writeHeader(pqOutput);
SAICodecUtils.writeHeader(postingsOutput);
- indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header
- SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput()), perIndexComponents.version());
- assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position();
+
+ // Write fused unless we don't meet some criteria
+ boolean attemptWritingFused = ENABLE_FUSED && perIndexComponents.version().onDiskFormat().jvectorFileFormatVersion() >= 6;
// compute and write PQ
long pqOffset = pqOutput.getFilePointer();
- long pqPosition = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context());
- long pqLength = pqPosition - pqOffset;
+ var compressor = writePQ(pqOutput.asSequentialWriter(), remappedPostings, perIndexComponents.context(), attemptWritingFused);
+ long pqLength = pqOutput.asSequentialWriter().position() - pqOffset;
// write postings
long postingsOffset = postingsOutput.getFilePointer();
@@ -497,25 +496,50 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
}
long postingsLength = postingsPosition - postingsOffset;
- // write the graph
- var start = System.nanoTime();
- var supplier = nvq != null
- ? Feature.singleStateFactory(FeatureId.NVQ_VECTORS, nodeId -> new NVQ.State(nvq.encode(vectorValues.getVector(nodeId))))
- : Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
- indexWriter.write(supplier);
- SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
- logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
- long termsLength = indexWriter.getOutput().position() - termsOffset;
-
- // write remaining footers/checksums
- SAICodecUtils.writeFooter(pqOutput);
- SAICodecUtils.writeFooter(postingsOutput);
-
- // add components to the metadata map
- return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength);
+ // Write the NVQ feature. We could compute this at insert time, but because the graph allows for parallel
+ // insertions, it would be a bit more complicated. All vectors are in memory, so the computation to build the
+ // mean vector should be pretty fast, and this path is only used when we don't have an existing
+ // ProductQuantization or when we're using BQ.
+ NVQuantization nvq = writeNvq ? NVQuantization.compute(vectorValues, NUM_SUB_VECTORS) : null;
+
+ try (var indexWriter = createIndexWriter(indexFile, termsOffset, perIndexComponents.context(), ordinalMapper, compressor, nvq);
+ var view = builder.getGraph().getView())
+ {
+ indexWriter.getOutput().seek(indexFile.length()); // position at the end of the previous segment before writing our own header
+ SAICodecUtils.writeHeader(SAICodecUtils.toLuceneOutput(indexWriter.getOutput()), perIndexComponents.version());
+ assert indexWriter.getOutput().position() == termsOffset : "termsOffset " + termsOffset + " != " + indexWriter.getOutput().position();
+
+ // write the graph
+ var start = System.nanoTime();
+ indexWriter.write(suppliers(perIndexComponents.context(), view, compressor, nvq));
+ SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
+ logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
+ long termsLength = indexWriter.getOutput().position() - termsOffset;
+
+ // write remaining footers/checksums
+ SAICodecUtils.writeFooter(pqOutput);
+ SAICodecUtils.writeFooter(postingsOutput);
+
+ // add components to the metadata map
+ return createMetadataMap(termsOffset, termsLength, postingsOffset, postingsLength, pqOffset, pqLength);
+ }
}
}
+ private OnDiskGraphIndexWriter createIndexWriter(File indexFile, long termsOffset, IndexContext context, OrdinalMapper ordinalMapper, VectorCompressor> compressor, NVQuantization nvq) throws IOException
+ {
+ var indexWriterBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexFile.toPath())
+ .withStartOffset(termsOffset)
+ .withVersion(context.version().onDiskFormat().jvectorFileFormatVersion())
+ .withMapper(ordinalMapper)
+ .with(nvq != null ? new NVQ(nvq) : new InlineVectors(vectorValues.dimension()));
+
+ if (ENABLE_FUSED && compressor instanceof ProductQuantization && context.version().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ indexWriterBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor));
+
+ return indexWriterBuilder.build();
+ }
+
static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset, long termsLength, long postingsOffset, long postingsLength, long pqOffset, long pqLength)
{
SegmentMetadata.ComponentMetadataMap metadataMap = new SegmentMetadata.ComponentMetadataMap();
@@ -526,6 +550,29 @@ static SegmentMetadata.ComponentMetadataMap createMetadataMap(long termsOffset,
return metadataMap;
}
+ private EnumMap> suppliers(IndexContext context, ImmutableGraphIndex.View view, VectorCompressor> compressor, NVQuantization nvq)
+ {
+ var features = new EnumMap>(FeatureId.class);
+
+ // We either write NVQ or inline (full precision) vectors in the graph. nvq is null when it is not enabled.
+ if (nvq != null)
+ features.put(FeatureId.NVQ_VECTORS, nodeId -> new NVQ.State(nvq.encode(vectorValues.getVector(nodeId))));
+ else
+ features.put(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
+
+ if (ENABLE_FUSED && context.version().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ {
+ if (compressor instanceof ProductQuantization)
+ {
+ // TODO temporary hack to parallelize computation
+ PQVectors pqVectors = (PQVectors) compressor.encodeAll(vectorValues);
+ features.put(FeatureId.FUSED_PQ, nodeId -> new FusedPQ.State(view, pqVectors::get, nodeId));
+ }
+ }
+
+ return features;
+ }
+
/**
* Return the best previous CompressedVectors for this column that matches the `matcher` predicate.
* "Best" means the most recent one that hits the row count target of {@link ProductQuantization#MAX_PQ_TRAINING_SET_SIZE},
@@ -578,7 +625,7 @@ public static PqInfo getPqIfPresent(IndexContext indexContext, Function writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPostings remapped, IndexContext indexContext, boolean attemptWritingFused) throws IOException
{
var preferredCompression = sourceModel.compressionProvider.apply(vectorValues.dimension());
@@ -602,18 +649,24 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
}
assert !vectorValues.isValueShared();
// encode (compress) the vectors to save
- if (compressor != null)
+ if ((compressor instanceof ProductQuantization && !attemptWritingFused) || compressor instanceof BinaryQuantization)
cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues));
}
var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type;
writePqHeader(writer, allVectorsAreUnitLength, actualType, indexContext.version());
if (actualType == CompressionType.NONE)
- return writer.position();
+ return null;
+
+ if (attemptWritingFused)
+ {
+ compressor.write(writer, indexContext.version().onDiskFormat().jvectorFileFormatVersion());
+ return compressor;
+ }
// save (outside the synchronized block, this is io-bound not CPU)
cv.write(writer, indexContext.version().onDiskFormat().jvectorFileFormatVersion());
- return writer.position();
+ return null; // Don't need compressor in this case
}
static void writePqHeader(DataOutput writer, boolean unitVectors, CompressionType type, Version version)
diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
index 46c0912cb50b..4fa14554ed18 100644
--- a/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
+++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CompactionGraph.java
@@ -26,6 +26,7 @@
import java.util.EnumMap;
import java.util.Map;
import java.util.Set;
+import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
@@ -45,6 +46,7 @@
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
+import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
@@ -75,6 +77,7 @@
import net.openhft.chronicle.map.ChronicleMapBuilder;
import org.agrona.collections.Int2ObjectHashMap;
import org.apache.cassandra.concurrent.NamedThreadFactory;
+import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
@@ -119,6 +122,9 @@ public class CompactionGraph implements Closeable, Accountable
@VisibleForTesting
public static int PQ_TRAINING_SIZE = ProductQuantization.MAX_PQ_TRAINING_SET_SIZE;
+ private static boolean ENABLE_FUSED = CassandraRelevantProperties.SAI_VECTOR_ENABLE_FUSED.getBoolean();
+ private static boolean PARALLEL_ENCODING_WRITING = CassandraRelevantProperties.SAI_ENCODE_AND_WRITE_VECTOR_GRAPH_IN_PARALLEL.getBoolean();
+
private final VectorType.VectorSerializer serializer;
private final VectorSimilarityFunction similarityFunction;
private final ChronicleMap, CompactionVectorPostings> postingsMap;
@@ -184,8 +190,8 @@ public CompactionGraph(IndexComponents.ForWrite perIndexComponents, VectorCompre
this.useSyntheticOrdinals = !V5OnDiskFormat.writeV5VectorPostings(context.version()) || !allRowsHaveVectors;
// the extension here is important to signal to CFS.scrubDataDirectories that it should be removed if present at restart
- Component tmpComponent = new Component(Component.Type.CUSTOM, "chronicle" + Descriptor.TMP_EXT);
- postingsFile = dd.fileFor(tmpComponent);
+ Component tmpComponent = new Component(Component.Type.CUSTOM, "chronicle" + UUID.randomUUID() + Descriptor.TMP_EXT);
+ postingsFile = dd.tmpFileFor(tmpComponent);
postingsMap = ChronicleMapBuilder.of((Class>) (Class) VectorFloat.class, (Class) (Class) CompactionVectorPostings.class)
.averageKeySize(dimension * Float.BYTES)
.averageValueSize(VectorPostings.emptyBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 2 * Integer.BYTES)
@@ -195,7 +201,7 @@ public CompactionGraph(IndexComponents.ForWrite perIndexComponents, VectorCompre
.createPersistedTo(postingsFile.toJavaIOFile());
// Formatted so that the full resolution vector is written at the ordinal * vector dimension offset
- Component vectorsByOrdinalComponent = new Component(Component.Type.CUSTOM, "vectors_by_ordinal");
+ Component vectorsByOrdinalComponent = new Component(Component.Type.CUSTOM, "vectors_by_ordinal" + UUID.randomUUID() + Descriptor.TMP_EXT);
vectorsByOrdinalTmpFile = dd.tmpFileFor(vectorsByOrdinalComponent);
vectorsByOrdinalBufferedWriter = new BufferedRandomAccessWriter(vectorsByOrdinalTmpFile.toPath());
@@ -243,12 +249,16 @@ else if (compressor instanceof BinaryQuantization)
private OnDiskGraphIndexWriter createTermsWriter(OrdinalMapper ordinalMapper, NVQuantization nvq) throws IOException
{
var feature = nvq != null ? new NVQ(nvq) : new InlineVectors(dimension);
- return new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toPath())
- .withStartOffset(termsOffset)
- .with(feature)
- .withVersion(context.version().onDiskFormat().jvectorFileFormatVersion())
- .withMapper(ordinalMapper)
- .build();
+ // TODO is this hack to use a local path safe?
+ var writerBuilder = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), termsFile.toJavaIOFile().toPath())
+ .withParallelWrites(PARALLEL_ENCODING_WRITING)
+ .withStartOffset(termsOffset)
+ .with(feature)
+ .withVersion(context.version().onDiskFormat().jvectorFileFormatVersion())
+ .withMapper(ordinalMapper);
+ if (ENABLE_FUSED && compressor instanceof ProductQuantization && context.version().onDiskFormat().jvectorFileFormatVersion() >= 6)
+ writerBuilder.with(new FusedPQ(context.getIndexWriterConfig().getAnnMaxDegree(), (ProductQuantization) compressor));
+ return writerBuilder.build();
}
@Override
@@ -437,13 +447,17 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
var ordinalMapper = new AtomicReference();
long postingsOffset = postingsOutput.getFilePointer();
var es = Executors.newSingleThreadExecutor(new NamedThreadFactory("CompactionGraphPostingsWriter"));
+
var postingsFuture = es.submit(() -> {
// V2 doesn't support ONE_TO_MANY so force it to ZERO_OR_ONE_TO_MANY if necessary;
// similarly, if we've been using synthetic ordinals then we can't map to ONE_TO_MANY
// (ending up at ONE_TO_MANY when the source sstables were not is unusual, but possible,
// if a row with null vector in sstable A gets updated with a vector in sstable B)
+ // If there are too many holes, we leave the mapping on the disk.
if (postingsStructure == Structure.ONE_TO_MANY
- && (!V5OnDiskFormat.writeV5VectorPostings(version) || useSyntheticOrdinals))
+ && (!V5OnDiskFormat.writeV5VectorPostings(version)
+ || useSyntheticOrdinals
+ || V5VectorPostingsWriter.GLOBAL_HOLES_ALLOWED < (double) lastRowId / postingsMap.size()))
{
postingsStructure = Structure.ZERO_OR_ONE_TO_MANY;
}
@@ -465,7 +479,6 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
long postingsLength = postingsEnd - postingsOffset;
es.shutdown();
- // write the graph edge lists and optionally fused adc features
var start = System.nanoTime();
// Null if we not using nvq
@@ -491,7 +504,18 @@ public SegmentMetadata.ComponentMetadataMap flush() throws IOException
return new InlineVectors.State(threadLocalReaders.get().getVector(ordinal));
});
}
- writer.write(supplier);
+ if (writer.getFeatureSet().contains(FeatureId.FUSED_PQ))
+ {
+ try (var view = builder.getGraph().getView())
+ {
+ supplier.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, (PQVectors) compressedVectors, ordinal));
+ writer.write(supplier);
+ }
+ }
+ else
+ {
+ writer.write(supplier);
+ }
}
catch (Exception e)
{
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorCompactionTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorCompactionTest.java
index 807d8b600a4b..1cebc4e54969 100644
--- a/test/unit/org/apache/cassandra/index/sai/cql/VectorCompactionTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorCompactionTest.java
@@ -31,9 +31,12 @@
import org.junit.runners.Parameterized;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+import org.apache.cassandra.db.Keyspace;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.index.sai.SAIUtil;
+import org.apache.cassandra.index.sai.StorageAttachedIndex;
import org.apache.cassandra.index.sai.disk.format.Version;
+import org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher;
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter;
import static org.apache.cassandra.index.sai.disk.vector.CassandraOnHeapGraph.MIN_PQ_ROWS;
@@ -254,7 +257,7 @@ private void insertZeroOrOneToManyRows(int vectorsPerSstable, int sstables)
public void testOneToManyCompactionHolesInternal(int vectorsPerSstable, int sstables)
{
- createTable();
+ var indexName = createTableAndReturnIndexName();
disableCompaction();
@@ -268,6 +271,18 @@ public void testOneToManyCompactionHolesInternal(int vectorsPerSstable, int ssta
validateQueries();
compact();
validateQueries();
+
+ // Validate that we have the expected structure for all the sstables-segment indexes.
+ var sai = (StorageAttachedIndex) Keyspace.open(KEYSPACE).getColumnFamilyStore(currentTable()).getIndexManager().getIndexByName(indexName);
+ var indexes = sai.getIndexContext().getView().getIndexes();
+ for (var index : indexes)
+ {
+ for (var segment : index.getSegments())
+ {
+ var searcher = (V2VectorIndexSearcher) segment.getIndexSearcher();
+ assertEquals(V5VectorPostingsWriter.Structure.ZERO_OR_ONE_TO_MANY, searcher.getPostingsStructure());
+ }
+ }
}
private void insertOneToManyRows(int vectorsPerSstable, int sstables)
@@ -300,9 +315,14 @@ private void insertOneToManyRows(int vectorsPerSstable, int sstables)
}
private void createTable()
+ {
+ createTableAndReturnIndexName();
+ }
+
+ private String createTableAndReturnIndexName()
{
createTable("CREATE TABLE %s (pk int, v vector, PRIMARY KEY(pk))");
- createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'");
+ return createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'");
}
private void validateQueries()
diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
index 8373ac003218..739ae264714f 100644
--- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java
@@ -162,6 +162,7 @@ public void testCompaction() throws Throwable
// Take the CassandraOnHeapGraph code path.
compact();
+ compact();
for (int topK : List.of(1, 100))
{
var recall = testRecall(topK, queryVectors, groundTruth);
@@ -326,7 +327,7 @@ private void createTable()
private void createIndex()
{
// we need a long timeout because we are adding many vectors
- String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean'}");
+ String index = createIndexAsync("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean', 'enable_hierarchy': 'true'}");
waitForIndexQueryable(KEYSPACE, index, 5, TimeUnit.MINUTES);
}
diff --git a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
index 752f517e2e32..f1a13a2fd70d 100644
--- a/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
+++ b/test/unit/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIteratorTest.java
@@ -19,6 +19,7 @@
package org.apache.cassandra.index.sai.disk.vector;
import java.util.NoSuchElementException;
+import java.util.function.Function;
import org.junit.Test;
@@ -95,6 +96,12 @@ public NodesIterator getNeighborsIterator(int i, int i1)
throw new UnsupportedOperationException();
}
+ @Override
+ public void processNeighbors(int i, int i1, ScoreFunction scoreFunction, ImmutableGraphIndex.IntMarker intMarker, ImmutableGraphIndex.NeighborProcessor neighborProcessor)
+ {
+
+ }
+
@Override
public int size()
{