Skip to content

CNDB-14317: Optimize doc freq computation for memtable BM25 queries #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 73 additions & 76 deletions src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Stream;
import javax.annotation.Nullable;
Expand All @@ -39,17 +38,14 @@

import org.apache.cassandra.cql3.Operator;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DataRange;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.db.RegularAndStaticColumns;
import org.apache.cassandra.db.filter.ColumnFilter;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.memtable.Memtable;
import org.apache.cassandra.db.memtable.ShardBoundaries;
import org.apache.cassandra.db.memtable.TrieMemtable;
import org.apache.cassandra.db.rows.Row;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.dht.Range;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
Expand Down Expand Up @@ -130,7 +126,7 @@ public int indexedRows()

/**
* Approximate total count of terms in the memory index.
* The count is approximate because deletions are not accounted for.
* The count is approximate because some deletions are not accounted for in the current implementation.
*
* @return total count of terms for indexes rows.
*/
Expand Down Expand Up @@ -290,7 +286,7 @@ public void update(DecoratedKey key, Clustering clustering, Iterator<ByteBuffer>
public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange, int limit)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
int endShard = getEndShardForBounds(keyRange);

KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(endShard - startShard + 1);

Expand Down Expand Up @@ -320,6 +316,20 @@ public KeyRangeIterator search(QueryContext queryContext, Expression expression,
return builder.build();
}

public KeyRangeIterator eagerSearch(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = getEndShardForBounds(keyRange);

KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(endShard - startShard + 1);
for (int shard = startShard; shard <= endShard; ++shard)
{
assert rangeIndexes[shard] != null;
builder.add(rangeIndexes[shard].search(expression, keyRange));
}
return builder.build();
}

@Override
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext queryContext,
Orderer orderer,
Expand All @@ -328,16 +338,29 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
int limit)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
int endShard = getEndShardForBounds(keyRange);

if (orderer.isBM25())
{
// Intersect iterators to find documents containing all terms
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
List<KeyRangeIterator> termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms);
Map<ByteBuffer, Long> documentFrequencies = new HashMap<>();
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
for (ByteBuffer term : queryTerms)
{
Expression expr = new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term);
// getMaxKeys() counts all rows that match the expression for shards within the key range. The key
// range is not applied to the search results yet, so there is a small chance for overcounting if
// the key range filters within a shard. This is assumed to be acceptable because the on disk
// estimate also uses the key range to skip irrelevant sstable segments but does not apply the key
// range when getting the estimate within a segment.
KeyRangeIterator iterator = eagerSearch(expr, keyRange);
documentFrequencies.put(term, iterator.getMaxKeys());
termIterators.add(iterator);
}
KeyRangeIterator intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build();

return List.of(orderByBM25(Streams.stream(intersectedIterator), orderer));
return List.of(orderByBM25(Streams.stream(intersectedIterator), documentFrequencies, orderer));
}
else
{
Expand All @@ -351,36 +374,50 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
}
}

private List<KeyRangeIterator> keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds<PartitionPosition> keyRange, List<ByteBuffer> queryTerms)
{
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
for (ByteBuffer term : queryTerms)
{
Expression expr = new Expression(indexContext);
expr.add(Operator.ANALYZER_MATCHES, term);
KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE);
termIterators.add(iterator);
}
return termIterators;
}

@Override
public long estimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
int endShard = getEndShardForBounds(keyRange);
return rangeIndexes[startShard].estimateMatchingRowsCount(expression, keyRange) * (endShard - startShard + 1);
}

// In the BM25 logic, estimateMatchingRowsCount is not accurate enough because we use the result to compute the
// document score.
private long completeEstimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = getEndShardForBounds(keyRange);
long count = 0;
for (int shard = startShard; shard <= endShard; ++shard)
{
assert rangeIndexes[shard] != null;
count += rangeIndexes[shard].estimateMatchingRowsCount(expression, keyRange);
}
return count;
}

@Override
public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext queryContext, List<PrimaryKey> keys, Orderer orderer, int limit)
{
if (keys.isEmpty())
return CloseableIterator.emptyIterator();

if (orderer.isBM25())
return orderByBM25(keys.stream(), orderer);
{
HashMap<ByteBuffer, Long> documentFrequencies = new HashMap<>();
// We only need to get the document frequencies for the shards that contain the keys.
Range<PartitionPosition> range = Range.makeRowRange(keys.get(0).partitionKey().getToken(),
keys.get(keys.size() - 1).partitionKey().getToken());
for (ByteBuffer term : orderer.getQueryTerms())
{
Expression expression = new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term);
documentFrequencies.put(term, completeEstimateMatchingRowsCount(expression, range));
}
return orderByBM25(keys.stream(), documentFrequencies, orderer);
}
else
{
return SortingIterator.createCloseable(
orderer.getComparator(),
keys,
Expand All @@ -403,14 +440,15 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
},
Runnables.doNothing()
);
}
}

private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Orderer orderer)
private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Map<ByteBuffer, Long> documentFrequencies, Orderer orderer)
{
assert orderer.isBM25();
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
AbstractAnalyzer analyzer = indexContext.getAnalyzerFactory().create();
BM25Utils.DocStats docStats = computeDocumentFrequencies(queryTerms, analyzer);
BM25Utils.DocStats docStats = new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
Iterator<BM25Utils.DocTF> it = stream
.map(pk -> BM25Utils.EagerDocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
.filter(Objects::nonNull)
Expand All @@ -422,54 +460,6 @@ private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey>
memtable);
}

/**
* Count document frequencies for each term using brute force
*/
private BM25Utils.DocStats computeDocumentFrequencies(List<ByteBuffer> queryTerms, AbstractAnalyzer docAnalyzer)
{
var documentFrequencies = new HashMap<ByteBuffer, Long>();

// count all documents in the queried column
try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())),
DataRange.allData(memtable.metadata().partitioner)))
{
while (it.hasNext())
{
var partitions = it.next();
while (partitions.hasNext())
{
var unfiltered = partitions.next();
if (!unfiltered.isRow())
continue;
var row = (Row) unfiltered;
var cell = row.getCell(indexContext.getDefinition());
if (cell == null)
continue;

Set<ByteBuffer> queryTermsPerDoc = new HashSet<>(queryTerms.size());
docAnalyzer.reset(cell.buffer());
try
{
while (docAnalyzer.hasNext())
{
ByteBuffer term = docAnalyzer.next();
if (queryTerms.contains(term))
queryTermsPerDoc.add(term);
}
}
finally
{
docAnalyzer.end();
}
for (ByteBuffer term : queryTermsPerDoc)
documentFrequencies.merge(term, 1L, Long::sum);

}
}
}
return new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
}

@Nullable
private org.apache.cassandra.db.rows.Cell<?> getCellForKey(PrimaryKey key)
{
Expand All @@ -487,6 +477,13 @@ private ByteComparable encode(ByteBuffer input)
return Version.current().onDiskFormat().encodeForTrie(input, indexContext.getValidator());
}

private int getEndShardForBounds(AbstractBounds<PartitionPosition> bounds)
{
PartitionPosition position = bounds.right;
return position.isMinimum() ? boundaries.shardCount() - 1
: boundaries.getShardForToken(position.getToken());
}

/**
* NOTE: returned data may contain partition key not within the provided min and max which are only used to find
* corresponding subranges. We don't do filtering here to avoid unnecessary token comparison. In case of JBOD,
Expand Down
28 changes: 28 additions & 0 deletions test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,34 @@ public void testDeletedRowWithPredicate() throws Throwable
beforeAndAfterFlush(() -> assertRows(execute(select), row(1)));
}

@Test
public void testRangeRestrictedBM25OnlyQuery() throws Throwable
{
createTable("CREATE TABLE %s (k int PRIMARY KEY, v text, n int)");
createIndex("CREATE CUSTOM INDEX ON %s(n) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
createAnalyzedIndex();
execute("INSERT INTO %s (k, v, n) VALUES (1, 'apple', 0)");
execute("INSERT INTO %s (k, v, n) VALUES (2, 'apple juice', 0)");
String select = "SELECT k FROM %s WHERE token(k) > token(1) AND token(k) < token(3) ORDER BY v BM25 OF 'apple' LIMIT 3";
beforeAndAfterFlush(() -> assertRows(execute(select), row(2)));
}

@Test
public void testRangeRestrictedHybridQuery() throws Throwable
{
createTable("CREATE TABLE %s (k int PRIMARY KEY, v text, n int)");
createIndex("CREATE CUSTOM INDEX ON %s(n) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
createAnalyzedIndex();
execute("INSERT INTO %s (k, v, n) VALUES (1, 'apple', 0)");
execute("INSERT INTO %s (k, v, n) VALUES (2, 'apple juice', 0)");
// Insert many unrelated rows so we do search-then-sort
for (int i = 3; i < 100; i++)
execute("INSERT INTO %s (k, v, n) VALUES (?, 'apple juice', 1)", i);
String select = "SELECT k FROM %s WHERE token(k) > token(1) AND token(k) < token(3) " +
"AND n = 0 ORDER BY v BM25 OF 'apple' LIMIT 3";
beforeAndAfterFlush(() -> assertRows(execute(select), row(2)));
}

@Test
public void testTwoIndexesAmbiguousPredicate() throws Throwable
{
Expand Down