Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ public enum CassandraRelevantProperties
SAI_VALIDATE_TERMS_AT_COORDINATOR("cassandra.sai.validate_terms_at_coordinator", "true"),
/** Controls the maximum top-k limit for vector search */
SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"),
SAI_VECTOR_USE_PRUNING_DEFAULT("cassandra.sai.jvector.use_pruning_default", "1000"),
SAI_WRITE_JVECTOR3_FORMAT("cassandra.sai.write_jv3_format", "false"),

SCHEMA_PULL_INTERVAL_MS("cassandra.schema_pull_interval_ms", "60000"),
Expand Down
78 changes: 63 additions & 15 deletions src/java/org/apache/cassandra/db/filter/ANNOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
public class ANNOptions
{
public static final String RERANK_K_OPTION_NAME = "rerank_k";
public static final String USE_PRUNING_OPTION_NAME = "use_pruning";

public static final ANNOptions NONE = new ANNOptions(null);
public static final ANNOptions NONE = new ANNOptions(null, null);

public static final Serializer serializer = new Serializer();

Expand All @@ -51,29 +52,39 @@ public class ANNOptions
@Nullable
public final Integer rerankK;

private ANNOptions(@Nullable Integer rerankK)
/**
* Whether to use pruning to speed up the ANN search. If {@code null}, the default value is used.
*/
@Nullable
public final Boolean usePruning;

private ANNOptions(@Nullable Integer rerankK, @Nullable Boolean usePruning)
{
this.rerankK = rerankK;
this.usePruning = usePruning;
}

public static ANNOptions create(@Nullable Integer rerankK)
public static ANNOptions create(@Nullable Integer rerankK, @Nullable Boolean usePruning)
{
// if all the options are null, return the NONE instance
return rerankK == null ? NONE : new ANNOptions(rerankK);
return rerankK == null && usePruning == null ? NONE : new ANNOptions(rerankK, usePruning);
}

/**
* Validates the ANN options by checking that they are within the guardrails and that peers support the options.
*/
public void validate(ClientState state, String keyspace, int limit)
{
if (rerankK == null)
if (rerankK == null && usePruning == null)
return;

if (rerankK < limit)
throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit));
if (rerankK != null)
{
if (rerankK > 0 && rerankK < limit)
throw new InvalidRequestException(String.format("Invalid rerank_k value %d greater than 0 and less than limit %d", rerankK, limit));

Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state);
Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state);
}

// Ensure that all nodes in the cluster are in a version that supports ANN options, including this one
assert keyspace != null;
Expand All @@ -93,6 +104,7 @@ public void validate(ClientState state, String keyspace, int limit)
public static ANNOptions fromMap(Map<String, String> map)
{
Integer rerankK = null;
Boolean usePruning = null;

for (Map.Entry<String, String> entry : map.entrySet())
{
Expand All @@ -103,13 +115,17 @@ public static ANNOptions fromMap(Map<String, String> map)
{
rerankK = parseRerankK(value);
}
else if (name.equals(USE_PRUNING_OPTION_NAME))
{
usePruning = parseUsePruning(value);
}
else
{
throw new InvalidRequestException("Unknown ANN option: " + name);
}
}

return ANNOptions.create(rerankK);
return ANNOptions.create(rerankK, usePruning);
}

private static int parseRerankK(String value)
Expand All @@ -129,9 +145,28 @@ private static int parseRerankK(String value)
return rerankK;
}

private static boolean parseUsePruning(String value)
{
value = value.toLowerCase();
if (!value.equals("true") && !value.equals("false"))
throw new InvalidRequestException(String.format("Invalid '%s' ANN option. Expected a boolean but found: %s",
USE_PRUNING_OPTION_NAME, value));
return Boolean.parseBoolean(value);
}

public String toCQLString()
{
return String.format("{'%s': %d}", RERANK_K_OPTION_NAME, rerankK);
StringBuilder sb = new StringBuilder("{");
if (rerankK != null)
sb.append(String.format("'%s': %d", RERANK_K_OPTION_NAME, rerankK));
if (usePruning != null)
{
if (rerankK != null)
sb.append(", ");
sb.append(String.format("'%s': %b", USE_PRUNING_OPTION_NAME, usePruning));
}
sb.append("}");
return sb.toString();
}

@Override
Expand All @@ -140,13 +175,14 @@ public boolean equals(Object o)
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ANNOptions that = (ANNOptions) o;
return Objects.equals(rerankK, that.rerankK);
return Objects.equals(rerankK, that.rerankK) &&
Objects.equals(usePruning, that.usePruning);
}

@Override
public int hashCode()
{
return Objects.hash(rerankK);
return Objects.hash(rerankK, usePruning);
}

/**
Expand All @@ -166,9 +202,9 @@ public static class Serializer
{
/** Bit flags mask to check if the rerank K option is present. */
private static final int RERANK_K_MASK = 1;

private static final int USE_PRUNING_MASK = 2;
/** Bit flags mask to check if there are any unknown options. It's the negation of all the known flags. */
private static final int UNKNOWN_OPTIONS_MASK = ~RERANK_K_MASK;
private static final int UNKNOWN_OPTIONS_MASK = ~(RERANK_K_MASK | USE_PRUNING_MASK);

/*
* If you add a new option, then update ANNOptionsTest.FutureANNOptions and possibly add a new test verifying
Expand All @@ -190,6 +226,8 @@ public void serialize(ANNOptions options, DataOutputPlus out, int version) throw

if (options.rerankK != null)
out.writeUnsignedVInt32(options.rerankK);
if (options.usePruning != null)
out.writeBoolean(options.usePruning);
}

public ANNOptions deserialize(DataInputPlus in, int version) throws IOException
Expand All @@ -206,8 +244,9 @@ public ANNOptions deserialize(DataInputPlus in, int version) throws IOException
"new options that are not supported by this node.");

Integer rerankK = hasRerankK(flags) ? (int) in.readUnsignedVInt() : null;
Boolean usePruning = hasUsePruning(flags) ? in.readBoolean() : null;

return ANNOptions.create(rerankK);
return ANNOptions.create(rerankK, usePruning);
}

public long serializedSize(ANNOptions options, int version)
Expand All @@ -221,6 +260,8 @@ public long serializedSize(ANNOptions options, int version)

if (options.rerankK != null)
size += TypeSizes.sizeofUnsignedVInt(options.rerankK);
if (options.usePruning != null)
size += TypeSizes.sizeof(options.usePruning);

return size;
}
Expand All @@ -234,6 +275,8 @@ private static int flags(ANNOptions options)

if (options.rerankK != null)
flags |= RERANK_K_MASK;
if (options.usePruning != null)
flags |= USE_PRUNING_MASK;

return flags;
}
Expand All @@ -242,5 +285,10 @@ private static boolean hasRerankK(int flags)
{
return (flags & RERANK_K_MASK) == RERANK_K_MASK;
}

private static boolean hasUsePruning(int flags)
{
return (flags & USE_PRUNING_MASK) == USE_PRUNING_MASK;
}
}
}
6 changes: 6 additions & 0 deletions src/java/org/apache/cassandra/dht/Murmur3Partitioner.java
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ public Token fromByteArray(ByteBuffer bytes)
return new LongToken(ByteBufferUtil.toLong(bytes));
}

@Override
public Token fromLongValue(long token)
{
return new LongToken(token);
}

@Override
public Token fromByteBuffer(ByteBuffer bytes, int position, int length)
{
Expand Down
14 changes: 14 additions & 0 deletions src/java/org/apache/cassandra/dht/Token.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ public static abstract class TokenFactory
public abstract ByteBuffer toByteArray(Token token);
public abstract Token fromByteArray(ByteBuffer bytes);

/**
* This method exists so that callers can create tokens from the primitive {@code long} value for this {@link Token}, if
* one exits. It is especially useful to skip ByteBuffer serde operations where performance is critical.
*
* @param token the primitive {@code long} value of this token
* @return the {@link Token} instance corresponding to the given primitive {@code long} value
*
* @throws UnsupportedOperationException if this {@link Token} is not backed by a primitive {@code long} value
*/
public Token fromLongValue(long token)
{
throw new UnsupportedOperationException();
}

/**
* Produce a byte-comparable representation of the token.
* See {@link Token#asComparableBytes}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ public void complete(Stopwatch stopwatch) throws IOException

private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType<?> termComparator, MemtableTermsIterator terms, int maxSegmentRowId) throws IOException
{
long numPostings;
long numRows;
SegmentMetadataBuilder metadataBuilder = new SegmentMetadataBuilder(0, perIndexComponents);
SegmentMetadata.ComponentMetadataMap indexMetas;
Expand All @@ -166,7 +167,8 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType<?> ter
);

indexMetas = writer.writeAll(metadataBuilder.intercept(terms), docLengths);
numRows = writer.getPostingsCount();
numPostings = writer.getPostingsCount();
numRows = docLengths.size();
}
}
else
Expand All @@ -180,18 +182,20 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType<?> ter
{
ImmutableOneDimPointValues values = ImmutableOneDimPointValues.fromTermEnum(terms, termComparator);
indexMetas = writer.writeAll(metadataBuilder.intercept(values));
numRows = writer.getPointCount();
numPostings = writer.getPointCount();
numRows = numPostings;
}
}

// If no rows were written we need to delete any created column index components
// so that the index is correctly identified as being empty (only having a completion marker)
if (numRows == 0)
if (numPostings == 0)
{
perIndexComponents.forceDeleteAllComponents();
return 0;
}

metadataBuilder.setNumRows(numRows);
metadataBuilder.setKeyRange(pkFactory.createPartitionKeyOnly(minKey), pkFactory.createPartitionKeyOnly(maxKey));
metadataBuilder.setRowIdRange(terms.getMinSSTableRowId(), terms.getMaxSSTableRowId());
metadataBuilder.setTermRange(terms.getMinTerm(), terms.getMaxTerm());
Expand All @@ -203,7 +207,7 @@ private long flush(DecoratedKey minKey, DecoratedKey maxKey, AbstractType<?> ter
SegmentMetadata.write(writer, Collections.singletonList(metadata));
}

return numRows;
return numPostings;
}

private boolean writeFrequencies()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.cassandra.index.sai.disk.v1;

import java.io.IOException;
import java.nio.ByteBuffer;
import javax.annotation.concurrent.NotThreadSafe;
import javax.annotation.concurrent.ThreadSafe;

Expand Down Expand Up @@ -142,7 +141,6 @@ public void close() throws IOException
private final IKeyFetcher keyFetcher;
private final PrimaryKey.Factory primaryKeyFactory;
private final SSTableId<?> sstableId;
private final ByteBuffer tokenBuffer = ByteBuffer.allocate(Long.BYTES);

private PartitionAwarePrimaryKeyMap(LongArray rowIdToToken,
LongArray rowIdToOffset,
Expand All @@ -168,9 +166,8 @@ public SSTableId<?> getSSTableId()
@Override
public PrimaryKey primaryKeyFromRowId(long sstableRowId)
{
tokenBuffer.putLong(rowIdToToken.get(sstableRowId));
tokenBuffer.rewind();
return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromByteArray(tokenBuffer), () -> supplier(sstableRowId));
long token = rowIdToToken.get(sstableRowId);
return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromLongValue(token), () -> supplier(sstableRowId));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public void addRow(PrimaryKey key, Row row, long sstableRowId) throws IOExceptio
if (indexContext.getDefinition().isStatic() != row.isStatic())
return;

boolean addedRow = false;
if (indexContext.isNonFrozenCollection())
{
Iterator<ByteBuffer> valueIterator = indexContext.getValuesOf(row, nowInSec);
Expand All @@ -121,16 +122,21 @@ public void addRow(PrimaryKey key, Row row, long sstableRowId) throws IOExceptio
while (valueIterator.hasNext())
{
ByteBuffer value = valueIterator.next();
addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator());
addedRow = addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator());
}
}
}
else
{
ByteBuffer value = indexContext.getValueOf(key.partitionKey(), row, nowInSec);
if (value != null)
addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator());
{
addedRow = addTerm(TypeUtil.asIndexBytes(value.duplicate(), indexContext.getValidator()), key, sstableRowId, indexContext.getValidator());
}
}
if (addedRow)
currentBuilder.incRowCount();

}

@Override
Expand Down Expand Up @@ -225,10 +231,10 @@ private boolean maybeAbort()
return true;
}

private void addTerm(ByteBuffer term, PrimaryKey key, long sstableRowId, AbstractType<?> type) throws IOException
private boolean addTerm(ByteBuffer term, PrimaryKey key, long sstableRowId, AbstractType<?> type) throws IOException
{
if (!indexContext.validateMaxTermSize(key.partitionKey(), term))
return;
return false;

if (currentBuilder == null)
{
Expand All @@ -241,10 +247,11 @@ else if (shouldFlush(sstableRowId))
}

if (term.remaining() == 0 && TypeUtil.skipsEmptyValue(indexContext.getValidator()))
return;
return false;

long allocated = currentBuilder.analyzeAndAdd(term, type, key, sstableRowId);
limiter.increment(allocated);
return true;
}

private boolean shouldFlush(long sstableRowId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ public SegmentMetadata flush() throws IOException
metadataBuilder.setKeyRange(minKey, maxKey);
metadataBuilder.setRowIdRange(minSSTableRowId, maxSSTableRowId);
metadataBuilder.setTermRange(minTerm, maxTerm);
metadataBuilder.setNumRows(getRowCount());

flushInternal(metadataBuilder);
return metadataBuilder.build();
Expand Down Expand Up @@ -502,8 +503,6 @@ private long add(List<ByteBuffer> terms, PrimaryKey key, long sstableRowId)
maxTerm = TypeUtil.max(term, maxTerm, termComparator, Version.latest());
}

rowCount++;

// segmentRowIdOffset should encode sstableRowId into Integer
int segmentRowId = Math.toIntExact(sstableRowId - segmentRowIdOffset);

Expand Down Expand Up @@ -600,6 +599,11 @@ int getRowCount()
return rowCount;
}

void incRowCount()
{
rowCount++;
}

/**
* @return true if next SSTable row ID exceeds max segment row ID
*/
Expand Down
Loading