diff --git a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java index 8e8d86a3f41a..5187f4c355d1 100644 --- a/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java +++ b/src/java/org/apache/cassandra/config/CassandraRelevantProperties.java @@ -720,6 +720,7 @@ public enum CassandraRelevantProperties SAI_VALIDATE_TERMS_AT_COORDINATOR("cassandra.sai.validate_terms_at_coordinator", "true"), /** Controls the maximum top-k limit for vector search */ SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"), + SAI_VECTOR_USE_PRUNING_DEFAULT("cassandra.sai.jvector.use_pruning_default", "1000"), SAI_WRITE_JVECTOR3_FORMAT("cassandra.sai.write_jv3_format", "false"), SCHEMA_PULL_INTERVAL_MS("cassandra.schema_pull_interval_ms", "60000"), diff --git a/src/java/org/apache/cassandra/db/filter/ANNOptions.java b/src/java/org/apache/cassandra/db/filter/ANNOptions.java index f73e1f13e056..87934f225d0b 100644 --- a/src/java/org/apache/cassandra/db/filter/ANNOptions.java +++ b/src/java/org/apache/cassandra/db/filter/ANNOptions.java @@ -38,8 +38,9 @@ public class ANNOptions { public static final String RERANK_K_OPTION_NAME = "rerank_k"; + public static final String USE_PRUNING_OPTION_NAME = "use_pruning"; - public static final ANNOptions NONE = new ANNOptions(null); + public static final ANNOptions NONE = new ANNOptions(null, null); public static final Serializer serializer = new Serializer(); @@ -51,15 +52,22 @@ public class ANNOptions @Nullable public final Integer rerankK; - private ANNOptions(@Nullable Integer rerankK) + /** + * Whether to use pruning to speed up the ANN search. If {@code null}, the default value is used. + */ + @Nullable + public final Boolean usePruning; + + private ANNOptions(@Nullable Integer rerankK, @Nullable Boolean usePruning) { this.rerankK = rerankK; + this.usePruning = usePruning; } - public static ANNOptions create(@Nullable Integer rerankK) + public static ANNOptions create(@Nullable Integer rerankK, @Nullable Boolean usePruning) { // if all the options are null, return the NONE instance - return rerankK == null ? NONE : new ANNOptions(rerankK); + return rerankK == null && usePruning == null ? NONE : new ANNOptions(rerankK, usePruning); } /** @@ -67,13 +75,16 @@ public static ANNOptions create(@Nullable Integer rerankK) */ public void validate(ClientState state, String keyspace, int limit) { - if (rerankK == null) + if (rerankK == null && usePruning == null) return; - if (rerankK < limit) - throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit)); + if (rerankK != null) + { + if (rerankK > 0 && rerankK < limit) + throw new InvalidRequestException(String.format("Invalid rerank_k value %d greater than 0 and less than limit %d", rerankK, limit)); - Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state); + Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state); + } // Ensure that all nodes in the cluster are in a version that supports ANN options, including this one assert keyspace != null; @@ -93,6 +104,7 @@ public void validate(ClientState state, String keyspace, int limit) public static ANNOptions fromMap(Map map) { Integer rerankK = null; + Boolean usePruning = null; for (Map.Entry entry : map.entrySet()) { @@ -103,13 +115,17 @@ public static ANNOptions fromMap(Map map) { rerankK = parseRerankK(value); } + else if (name.equals(USE_PRUNING_OPTION_NAME)) + { + usePruning = parseUsePruning(value); + } else { throw new InvalidRequestException("Unknown ANN option: " + name); } } - return ANNOptions.create(rerankK); + return ANNOptions.create(rerankK, usePruning); } private static int parseRerankK(String value) @@ -129,9 +145,28 @@ private static int parseRerankK(String value) return rerankK; } + private static boolean parseUsePruning(String value) + { + value = value.toLowerCase(); + if (!value.equals("true") && !value.equals("false")) + throw new InvalidRequestException(String.format("Invalid '%s' ANN option. Expected a boolean but found: %s", + USE_PRUNING_OPTION_NAME, value)); + return Boolean.parseBoolean(value); + } + public String toCQLString() { - return String.format("{'%s': %d}", RERANK_K_OPTION_NAME, rerankK); + StringBuilder sb = new StringBuilder("{"); + if (rerankK != null) + sb.append(String.format("'%s': %d", RERANK_K_OPTION_NAME, rerankK)); + if (usePruning != null) + { + if (rerankK != null) + sb.append(", "); + sb.append(String.format("'%s': %b", USE_PRUNING_OPTION_NAME, usePruning)); + } + sb.append("}"); + return sb.toString(); } @Override @@ -140,13 +175,14 @@ public boolean equals(Object o) if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ANNOptions that = (ANNOptions) o; - return Objects.equals(rerankK, that.rerankK); + return Objects.equals(rerankK, that.rerankK) && + Objects.equals(usePruning, that.usePruning); } @Override public int hashCode() { - return Objects.hash(rerankK); + return Objects.hash(rerankK, usePruning); } /** @@ -166,9 +202,9 @@ public static class Serializer { /** Bit flags mask to check if the rerank K option is present. */ private static final int RERANK_K_MASK = 1; - + private static final int USE_PRUNING_MASK = 2; /** Bit flags mask to check if there are any unknown options. It's the negation of all the known flags. */ - private static final int UNKNOWN_OPTIONS_MASK = ~RERANK_K_MASK; + private static final int UNKNOWN_OPTIONS_MASK = ~(RERANK_K_MASK | USE_PRUNING_MASK); /* * If you add a new option, then update ANNOptionsTest.FutureANNOptions and possibly add a new test verifying @@ -190,6 +226,8 @@ public void serialize(ANNOptions options, DataOutputPlus out, int version) throw if (options.rerankK != null) out.writeUnsignedVInt32(options.rerankK); + if (options.usePruning != null) + out.writeBoolean(options.usePruning); } public ANNOptions deserialize(DataInputPlus in, int version) throws IOException @@ -206,8 +244,9 @@ public ANNOptions deserialize(DataInputPlus in, int version) throws IOException "new options that are not supported by this node."); Integer rerankK = hasRerankK(flags) ? (int) in.readUnsignedVInt() : null; + Boolean usePruning = hasUsePruning(flags) ? in.readBoolean() : null; - return ANNOptions.create(rerankK); + return ANNOptions.create(rerankK, usePruning); } public long serializedSize(ANNOptions options, int version) @@ -221,6 +260,8 @@ public long serializedSize(ANNOptions options, int version) if (options.rerankK != null) size += TypeSizes.sizeofUnsignedVInt(options.rerankK); + if (options.usePruning != null) + size += TypeSizes.sizeof(options.usePruning); return size; } @@ -234,6 +275,8 @@ private static int flags(ANNOptions options) if (options.rerankK != null) flags |= RERANK_K_MASK; + if (options.usePruning != null) + flags |= USE_PRUNING_MASK; return flags; } @@ -242,5 +285,10 @@ private static boolean hasRerankK(int flags) { return (flags & RERANK_K_MASK) == RERANK_K_MASK; } + + private static boolean hasUsePruning(int flags) + { + return (flags & USE_PRUNING_MASK) == USE_PRUNING_MASK; + } } } diff --git a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java index 9a1e07523f78..ca55cd9c7833 100644 --- a/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java +++ b/src/java/org/apache/cassandra/dht/Murmur3Partitioner.java @@ -376,6 +376,12 @@ public Token fromByteArray(ByteBuffer bytes) return new LongToken(ByteBufferUtil.toLong(bytes)); } + @Override + public Token fromLongValue(long token) + { + return new LongToken(token); + } + @Override public Token fromByteBuffer(ByteBuffer bytes, int position, int length) { diff --git a/src/java/org/apache/cassandra/dht/Token.java b/src/java/org/apache/cassandra/dht/Token.java index fda7f307513d..0fcd0d1c7be1 100644 --- a/src/java/org/apache/cassandra/dht/Token.java +++ b/src/java/org/apache/cassandra/dht/Token.java @@ -40,6 +40,20 @@ public static abstract class TokenFactory public abstract ByteBuffer toByteArray(Token token); public abstract Token fromByteArray(ByteBuffer bytes); + /** + * This method exists so that callers can create tokens from the primitive {@code long} value for this {@link Token}, if + * one exits. It is especially useful to skip ByteBuffer serde operations where performance is critical. + * + * @param token the primitive {@code long} value of this token + * @return the {@link Token} instance corresponding to the given primitive {@code long} value + * + * @throws UnsupportedOperationException if this {@link Token} is not backed by a primitive {@code long} value + */ + public Token fromLongValue(long token) + { + throw new UnsupportedOperationException(); + } + /** * Produce a byte-comparable representation of the token. * See {@link Token#asComparableBytes} diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java index 266100a76789..58e4cf3367d6 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/MemtableIndexWriter.java @@ -146,6 +146,7 @@ public void complete(Stopwatch stopwatch) throws IOException private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType termComparator, MemtableTermsIterator terms, int maxSegmentRowId) throws IOException { + long numPostings; long numRows; SegmentMetadataBuilder metadataBuilder = new SegmentMetadataBuilder(0, perIndexComponents); SegmentMetadata.ComponentMetadataMap indexMetas; @@ -166,7 +167,8 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter ); indexMetas = writer.writeAll(metadataBuilder.intercept(terms), docLengths); - numRows = writer.getPostingsCount(); + numPostings = writer.getPostingsCount(); + numRows = docLengths.size(); } } else @@ -180,18 +182,20 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter { ImmutableOneDimPointValues values = ImmutableOneDimPointValues.fromTermEnum(terms, termComparator); indexMetas = writer.writeAll(metadataBuilder.intercept(values)); - numRows = writer.getPointCount(); + numPostings = writer.getPointCount(); + numRows = numPostings; } } // If no rows were written we need to delete any created column index components // so that the index is correctly identified as being empty (only having a completion marker) - if (numRows == 0) + if (numPostings == 0) { perIndexComponents.forceDeleteAllComponents(); return 0; } + metadataBuilder.setNumRows(numRows); metadataBuilder.setKeyRange(pkFactory.createPartitionKeyOnly(minKey), pkFactory.createPartitionKeyOnly(maxKey)); metadataBuilder.setRowIdRange(terms.getMinSSTableRowId(), terms.getMaxSSTableRowId()); metadataBuilder.setTermRange(terms.getMinTerm(), terms.getMaxTerm()); @@ -203,7 +207,7 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType ter SegmentMetadata.write(writer, Collections.singletonList(metadata)); } - return numRows; + return numPostings; } private boolean writeFrequencies() diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyMap.java index ae09fc0e52d0..51d337ea8ea8 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/PartitionAwarePrimaryKeyMap.java @@ -19,7 +19,6 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; -import java.nio.ByteBuffer; import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.ThreadSafe; @@ -142,7 +141,6 @@ public void close() throws IOException private final IKeyFetcher keyFetcher; private final PrimaryKey.Factory primaryKeyFactory; private final SSTableId sstableId; - private final ByteBuffer tokenBuffer = ByteBuffer.allocate(Long.BYTES); private PartitionAwarePrimaryKeyMap(LongArray rowIdToToken, LongArray rowIdToOffset, @@ -168,9 +166,8 @@ public SSTableId getSSTableId() @Override public PrimaryKey primaryKeyFromRowId(long sstableRowId) { - tokenBuffer.putLong(rowIdToToken.get(sstableRowId)); - tokenBuffer.rewind(); - return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromByteArray(tokenBuffer), () -> supplier(sstableRowId)); + long token = rowIdToToken.get(sstableRowId); + return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromLongValue(token), () -> supplier(sstableRowId)); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java index 5884b49501d6..077ce75d6ab9 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SSTableIndexWriter.java @@ -113,6 +113,7 @@ public void addRow(PrimaryKey key, Row row, long sstableRowId) throws IOExceptio if (indexContext.getDefinition().isStatic() != row.isStatic()) return; + boolean addedRow = false; if (indexContext.isNonFrozenCollection()) { Iterator valueIterator = indexContext.getValuesOf(row, nowInSec); @@ -121,7 +122,7 @@ public void addRow(PrimaryKey key, Row row, long sstableRowId) throws IOExceptio while (valueIterator.hasNext()) { ByteBuffer value = valueIterator.next(); - addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator()); + addedRow = addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator()); } } } @@ -129,8 +130,13 @@ public void addRow(PrimaryKey key, Row row, long sstableRowId) throws IOExceptio { ByteBuffer value = indexContext.getValueOf(key.partitionKey(), row, nowInSec); if (value != null) - addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator()); + { + addedRow = addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator()); + } } + if (addedRow) + currentBuilder.incRowCount(); + } @Override @@ -225,10 +231,10 @@ private boolean maybeAbort() return true; } - private void addTerm(ByteBuffer term, PrimaryKey key, long sstableRowId, AbstractType type) throws IOException + private boolean addTerm(ByteBuffer term, PrimaryKey key, long sstableRowId, AbstractType type) throws IOException { if (!indexContext.validateMaxTermSize(key.partitionKey(), term)) - return; + return false; if (currentBuilder == null) { @@ -241,10 +247,11 @@ else if (shouldFlush(sstableRowId)) } if (term.remaining() == 0 && TypeUtil.skipsEmptyValue(indexContext.getValidator())) - return; + return false; long allocated = currentBuilder.analyzeAndAdd(term, type, key, sstableRowId); limiter.increment(allocated); + return true; } private boolean shouldFlush(long sstableRowId) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java index 70997ccedcd1..701112a222d6 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java @@ -464,6 +464,7 @@ public SegmentMetadata flush() throws IOException metadataBuilder.setKeyRange(minKey, maxKey); metadataBuilder.setRowIdRange(minSSTableRowId, maxSSTableRowId); metadataBuilder.setTermRange(minTerm, maxTerm); + metadataBuilder.setNumRows(getRowCount()); flushInternal(metadataBuilder); return metadataBuilder.build(); @@ -502,8 +503,6 @@ private long add(List terms, PrimaryKey key, long sstableRowId) maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest()); } - rowCount++; - // segmentRowIdOffset should encode sstableRowId into Integer int segmentRowId = Math.toIntExact(sstableRowId - segmentRowIdOffset); @@ -600,6 +599,11 @@ int getRowCount() return rowCount; } + void incRowCount() + { + rowCount++; + } + /** * @return true if next SSTable row ID exceeds max segment row ID */ diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadataBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadataBuilder.java index 26326fc91b6a..2ace29728086 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadataBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadataBuilder.java @@ -86,6 +86,11 @@ public SegmentMetadataBuilder(long segmentRowIdOffset, IndexComponents.ForWrite this.termsDistributionBuilder = new TermsDistribution.Builder(context.getValidator(), byteComparableVersion, histogramSize, mostFrequentTermsCount); } + public void setNumRows(long numRows) + { + this.numRows = numRows; + } + public void setKeyRange(@Nonnull PrimaryKey minKey, @Nonnull PrimaryKey maxKey) { assert minKey.compareTo(maxKey) <= 0: "minKey (" + minKey + ") must not be greater than (" + maxKey + ')'; @@ -129,7 +134,6 @@ void add(ByteComparable term, int rowCount) if (built) throw new IllegalStateException("Segment metadata already built, no more additions allowed"); - numRows += rowCount; termsDistributionBuilder.add(term, rowCount); } @@ -360,7 +364,7 @@ public void intersect(IntersectVisitor visitor) throws IOException } @Override - public void close() throws IOException + public void close() { if (lastTerm != null) { @@ -371,5 +375,3 @@ public void close() throws IOException } } - - diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java index 2f1681abef34..220aa388ed60 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java @@ -152,7 +152,6 @@ public void close() throws IOException private final PrimaryKey.Factory primaryKeyFactory; private final ClusteringComparator clusteringComparator; private final SSTableId sstableId; - private final ByteBuffer tokenBuffer = ByteBuffer.allocate(Long.BYTES); private RowAwarePrimaryKeyMap(LongArray rowIdToToken, SortedTermsReader sortedTermsReader, @@ -185,9 +184,8 @@ public long count() @Override public PrimaryKey primaryKeyFromRowId(long sstableRowId) { - tokenBuffer.putLong(rowIdToToken.get(sstableRowId)); - tokenBuffer.rewind(); - return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromByteArray(tokenBuffer), () -> supplier(sstableRowId)); + long token = rowIdToToken.get(sstableRowId); + return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromLongValue(token), () -> supplier(sstableRowId)); } private long skinnyExactRowIdOrInvertedCeiling(PrimaryKey key) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 0af96fd7f268..4e654c78edd9 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -153,7 +153,8 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract var queryVector = vts.createFloatVector(exp.lower.value.vector); // this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold - var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold()); + // Threshold queries do not use pruning. + var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold(), false); return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId); } @@ -169,7 +170,7 @@ public CloseableIterator orderBy(Orderer orderer, Express int rerankK = orderer.rerankKFor(limit, graph.getCompression()); var queryVector = vts.createFloatVector(orderer.getVectorTerm()); - var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); + var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0, orderer.usePruning()); return toMetaSortedIterator(result, context); } @@ -183,13 +184,15 @@ public CloseableIterator orderBy(Orderer orderer, Express * @param rerankK the amplified limit for the query to get more accurate results * @param threshold the threshold for the query. When the threshold is greater than 0 and brute force logic is used, * the results will be filtered by the threshold. + * @param usePruning whether to use pruning to speed up the ANN search */ private CloseableIterator searchInternal(AbstractBounds keyRange, QueryContext context, VectorFloat queryVector, int limit, int rerankK, - float threshold) throws IOException + float threshold, + boolean usePruning) throws IOException { try (PrimaryKeyMap primaryKeyMap = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap()) { @@ -197,7 +200,7 @@ private CloseableIterator searchInternal(AbstractBounds searchInternal(AbstractBounds= metadata.maxSSTableRowId) - return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, visited -> {}); + return graph.search(queryVector, limit, rerankK, threshold, usePruning, Bits.ALL, context, visited -> {}); minSSTableRowId = Math.max(minSSTableRowId, metadata.minSSTableRowId); maxSSTableRowId = min(maxSSTableRowId, metadata.maxSSTableRowId); @@ -263,12 +266,14 @@ private CloseableIterator searchInternal(AbstractBounds orderByBruteForce(VectorFloat queryVector, SegmentRowIdOrdinalPairs segmentOrdinalPairs, int limit, int rerankK) throws IOException { + // We allow for negative rerankK, but for our cost calculations, it only makes sense to use 0 here. + rerankK = Math.max(0, rerankK); // If we use compressed vectors, we still have to order rerankK results using full resolution similarity // scores, so only use the compressed vectors when there are enough vectors to make it worthwhile. double twoPassCost = segmentOrdinalPairs.size() * CostCoefficients.ANN_SIMILARITY_COST @@ -288,21 +293,28 @@ private CloseableIterator orderByBruteForce(CompressedVectors cv VectorFloat queryVector, SegmentRowIdOrdinalPairs segmentOrdinalPairs, int limit, - int rerankK) throws IOException + int rerankK) { // Use the jvector NodeQueue to avoid unnecessary object allocations since this part of the code operates on // many rows. var approximateScores = new NodeQueue(new BoundedLongHeap(segmentOrdinalPairs.size()), NodeQueue.Order.MAX_HEAP); var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction(); var scoreFunction = cv.precomputedScoreFunctionFor(queryVector, similarityFunction); + columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size()); + + if (rerankK <= 0) + { + // Rerankless search, so we go straight to the NodeQueueRowIdIterator. + var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(scoreFunction); + approximateScores.pushMany(iter, segmentOrdinalPairs.size()); + return new NodeQueueRowIdIterator(approximateScores); + } // Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can // retrieve both values with O(1) lookup when we need to resolve the full resolution score in the // BruteForceRowIdIterator. - segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> { - approximateScores.push(i, scoreFunction.similarityTo(ordinal)); - }); - columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size()); + var iter = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction); + approximateScores.pushMany(iter, segmentOrdinalPairs.size()); var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView()); return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics); } @@ -320,9 +332,8 @@ private CloseableIterator orderByBruteForce(VectorFloat query var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction(); var esf = vectorsView.rerankerFor(queryVector, similarityFunction); // Because the scores are exact, we only store the rowid, score pair. - segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> { - scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal)); - }); + var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(esf); + scoredRowIds.pushMany(iter, segmentOrdinalPairs.size()); columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size()); return new NodeQueueRowIdIterator(scoredRowIds); } @@ -495,7 +506,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea segmentOrdinalPairs.forEachOrdinal(bits::set); // else ask the index to perform a search limited to the bits we created var queryVector = vts.createFloatVector(orderer.getVectorTerm()); - var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics); + var results = graph.search(queryVector, limit, rerankK, 0, orderer.usePruning(), bits, context, cost::updateStatistics); return toMetaSortedIterator(results, context); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v3/V3OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v3/V3OnDiskFormat.java index 05789de225e7..f229b0c200a6 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v3/V3OnDiskFormat.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v3/V3OnDiskFormat.java @@ -36,13 +36,7 @@ import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v2.V2OnDiskFormat; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_ENABLE_EDGES_CACHE; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_ENABLE_JVECTOR_DELETES; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_ENABLE_LTM_CONSTRUCTION; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_ENABLE_RERANK_FLOOR; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_JVECTOR_VERSION; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_REDUCE_TOPK_ACROSS_SSTABLES; -import static org.apache.cassandra.config.CassandraRelevantProperties.SAI_WRITE_JVECTOR3_FORMAT; +import static org.apache.cassandra.config.CassandraRelevantProperties.*; /** * Different vector components compared to V2OnDiskFormat (supporting DiskANN/jvector instead of HNSW/lucene). @@ -56,6 +50,8 @@ public class V3OnDiskFormat extends V2OnDiskFormat public static volatile boolean WRITE_JVECTOR3_FORMAT = SAI_WRITE_JVECTOR3_FORMAT.getBoolean(); public static final boolean ENABLE_LTM_CONSTRUCTION = SAI_ENABLE_LTM_CONSTRUCTION.getBoolean(); + // JVector doesn't give us a way to access its default, so we set it here, but allow it to be overridden. + public static boolean JVECTOR_USE_PRUNING_DEFAULT = SAI_VECTOR_USE_PRUNING_DEFAULT.getBoolean(); // These are built to be backwards and forwards compatible. Not final only for testing. public static int JVECTOR_VERSION = SAI_JVECTOR_VERSION.getInt(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java b/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java index e2b2aa22b93f..392341fb1298 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java @@ -26,12 +26,10 @@ import org.apache.cassandra.index.sai.utils.RowIdWithScore; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; -import org.apache.cassandra.utils.SortingIterator; /** - * An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link SortingIterator} of - * {@link RowWithApproximateScore}. + * An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link NodeQueue} of approximate scores. *

* The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top * `k` best approximate scores at a time with the full resolution vectors to return the top `limit`. diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java index 748cf48e582b..40d4d363faa0 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java @@ -200,8 +200,11 @@ public int size() /** * @param queryVector the query vector * @param limit the number of results to look for in the index (>= limit) - * @param rerankK the number of results to look for in the index (>= limit) + * @param rerankK the number of quantized results to look for in the index (>= limit or <= 0). If rerankK is + * non-positive, then we will use limit as the value and will skip reranking. Rerankless search + * only applies when the graph has compressed vectors. * @param threshold the minimum similarity score to accept + * @param usePruning whether to use pruning to speed up the search * @param acceptBits a Bits indicating which row IDs are acceptable, or null if no constraints * @param context unused (vestige from HNSW, retained in signature to allow calling both easily) * @param nodesVisitedConsumer a consumer that will be called with the number of nodes visited during the search @@ -212,14 +215,21 @@ public CloseableIterator search(VectorFloat queryVector, int limit, int rerankK, float threshold, + boolean usePruning, Bits acceptBits, QueryContext context, IntConsumer nodesVisitedConsumer) { VectorValidation.validateIndexable(queryVector, similarityFunction); + boolean isRerankless = rerankK <= 0; + if (isRerankless) + rerankK = limit; var graphAccessManager = searchers.get(); var searcher = graphAccessManager.get(); + // This searcher is reused across searches. We set here every time to ensure it is configured correctly + // for this search. Note that resume search in AutoResumingNodeScoreIterator will continue to use this setting. + searcher.usePruning(usePruning); try { var view = (GraphIndex.ScoringView) searcher.getView(); @@ -229,11 +239,12 @@ public CloseableIterator search(VectorFloat queryVector, if (features.contains(FeatureId.FUSED_ADC)) { var asf = view.approximateScoreFunctionFor(queryVector, similarityFunction); - var rr = view.rerankerFor(queryVector, similarityFunction); + var rr = isRerankless ? null : view.rerankerFor(queryVector, similarityFunction); ssp = new SearchScoreProvider(asf, rr); } else if (compressedVectors == null) { + // no compression, so we ignore isRerankless (except for setting rerankK to limit) ssp = new SearchScoreProvider(view.rerankerFor(queryVector, similarityFunction)); } else @@ -244,7 +255,7 @@ else if (compressedVectors == null) ? VectorSimilarityFunction.COSINE : similarityFunction; var asf = compressedVectors.precomputedScoreFunctionFor(queryVector, sf); - var rr = view.rerankerFor(queryVector, similarityFunction); + var rr = isRerankless ? null : view.rerankerFor(queryVector, similarityFunction); ssp = new SearchScoreProvider(asf, rr); } long start = nanoTime(); @@ -252,8 +263,8 @@ else if (compressedVectors == null) long elapsed = nanoTime() - start; if (V3OnDiskFormat.ENABLE_RERANK_FLOOR) context.updateAnnRerankFloor(result.getWorstApproximateScoreInTopK()); - Tracing.trace("DiskANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}", - limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); + Tracing.trace("DiskANN search for {}/{} rerankless={}, usePruning={} visited {} nodes, reranked {} to return {} results from {}", + limit, rerankK, isRerankless, usePruning, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); columnQueryMetrics.onSearchResult(result, elapsed, false); context.addAnnGraphSearchLatency(elapsed); if (threshold > 0) diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java index 4744f5fc363d..b65764d149ce 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java @@ -322,25 +322,30 @@ public void remove(ByteBuffer term, T key) /** * @return an itererator over {@link PrimaryKeyWithSortKey} in the graph's {@link SearchResult} order */ - public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, int rerankK, float threshold, Bits toAccept) + public CloseableIterator search(QueryContext context, VectorFloat queryVector, int limit, int rerankK, float threshold, boolean usePruning, Bits toAccept) { VectorValidation.validateIndexable(queryVector, similarityFunction); // search() errors out when an empty graph is passed to it if (vectorValues.size() == 0) return CloseableIterator.emptyIterator(); + // This configuration indicates rerankless search, but that is only applicable to disk search, so we set + // rerankK to limit and otherwise ignore the setting. + if (rerankK <= 0) + rerankK = limit; Bits bits = hasDeletions ? BitsUtil.bitsIgnoringDeleted(toAccept, postingsByOrdinal) : toAccept; var graphAccessManager = searchers.get(); var searcher = graphAccessManager.get(); + searcher.usePruning(usePruning); try { var ssf = SearchScoreProvider.exact(queryVector, similarityFunction, vectorValues); long start = nanoTime(); var result = searcher.search(ssf, limit, rerankK, threshold, 0.0f, bits); long elapsed = nanoTime() - start; - Tracing.trace("ANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}", - limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); + Tracing.trace("ANN search for {}/{} (usePruning: {}) visited {} nodes, reranked {} to return {} results from {}", + limit, rerankK, usePruning, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source); columnQueryMetrics.onSearchResult(result, elapsed, false); context.addAnnGraphSearchLatency(elapsed); if (threshold > 0) diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index b7e2ecc234d7..293fd0946fc0 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -203,7 +203,8 @@ public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBo float threshold = expr.getEuclideanSearchThreshold(); SortingIterator.Builder keyQueue; - try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), graph.size(), threshold)) + // Threshold queries do not use pruning. + try (var pkIterator = searchInternal(context, qv, keyRange, graph.size(), graph.size(), threshold, false)) { keyQueue = new SortingIterator.Builder<>(); while (pkIterator.hasNext()) @@ -235,15 +236,16 @@ public List> orderBy(QueryContext conte var qv = vts.createFloatVector(orderer.getVectorTerm()); var rerankK = orderer.rerankKFor(limit, VectorCompression.NO_COMPRESSION); - return List.of(searchInternal(context, qv, keyRange, limit, rerankK, 0)); + return List.of(searchInternal(context, qv, keyRange, limit, rerankK, 0, orderer.usePruning())); } private CloseableIterator searchInternal(QueryContext context, - VectorFloat queryVector, - AbstractBounds keyRange, - int limit, - int rerankK, - float threshold) + VectorFloat queryVector, + AbstractBounds keyRange, + int limit, + int rerankK, + float threshold, + boolean usePruning) { Bits bits; if (RangeUtil.coversFullRing(keyRange)) @@ -288,7 +290,7 @@ private CloseableIterator searchInternal(QueryContext con } } - var nodeScoreIterator = graph.search(context, queryVector, limit, rerankK, threshold, bits); + var nodeScoreIterator = graph.search(context, queryVector, limit, rerankK, threshold, usePruning, bits); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } @@ -336,7 +338,7 @@ public CloseableIterator orderResultsBy(QueryContext cont return orderByBruteForce(qv, keysInGraph); } // indexed path - var nodeScoreIterator = graph.search(context, qv, limit, rerankK, 0, relevantOrdinals::contains); + var nodeScoreIterator = graph.search(context, qv, limit, rerankK, 0, orderer.usePruning(), relevantOrdinals::contains); return new NodeScoreToScoredPrimaryKeyIterator(nodeScoreIterator); } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index 7fd70502ce4f..21e746d28a52 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -30,6 +30,7 @@ import javax.annotation.Nullable; import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.db.filter.ANNOptions; import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.index.SecondaryIndexManager; import org.apache.cassandra.index.sai.IndexContext; @@ -38,6 +39,8 @@ import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.TypeUtil; +import static org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat.JVECTOR_USE_PRUNING_DEFAULT; + /** * An SAI Orderer represents an index based order by clause. */ @@ -55,7 +58,7 @@ public class Orderer // Vector search parameters private float[] vector; - private final Integer rerankK; + private final ANNOptions annOptions; // BM25 search parameter private List queryTerms; @@ -65,14 +68,14 @@ public class Orderer * @param context the index context, used to build the view of memtables and sstables for query execution. * @param operator the operator for the order by clause. * @param term the term to order by (not always relevant) - * @param rerankK optional rerank K parameter for ANN queries + * @param annOptions optional options for ANN queries */ - public Orderer(IndexContext context, Operator operator, ByteBuffer term, @Nullable Integer rerankK) + public Orderer(IndexContext context, Operator operator, ByteBuffer term, ANNOptions annOptions) { this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; - this.rerankK = rerankK; + this.annOptions = annOptions; this.term = term; } @@ -115,11 +118,23 @@ public boolean isANN() public int rerankKFor(int limit, VectorCompression vc) { assert isANN() : "rerankK is only valid for ANN queries"; - return rerankK != null - ? rerankK + return annOptions.rerankK != null + ? annOptions.rerankK : context.getIndexWriterConfig().getSourceModel().rerankKFor(limit, vc); } + /** + * Whether to use pruning to speed up the ANN search. If the AnnOption does not specify a value for usePruning, + * we use the default value, which is currently configured as an environment variable. + * + * @return the usePruning value to use in ANN search + */ + public boolean usePruning() + { + assert isANN() : "usePruning is only valid for ANN queries"; + return annOptions.usePruning != null ? annOptions.usePruning : JVECTOR_USE_PRUNING_DEFAULT; + } + public boolean isBM25() { return operator == Operator.BM25; @@ -135,9 +150,7 @@ public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) var index = indexManager.getBestIndexFor(orderExpression, StorageAttachedIndex.class) .orElseThrow(() -> new IllegalStateException("No index found for order by clause")); - // Null if not specified explicitly in the CQL query. - Integer rerankK = filter.annOptions().rerankK; - return new Orderer(index.getIndexContext(), orderExpression.operator(), orderExpression.getIndexValue(), rerankK); + return new Orderer(index.getIndexContext(), orderExpression.operator(), orderExpression.getIndexValue(), filter.annOptions()); } public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) @@ -149,9 +162,9 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; - String rerankInfo = rerankK != null ? String.format(" (rerank_k=%d)", rerankK) : ""; + String annOptionsString = annOptions != null ? annOptions.toCQLString() : ""; if (isANN()) - return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction + rerankInfo; + return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction + annOptionsString; if (isBM25()) return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction; return context.getColumnName() + ' ' + direction; diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index ba8bfbcbd0ad..29039affe662 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -170,32 +170,55 @@ public static CloseableIterator computeScores(CloseableIt // Calculate average document length double avgDocLength = totalTermCount / documents.size(); - // Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity + // Calculate BM25 scores. + // Uses a NodeQueue that avoids allocating an object for each document. var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP); - for (int i = 0; i < documents.size(); i++) - { - var doc = documents.get(i); - double score = 0.0; - for (var queryTerm : queryTerms) - { - int tf = doc.getTermFrequency(queryTerm); - Long df = docStats.frequencies.get(queryTerm); - // we shouldn't have more hits for a term than we counted total documents - assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); - - double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength)); - double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); - double deltaScore = normalizedTf * idf; - assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", - tf, df, doc.termCount(), docStats.docCount, deltaScore); - score += deltaScore; + // Create an anonymous NodeScoreIterator that holds the logic for computing BM25 + var iter = new NodeQueue.NodeScoreIterator() { + int current = 0; + + @Override + public boolean hasNext() { + return current < documents.size(); } - nodeQueue.push(i, (float) score); - } + + @Override + public int pop() { + return current++; + } + + @Override + public float topScore() { + // Compute BM25 for the current document + return scoreDoc(documents.get(current), docStats, queryTerms, avgDocLength); + } + }; + // pushMany is an O(n) operation where n is the final size of the queue. Iterative calls to push is O(n log n). + nodeQueue.pushMany(iter, documents.size()); return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator); } + private static float scoreDoc(DocTF doc, DocStats docStats, List queryTerms, double avgDocLength) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + // we shouldn't have more hits for a term than we counted total documents + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); + + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount(), docStats.docCount, deltaScore); + score += deltaScore; + } + return (float) score; + } + private static class NodeQueueDocTFIterator extends AbstractIterator { private final NodeQueue nodeQueue; diff --git a/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java b/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java index 0f87fc279185..0c47de40e181 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java +++ b/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java @@ -20,6 +20,8 @@ import java.util.function.IntConsumer; +import io.github.jbellis.jvector.graph.NodeQueue; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import org.agrona.collections.IntIntConsumer; /** @@ -33,7 +35,7 @@ public class SegmentRowIdOrdinalPairs private final int[] array; /** - * Create a new IntIntPairArray with the given capacity. + * Create a new SegmentRowIdOrdinalPairs with the given capacity. * @param capacity the capacity */ public SegmentRowIdOrdinalPairs(int capacity) @@ -102,15 +104,53 @@ public void forEachSegmentRowIdOrdinalPair(IntIntConsumer consumer) } /** - * Iterate over the pairs in the array, calling the consumer for each pair passing (index, x, y). - * @param consumer the consumer to call for each pair + * Create an iterator over the segment row id and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal */ - public void forEachIndexOrdinalPair(IntIntConsumer consumer) + public NodeQueue.NodeScoreIterator mapToSegmentRowIdScoreIterator(ScoreFunction scoreFunction) { - for (int i = 0; i < size; i++) - consumer.accept(i, array[i * 2 + 1]); + return mapToScoreIterator(scoreFunction, false); + } + + /** + * Create an iterator over the index and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + */ + public NodeQueue.NodeScoreIterator mapToIndexScoreIterator(ScoreFunction scoreFunction) + { + return mapToScoreIterator(scoreFunction, true); } + /** + * Create an iterator over the index or the segment row id and the score for the ordinal. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + * @param mapToIndex whether to map to the index or the segment row id + */ + private NodeQueue.NodeScoreIterator mapToScoreIterator(ScoreFunction scoreFunction, boolean mapToIndex) + { + return new NodeQueue.NodeScoreIterator() + { + int i = 0; + + @Override + public boolean hasNext() + { + return i < size; + } + + @Override + public int pop() + { + return mapToIndex ? i++ : array[i++ * 2]; + } + + @Override + public float topScore() + { + return scoreFunction.similarityTo(array[i * 2 + 1]); + } + }; + } /** * Calls the consumer for each right value in each pair of the array. diff --git a/src/java/org/apache/cassandra/io/sstable/format/FilterComponent.java b/src/java/org/apache/cassandra/io/sstable/format/FilterComponent.java index eab8f4ce825f..488564c7cd60 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/FilterComponent.java +++ b/src/java/org/apache/cassandra/io/sstable/format/FilterComponent.java @@ -74,6 +74,12 @@ public static IFilter load(Descriptor descriptor) throws IOException public static void save(IFilter filter, Descriptor descriptor, boolean deleteOnFailure) throws IOException { + if (!(filter instanceof BloomFilter)) + { + logger.info("Skipped saving bloom filter {} for {} to disk", filter, descriptor); + return; + } + File filterFile = descriptor.fileFor(Components.FILTER); try (FileOutputStreamPlus stream = filterFile.newOutputStream(File.WriteMode.OVERWRITE)) { diff --git a/src/java/org/apache/cassandra/utils/FilterFactory.java b/src/java/org/apache/cassandra/utils/FilterFactory.java index 87db90bd8de9..bedaa10076ee 100644 --- a/src/java/org/apache/cassandra/utils/FilterFactory.java +++ b/src/java/org/apache/cassandra/utils/FilterFactory.java @@ -161,6 +161,12 @@ public boolean isInformative() { return false; } + + @Override + public String toString() + { + return "AlwaysPresentFilter"; + } } public interface FilterFactoryMetrics diff --git a/src/java/org/apache/cassandra/utils/obs/MemoryLimiter.java b/src/java/org/apache/cassandra/utils/obs/MemoryLimiter.java index bb2eb28a341d..295a223ae25e 100644 --- a/src/java/org/apache/cassandra/utils/obs/MemoryLimiter.java +++ b/src/java/org/apache/cassandra/utils/obs/MemoryLimiter.java @@ -39,7 +39,8 @@ public void increment(long bytesCount) throws ReachedMemoryLimitException { assert bytesCount >= 0; long bytesCountAfterAllocation = this.currentMemory.addAndGet(bytesCount); - if (bytesCountAfterAllocation >= maxMemory) + // if overflow or exceeded max memory + if (bytesCountAfterAllocation < 0 || bytesCountAfterAllocation >= maxMemory) { this.currentMemory.addAndGet(-bytesCount); diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java index a73e8b05a3b4..7e793706feb0 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/ANNOptionsDistributedTest.java @@ -98,16 +98,24 @@ private static void testSelectWithAnnOptions(Cluster cluster, String expectedErr cluster.schemaChange(withKeyspace("CREATE CUSTOM INDEX ON %s.t(v) USING 'StorageAttachedIndex'")); SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); - String select = withKeyspace("SELECT * FROM %s.t ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + String selectRerankk = withKeyspace("SELECT * FROM %s.t ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); + String selectUsePruning = withKeyspace("SELECT * FROM %s.t ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'use_pruning': false}"); for (int i = 1; i <= cluster.size(); i++) { ICoordinator coordinator = cluster.coordinator(i); if (expectedErrorMessage == null) - coordinator.execute(select, ConsistencyLevel.ONE); + { + coordinator.execute(selectRerankk, ConsistencyLevel.ONE); + coordinator.execute(selectUsePruning, ConsistencyLevel.ONE); + } else - Assertions.assertThatThrownBy(() -> coordinator.execute(select, ConsistencyLevel.ONE)) + { + Assertions.assertThatThrownBy(() -> coordinator.execute(selectRerankk, ConsistencyLevel.ONE)) .hasMessageContaining(expectedErrorMessage); + Assertions.assertThatThrownBy(() -> coordinator.execute(selectUsePruning, ConsistencyLevel.ONE)) + .hasMessageContaining(expectedErrorMessage); + } } } diff --git a/test/simulator/main/org/apache/cassandra/simulator/utils/CountingCollection.java b/test/simulator/main/org/apache/cassandra/simulator/utils/CountingCollection.java index 7e04fcdec0f3..d20993a49c1e 100644 --- a/test/simulator/main/org/apache/cassandra/simulator/utils/CountingCollection.java +++ b/test/simulator/main/org/apache/cassandra/simulator/utils/CountingCollection.java @@ -51,4 +51,10 @@ public int size() { return count; } + + @Override + public String toString() + { + return "AlwaysPresentFilter"; + } } diff --git a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java index 9c1a2d6a7bad..deb6ca0ba2fc 100644 --- a/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java +++ b/test/unit/org/apache/cassandra/db/filter/ANNOptionsTest.java @@ -90,19 +90,23 @@ public void testParseAndValidate() execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] ALLOW FILTERING WITH ann_options = {}"); execute("SELECT * FROM %s WHERE k=0 ORDER BY v ANN OF [1, 1] WITH ann_options = {}"); - // correct queries with specific ANN options + // correct queries with specific ANN options - rerank_k + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': -1}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': 0}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 11}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 1000}"); execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': '1000'}"); - // Queries with invalid ann options that will eventually be valid when we support disabling reranking - assertInvalidThrowMessage("Invalid rerank_k value -1 lesser than limit 100", - InvalidRequestException.class, - "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': -1}"); - assertInvalidThrowMessage("Invalid rerank_k value 0 lesser than limit 100", - InvalidRequestException.class, - "SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 100 WITH ann_options = {'rerank_k': 0}"); + // correct queries with specific ANN options - use_pruning + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': true}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': false}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': 'true'}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': 'false'}"); + + // correct queries with both options + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 10, 'use_pruning': true}"); + execute("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'use_pruning': false, 'rerank_k': 20}"); // Queries that exceed the failure threshold for the guardrail. Specifies a protocol version to trigger // validation in the coordinator. @@ -144,10 +148,18 @@ public void testParseAndValidate() baseQuery + " WITH ann_options = {'rerank_k': 'a'}"); // ANN options with rerank lesser than limit - assertInvalidThrowMessage("Invalid rerank_k value 10 lesser than limit 100", + assertInvalidThrowMessage("Invalid rerank_k value 10 greater than 0 and less than limit 100", InvalidRequestException.class, baseQuery + "LIMIT 100 WITH ann_options = {'rerank_k': 10}"); + // invalid use_pruning values + assertInvalidThrowMessage("Invalid 'use_pruning' ANN option. Expected a boolean but found: notaboolean", + InvalidRequestException.class, + baseQuery + " WITH ann_options = {'use_pruning': 'notaboolean'}"); + assertInvalidThrowMessage("Invalid 'use_pruning' ANN option. Expected a boolean but found: 42", + InvalidRequestException.class, + baseQuery + " WITH ann_options = {'use_pruning': '42'}"); + // ANN options without ORDER BY ANN with empty options assertInvalidThrowMessage(StatementRestrictions.ANN_OPTIONS_WITHOUT_ORDER_BY_ANN, InvalidRequestException.class, @@ -178,10 +190,20 @@ public void testToCQLString() ReadCommand command = parseReadCommand(formattedQuery); Assertions.assertThat(command.toCQLString()).doesNotContain("WITH ann_options"); - // with ANN options + // with rerank_k option formattedQuery = formatQuery("SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 1 WITH ann_options = {'rerank_k': 2}"); command = parseReadCommand(formattedQuery); Assertions.assertThat(command.toCQLString()).contains("WITH ann_options = {'rerank_k': 2}"); + + // with use_pruning option + formattedQuery = formatQuery("SELECT * FROM %%s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': true}"); + command = parseReadCommand(formattedQuery); + Assertions.assertThat(command.toCQLString()).contains("WITH ann_options = {'use_pruning': true}"); + + // with both options + formattedQuery = formatQuery("SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 1 WITH ann_options = {'rerank_k': 2, 'use_pruning': false}"); + command = parseReadCommand(formattedQuery); + Assertions.assertThat(command.toCQLString()).contains("WITH ann_options = {'rerank_k': 2, 'use_pruning': false}"); } /** @@ -194,24 +216,35 @@ public void testTransport() createTable("CREATE TABLE %s (k int PRIMARY KEY, n int, v vector)"); createIndex(String.format("CREATE CUSTOM INDEX ON %%s(v) USING '%s'", ANNIndex.class.getName())); - // unespecified ANN options, should be mapped to NONE + // unspecified ANN options, should be mapped to NONE testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1]", ANNOptions.NONE); testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {}", ANNOptions.NONE); - // TODO re-enable this test when we support negative rerank_k values // some random negative values, all should be accepted and not be mapped to NONE -// String negativeQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': %d}"; -// QuickTheory.qt() -// .withExamples(100) -// .forAll(integers().allPositive()) -// .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i))); - - // some random positive values, all should be accepted - String positiveQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT %d WITH ann_options = {'rerank_k': % testTransport(String.format(positiveQuery, i), ANNOptions.create(i))); + .checkAssert(i -> testTransport(String.format(negativeQuery, -i), ANNOptions.create(-i, null))); + + // rerankK = 0 must also work + testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] LIMIT 10 WITH ann_options = {'rerank_k': 0}", ANNOptions.create(0, null)); + + // test use_pruning values + testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': true}", + ANNOptions.create(null, true)); + testTransport("SELECT * FROM %s ORDER BY v ANN OF [1, 1] WITH ann_options = {'use_pruning': false}", + ANNOptions.create(null, false)); + + // test combinations of rerank_k and use_pruning + String combinedQuery = "SELECT * FROM %%s ORDER BY v ANN OF [1, 1] LIMIT %d WITH ann_options = {'rerank_k': % { + testTransport(String.format(combinedQuery, i, true), ANNOptions.create(i, true)); + testTransport(String.format(combinedQuery, i, false), ANNOptions.create(i, false)); + }); } private void testTransport(String query, ANNOptions expectedOptions) @@ -269,25 +302,32 @@ private void testTransport(String query, ANNOptions expectedOptions) } /** - * Tests that any future versions of {@link ANNOptions} can be able to read the current options. + * Tests that the current version of {@link ANNOptions} can correctly serialize and deserialize all combinations + * of current options. */ @Test - public void testSerializationForFutureVersions() throws IOException + public void testCurrentVersionSerialization() throws IOException { - // the current version of the ANN options... - ANNOptions sentOptions = ANNOptions.create(7); - DataOutputBuffer out = new DataOutputBuffer(); - ANNOptions.serializer.serialize(sentOptions, out, MessagingService.current_version); - int serializedSize = out.buffer().remaining(); - Assertions.assertThat(ANNOptions.serializer.serializedSize(sentOptions, MessagingService.current_version)) - .isEqualTo(serializedSize); + // Test different combinations of options + ANNOptions[] optionsToTest = { + ANNOptions.NONE, + ANNOptions.create(7, null), + ANNOptions.create(null, true), + ANNOptions.create(7, false) + }; + + for (ANNOptions options : optionsToTest) + { + DataOutputBuffer out = new DataOutputBuffer(); + ANNOptions.serializer.serialize(options, out, MessagingService.current_version); + int serializedSize = out.buffer().remaining(); + Assertions.assertThat(ANNOptions.serializer.serializedSize(options, MessagingService.current_version)) + .isEqualTo(serializedSize); - // ...should be readable with the future serializer - DataInputBuffer in = new DataInputBuffer(out.buffer(), true); - FutureANNOptions receivedOptions = FutureANNOptions.serializer.deserialize(in); - Assertions.assertThat(receivedOptions).isEqualTo(new FutureANNOptions(sentOptions)); - Assertions.assertThat(FutureANNOptions.serializer.serializedSize(receivedOptions)) - .isEqualTo(serializedSize); + DataInputBuffer in = new DataInputBuffer(out.buffer(), true); + ANNOptions deserialized = ANNOptions.serializer.deserialize(in, MessagingService.current_version); + Assertions.assertThat(deserialized).isEqualTo(options); + } } /** @@ -308,7 +348,7 @@ public void testDeserializationOfCompatibleFutureVersions() throws IOException // ...should be readable with the current serializer DataInputBuffer in = new DataInputBuffer(out.buffer(), true); ANNOptions receivedOptions = ANNOptions.serializer.deserialize(in, MessagingService.current_version); - Assertions.assertThat(receivedOptions).isEqualTo(ANNOptions.create(sentOptions.rerankK)); + Assertions.assertThat(receivedOptions).isEqualTo(ANNOptions.create(sentOptions.rerankK, null)); Assertions.assertThat(ANNOptions.serializer.serializedSize(receivedOptions, MessagingService.current_version)) .isEqualTo(serializedSize); } diff --git a/test/unit/org/apache/cassandra/dht/LengthPartitioner.java b/test/unit/org/apache/cassandra/dht/LengthPartitioner.java index a5d8c3c42ab9..4d7e24f45288 100644 --- a/test/unit/org/apache/cassandra/dht/LengthPartitioner.java +++ b/test/unit/org/apache/cassandra/dht/LengthPartitioner.java @@ -125,6 +125,11 @@ public Token fromByteArray(ByteBuffer bytes) return new BigIntegerToken(ByteBufferUtil.toLong(bytes)); } + public Token fromLongValue(long longValue) + { + return new BigIntegerToken(longValue); + } + @Override public Token fromComparableBytes(ByteSource.Peekable comparableBytes, ByteComparable.Version version) { diff --git a/test/unit/org/apache/cassandra/dht/Murmur3PartitionerTest.java b/test/unit/org/apache/cassandra/dht/Murmur3PartitionerTest.java index 9a02fb086cd0..bd5bf2f9d247 100644 --- a/test/unit/org/apache/cassandra/dht/Murmur3PartitionerTest.java +++ b/test/unit/org/apache/cassandra/dht/Murmur3PartitionerTest.java @@ -77,5 +77,16 @@ public void testLongTokenInverse() return Murmur3Partitioner.instance.getToken(key).token == token; }); } + + @Test + public void testFromLongValue() + { + qt().forAll(longs().between(Long.MIN_VALUE + 1, Long.MAX_VALUE)) + .check(token -> { + Token fromLongValue = Murmur3Partitioner.instance.getTokenFactory().fromLongValue(token); + Token constructed = new Murmur3Partitioner.LongToken(token); + return constructed.equals(fromLongValue); + }); + } } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 8f43b0f66ab3..0660e702062e 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -21,12 +21,17 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.apache.cassandra.index.sai.SSTableIndex; +import org.apache.cassandra.index.sai.memory.MemtableIndex; +import org.apache.cassandra.index.sai.memory.TrieMemoryIndex; +import org.apache.cassandra.index.sai.memory.TrieMemtableIndex; import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; @@ -816,6 +821,71 @@ public void testOrderingSeveralSegments() throws Throwable "climate"); } + /** + * Asserts that memtable SAI index maintains expected row count, which is, then, + * used to store row count in SSTable SAI index and its segments. This is also + * asserted. + */ + @Test + public void testIndexMetaForNumRows() + { + createTable("CREATE TABLE %s (id int PRIMARY KEY, category text, score int, " + + "title text, body text, bodyset set, " + + "map_category map, map_body map)"); + String bodyIndexName = createAnalyzedIndex("body", true); + String scoreIndexName = createIndex("CREATE CUSTOM INDEX ON %s (score) USING 'StorageAttachedIndex'"); + String mapIndexName = createIndex("CREATE CUSTOM INDEX ON %s (map_category) USING 'StorageAttachedIndex'"); + insertCollectionData(); + + assertNumRowsMemtable(scoreIndexName, DATASET.length); + assertNumRowsMemtable(bodyIndexName, DATASET.length); + assertNumRowsMemtable(mapIndexName, DATASET.length); + execute("DELETE FROM %s WHERE id = ?", 5); + flush(); + assertNumRowsSSTable(scoreIndexName, DATASET.length - 1); + assertNumRowsSSTable(bodyIndexName, DATASET.length - 1); + assertNumRowsSSTable(mapIndexName, DATASET.length - 1); + execute("DELETE FROM %s WHERE id = ?", 10); + flush(); + assertNumRowsSSTable(scoreIndexName, DATASET.length - 1); + assertNumRowsSSTable(bodyIndexName, DATASET.length - 1); + assertNumRowsSSTable(mapIndexName, DATASET.length - 1); + compact(); + assertNumRowsSSTable(scoreIndexName, DATASET.length - 2); + assertNumRowsSSTable(bodyIndexName, DATASET.length - 2); + assertNumRowsSSTable(mapIndexName, DATASET.length - 2); + } + + private void assertNumRowsMemtable(String indexName, int expectedNumRows) + { + int rowCount = 0; + + for (var memtable : getCurrentColumnFamilyStore().getAllMemtables()) + { + MemtableIndex memIndex = getIndexContext(indexName).getMemtableIndex(memtable); + assert memIndex instanceof TrieMemtableIndex; + rowCount = Arrays.stream(((TrieMemtableIndex) memIndex).getRangeIndexes()) + .map(index -> ((TrieMemoryIndex) index).getDocLengths().size()) + .mapToInt(Integer::intValue).sum(); + } + assertEquals(expectedNumRows, rowCount); + } + + private void assertNumRowsSSTable(String indexName, int expectedNumRows) + { + long indexRowCount = 0; + long segmentRowCount = 0; + for (SSTableIndex sstableIndex : getIndexContext(indexName).getView()) + { + indexRowCount += sstableIndex.getRowCount(); + segmentRowCount += sstableIndex.getSegments().stream() + .map(s -> Objects.requireNonNull(s.metadata).numRows) + .mapToLong(Long::longValue).sum(); + } + assertEquals(indexRowCount, segmentRowCount); + assertEquals(expectedNumRows, indexRowCount); + } + private final static Object[][] DATASET = { { 1, "Climate", 5, "Climate change is a pressing issue. Climate patterns are shifting globally. Scientists study climate data daily.", 1 }, diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorHybridSearchTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorHybridSearchTest.java index b5952d8ce348..4f14e84c6a9f 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorHybridSearchTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorHybridSearchTest.java @@ -18,10 +18,14 @@ package org.apache.cassandra.index.sai.cql; +import java.util.stream.IntStream; + import org.junit.Test; import org.apache.cassandra.index.sai.plan.QueryController; +import static org.apache.cassandra.index.sai.disk.vector.CassandraOnHeapGraph.MIN_PQ_ROWS; + public class VectorHybridSearchTest extends VectorTester.VersionedWithChecksums { @Test @@ -214,4 +218,36 @@ public void testHybridQueryWithMissingVectorValuesForMaxSegmentRow() throws Thro assertRows(execute("SELECT i FROM %s WHERE c >= 1 ORDER BY v ANN OF [1,1] LIMIT 1"), row(1)); }); } + + @Test + public void testReranklessHybridSearch() + { + // Want to test the search then order path + QueryController.QUERY_OPT_LEVEL = 0; + + createTable("CREATE TABLE %s (pk int, val int, vec vector, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(vec) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // Insert many rows in parallel + IntStream.range(0, MIN_PQ_ROWS * 2).parallel().forEach(i -> { + execute("INSERT INTO %s (pk, val, vec) VALUES (?, ?, ?)", i, i, randomVectorBoxed(128)); + }); + + flush(); + + // Search the graph with rerankless search. We restrict val to 1/4th of the dataset + setMaxBruteForceRows(0); + var result = execute("SELECT pk FROM %s WHERE val < ? ORDER BY vec ANN OF ? LIMIT 10 with ann_options = { 'rerank_k': 0 }", MIN_PQ_ROWS / 2, randomVectorBoxed(128)); + // Just testing that we can run the query, so only assert that we got results + assertRowCount(result, 10); + + // Now search with brute force (skipping the graph). We restrict to 1/10th of the dataset to trigger brute force. + setMaxBruteForceRows(MIN_PQ_ROWS * 2); + result = execute("SELECT pk FROM %s WHERE val < ? ORDER BY vec ANN OF ? LIMIT 10 with ann_options = { 'rerank_k': 0 }", MIN_PQ_ROWS / 5, randomVectorBoxed(128)); + assertRowCount(result, 10); + // Also ensure that a negative rerank_k value works + result = execute("SELECT pk FROM %s WHERE val < ? ORDER BY vec ANN OF ? LIMIT 10 with ann_options = { 'rerank_k': -1 }", MIN_PQ_ROWS / 5, randomVectorBoxed(128)); + assertRowCount(result, 10); + } } diff --git a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java index c66f6889a7a3..8fc9b37c17e8 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/VectorSiftSmallTest.java @@ -71,6 +71,9 @@ public void testSiftSmall() throws Throwable // Run a few queries with increasing rerank_k to validate that recall increases ensureIncreasingRerankKIncreasesRecall(queryVectors, groundTruth); + // Run queries with and without pruning to validate recall improves + ensureDisablingPruningIncreasesRecall(queryVectors, groundTruth); + flush(); var diskRecall = testRecall(100, queryVectors, groundTruth); assertTrue("Disk recall is " + diskRecall, diskRecall > 0.975); @@ -85,10 +88,18 @@ private void ensureIncreasingRerankKIncreasesRecall(List queryVectors, double previousRecall = 0; int limit = 10; int strictlyIncreasedCount = 0; - // Testing shows that we acheive 100% recall at about rerank_k = 45, so no need to go higher + + // First test with rerank_k = 0, which should have the worst recall + double zeroRerankRecall = testRecall(limit, queryVectors, groundTruth, 0, null); + + // Testing shows that we achieve 100% recall at about rerank_k = 45, so no need to go higher for (int rerankK = limit; rerankK <= 50; rerankK += 5) { - var recall = testRecall(limit, queryVectors, groundTruth, rerankK); + var recall = testRecall(limit, queryVectors, groundTruth, rerankK, null); + // All recalls should be better than rerank_k = 0 + assertTrue("Recall for rerank_k = " + rerankK + " should be at least as good as with rerank_k = 0", + recall >= zeroRerankRecall); + // Recall varies, so we can only assert that it does not get worse on a per-run basis. However, it should // get better strictly at least some of the time assertTrue("Recall for rerank_k = " + rerankK + " is " + recall, recall >= previousRecall); @@ -102,6 +113,21 @@ private void ensureIncreasingRerankKIncreasesRecall(List queryVectors, strictlyIncreasedCount > 3); } + private void ensureDisablingPruningIncreasesRecall(List queryVectors, List> groundTruth) + { + int limit = 10; + // Test with pruning enabled + double recallWithPruning = testRecall(limit, queryVectors, groundTruth, null, true); + + // Test with pruning disabled + double recallWithoutPruning = testRecall(limit, queryVectors, groundTruth, null, false); + + // Recall should be at least as good when pruning is disabled + assertTrue("Recall without pruning (" + recallWithoutPruning + + ") should be at least as good as recall with pruning (" + recallWithPruning + ')', + recallWithoutPruning >= recallWithPruning); + } + @Test public void testCompaction() throws Throwable { @@ -228,10 +254,10 @@ private static ArrayList> readIvecs(String filename) public double testRecall(int topK, List queryVectors, List> groundTruth) { - return testRecall(topK, queryVectors, groundTruth, null); + return testRecall(topK, queryVectors, groundTruth, null, null); } - public double testRecall(int topK, List queryVectors, List> groundTruth, Integer rerankK) + public double testRecall(int topK, List queryVectors, List> groundTruth, Integer rerankK, Boolean usePruning) { AtomicInteger topKfound = new AtomicInteger(0); @@ -242,11 +268,28 @@ public double testRecall(int topK, List queryVectors, List options = new ArrayList<>(); + + if (rerankK != null) + options.add("'rerank_k': " + rerankK); + + if (usePruning != null) + options.add("'use_pruning': " + usePruning); + + query.append(String.join(", ", options)); + query.append('}'); + } - UntypedResultSet result = execute(query); + UntypedResultSet result = execute(query.toString()); var gt = groundTruth.get(i); assert topK <= gt.size(); // we don't care about order within the topK but we do need to restrict the size first diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index a15b11f6abf3..d6c27fb1df4c 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -49,9 +49,8 @@ import org.apache.cassandra.service.StorageService; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; +import org.assertj.core.api.Assertions; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -194,6 +193,7 @@ private IndexSearcher buildIndexAndOpenSearcher(int terms, List i).sum()); } final SegmentMetadata segmentMetadata = metadataBuilder.build(); @@ -207,7 +207,7 @@ private IndexSearcher buildIndexAndOpenSearcher(int terms, List> c UTF8Type.instance.fromString("d")); metadataBuilder.setKeyRange(SAITester.TEST_FACTORY.createTokenOnly(Murmur3Partitioner.instance.decorateKey(UTF8Type.instance.fromString("a")).getToken()), SAITester.TEST_FACTORY.createTokenOnly(Murmur3Partitioner.instance.decorateKey(UTF8Type.instance.fromString("b")).getToken())); + metadataBuilder.setNumRows(size); metadata = metadataBuilder.build(); } @@ -192,7 +191,7 @@ protected Pair> c when(sstableContext.usedPerSSTableComponents()).thenReturn(indexDescriptor.perSSTableComponents()); IndexSearcher searcher = Version.latest().onDiskFormat().newIndexSearcher(sstableContext, indexContext, indexFiles, metadata); - assertThat(searcher, is(instanceOf(KDTreeIndexSearcher.class))); + assertThat(searcher).isInstanceOf(KDTreeIndexSearcher.class); return (KDTreeIndexSearcher) searcher; } } @@ -293,7 +292,7 @@ public static IndexSearcher buildShortSearcher(IndexDescriptor indexDescriptor, */ public static AbstractGuavaIterator> singleOrd(Iterator terms, AbstractType type, int segmentRowIdOffset, int size) { - return new AbstractGuavaIterator>() + return new AbstractGuavaIterator<>() { private long currentTerm = 0; private int currentSegmentRowId = segmentRowIdOffset; @@ -343,11 +342,13 @@ public static Iterator longRange(long startInclusive, long endExclus public static Iterator decimalRange(final BigDecimal startInclusive, final BigDecimal endExclusive) { int n = endExclusive.subtract(startInclusive).intValueExact() * 10; - final Supplier generator = new Supplier() { + final Supplier generator = new Supplier<>() + { BigDecimal current = startInclusive; @Override - public BigDecimal get() { + public BigDecimal get() + { BigDecimal result = current; current = current.add(ONE_TENTH); return result; @@ -363,11 +364,13 @@ public BigDecimal get() { public static Iterator bigIntegerRange(final BigInteger startInclusive, final BigInteger endExclusive) { int n = endExclusive.subtract(startInclusive).intValueExact(); - final Supplier generator = new Supplier() { + final Supplier generator = new Supplier<>() + { BigInteger current = startInclusive; @Override - public BigInteger get() { + public BigInteger get() + { BigInteger result = current; current = current.add(BigInteger.ONE); return result; diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java index 1f38c80a0cb5..11fee5ff2f13 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java @@ -41,6 +41,7 @@ import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.filter.ANNOptions; import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.db.marshal.VectorType; @@ -212,7 +213,7 @@ public void indexIteratorTest() private Orderer randomVectorOrderer() { - return new Orderer(indexContext, Operator.ANN, randomVectorSerialized(), null); + return new Orderer(indexContext, Operator.ANN, randomVectorSerialized(), ANNOptions.NONE); } private ByteBuffer randomVectorSerialized() { diff --git a/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java b/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java index 87daed0e609e..2f05cd4cb0a2 100644 --- a/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java @@ -16,6 +16,8 @@ package org.apache.cassandra.index.sai.utils; +import io.github.jbellis.jvector.graph.NodeQueue; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -23,7 +25,9 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; public class SegmentRowIdOrdinalPairsTest { @@ -95,32 +99,6 @@ public void testForEachSegmentRowIdOrdinalPair() assertEquals(Integer.valueOf(30), ordinals.get(2)); } - @Test - public void testForEachIndexOrdinalPair() - { - SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); - pairs.add(1, 10); - pairs.add(2, 20); - pairs.add(3, 30); - - List indices = new ArrayList<>(); - List ordinals = new ArrayList<>(); - - pairs.forEachIndexOrdinalPair((index, ordinal) -> { - indices.add(index); - ordinals.add(ordinal); - }); - - assertEquals(3, indices.size()); - assertEquals(3, ordinals.size()); - assertEquals(Integer.valueOf(0), indices.get(0)); - assertEquals(Integer.valueOf(10), ordinals.get(0)); - assertEquals(Integer.valueOf(1), indices.get(1)); - assertEquals(Integer.valueOf(20), ordinals.get(1)); - assertEquals(Integer.valueOf(2), indices.get(2)); - assertEquals(Integer.valueOf(30), ordinals.get(2)); - } - @Test public void testGetSegmentRowIdAndOrdinalBoundaryChecks() { @@ -158,9 +136,6 @@ public void testOperationsOnEmptyArray() pairs.forEachSegmentRowIdOrdinalPair((x, y) -> count.incrementAndGet()); assertEquals(0, count.get()); - - pairs.forEachIndexOrdinalPair((x, y) -> count.incrementAndGet()); - assertEquals(0, count.get()); } @Test @@ -170,4 +145,81 @@ public void testZeroCapacity() assertEquals(0, pairs.size()); assertThrows(IndexOutOfBoundsException.class, () -> pairs.add(1, 10)); } + + @Test + public void testMapToSegmentRowIdScoreIterator() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); + pairs.add(1, 10); + pairs.add(2, 20); + pairs.add(3, 30); + + // Create a simple score function that returns the ordinal value divided by 10 as the score + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator iterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction); + + // Test first pair + assertTrue(iterator.hasNext()); + assertEquals(1.0f, iterator.topScore(), 0.001f); + assertEquals(1, iterator.pop()); + + // Test second pair + assertTrue(iterator.hasNext()); + assertEquals(2.0f, iterator.topScore(), 0.001f); + assertEquals(2, iterator.pop()); + + // Test third pair + assertTrue(iterator.hasNext()); + assertEquals(3.0f, iterator.topScore(), 0.001f); + assertEquals(3, iterator.pop()); + + // Test end of iteration + assertFalse(iterator.hasNext()); + } + + @Test + public void testMapToIndexScoreIterator() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); + pairs.add(1, 10); + pairs.add(2, 20); + pairs.add(3, 30); + + // Create a simple score function that returns the ordinal value divided by 10 as the score + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator iterator = pairs.mapToIndexScoreIterator(scoreFunction); + + // Test first pair + assertTrue(iterator.hasNext()); + assertEquals(1.0f, iterator.topScore(), 0.001f); + assertEquals(0, iterator.pop()); + + // Test second pair + assertTrue(iterator.hasNext()); + assertEquals(2.0f, iterator.topScore(), 0.001f); + assertEquals(1, iterator.pop()); + + // Test third pair + assertTrue(iterator.hasNext()); + assertEquals(3.0f, iterator.topScore(), 0.001f); + assertEquals(2, iterator.pop()); + + // Test end of iteration + assertFalse(iterator.hasNext()); + } + + @Test + public void testEmptyIterators() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(0); + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator segmentRowIdIterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction); + assertFalse(segmentRowIdIterator.hasNext()); + + NodeQueue.NodeScoreIterator indexIterator = pairs.mapToIndexScoreIterator(scoreFunction); + assertFalse(indexIterator.hasNext()); + } } diff --git a/test/unit/org/apache/cassandra/inject/Injections.java b/test/unit/org/apache/cassandra/inject/Injections.java index c19ee5ec067b..63328e038e42 100644 --- a/test/unit/org/apache/cassandra/inject/Injections.java +++ b/test/unit/org/apache/cassandra/inject/Injections.java @@ -612,7 +612,7 @@ public static class PauseBuilder extends SingleActionBuilder sstable.onDiskLength()); } + @Test + public void testSSTableFlushBloomFilterReachedLimit() throws Exception + { + final int numKeys = 100; // will use about 128 bytes + final Keyspace keyspace = Keyspace.open(KEYSPACE1); + final ColumnFamilyStore cfs = keyspace.getColumnFamilyStore(CF_STANDARD); + + SSTableReader sstable; + long bfSpace = BloomFilter.memoryLimiter.maxMemory - BloomFilter.memoryLimiter.memoryAllocated() - 100; + try + { + BloomFilter.memoryLimiter.increment(bfSpace); + sstable = getNewSSTable(cfs, numKeys, 1); + Assert.assertFalse(PathUtils.exists(sstable.descriptor.pathFor(Components.FILTER))); + Assert.assertSame(FilterFactory.AlwaysPresent, getFilter(sstable)); + } + finally + { + // reset + BloomFilter.memoryLimiter.decrement(bfSpace); + } + } + private void checkSSTableOpenedWithGivenFPChance(ColumnFamilyStore cfs, SSTableReader sstable, double fpChance, boolean bfShouldExist, int numKeys, boolean expectRecreated) throws IOException { Descriptor desc = sstable.descriptor;