Skip to content

CNDB-14317: Optimize doc freq computation for memtable BM25 queries #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 9, 2025
Merged
126 changes: 59 additions & 67 deletions src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
Expand Down Expand Up @@ -130,7 +131,7 @@ public int indexedRows()

/**
* Approximate total count of terms in the memory index.
* The count is approximate because deletions are not accounted for.
* The count is approximate because range deletions are not accounted for.
*
* @return total count of terms for indexes rows.
*/
Expand Down Expand Up @@ -320,6 +321,20 @@ public KeyRangeIterator search(QueryContext queryContext, Expression expression,
return builder.build();
}

public KeyRangeIterator eagerSearch(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());

KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(endShard - startShard + 1);
for (int shard = startShard; shard <= endShard; ++shard)
{
assert rangeIndexes[shard] != null;
builder.add(rangeIndexes[shard].search(expression, keyRange));
}
return builder.build();
}

@Override
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext queryContext,
Orderer orderer,
Expand All @@ -334,10 +349,21 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
{
// Intersect iterators to find documents containing all terms
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
List<KeyRangeIterator> termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms);
Map<ByteBuffer, Long> documentFrequencies = new HashMap<>();
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
for (ByteBuffer term : queryTerms)
{
Expression expr = new Expression(indexContext);
expr.add(Operator.ANALYZER_MATCHES, term);
// Because this is an in memory, eager search, the max keys is exact and cumulative over all shards.
// Also, the key range is not eagerly applied, so the
KeyRangeIterator iterator = eagerSearch(expr, keyRange);
documentFrequencies.put(term, iterator.getMaxKeys());
termIterators.add(iterator);
}
KeyRangeIterator intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build();

return List.of(orderByBM25(Streams.stream(intersectedIterator), orderer));
return List.of(orderByBM25(Streams.stream(intersectedIterator), documentFrequencies, orderer));
}
else
{
Expand All @@ -351,19 +377,6 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
}
}

private List<KeyRangeIterator> keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds<PartitionPosition> keyRange, List<ByteBuffer> queryTerms)
{
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
for (ByteBuffer term : queryTerms)
{
Expression expr = new Expression(indexContext);
expr.add(Operator.ANALYZER_MATCHES, term);
KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE);
termIterators.add(iterator);
}
return termIterators;
}

@Override
public long estimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
Expand All @@ -372,15 +385,41 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds<Part
return rangeIndexes[startShard].estimateMatchingRowsCount(expression, keyRange) * (endShard - startShard + 1);
}

// In the BM25 logic, estimateMatchingRowsCount is not accurate enough because we use the result to compute the
// document score.
private long completeEstimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
{
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
long count = 0;
for (int shard = startShard; shard <= endShard; ++shard)
{
assert rangeIndexes[shard] != null;
count += rangeIndexes[shard].estimateMatchingRowsCount(expression, keyRange);
}
return count;
}

@Override
public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext queryContext, List<PrimaryKey> keys, Orderer orderer, int limit)
{
if (keys.isEmpty())
return CloseableIterator.emptyIterator();

if (orderer.isBM25())
return orderByBM25(keys.stream(), orderer);
{
HashMap<ByteBuffer, Long> documentFrequencies = new HashMap<>();
// We don't want to filter the document frequencies, so we use the whole range
DataRange dataRange = DataRange.allData(memtable.metadata().partitioner);
for (ByteBuffer term : orderer.getQueryTerms())
{
Expression expression = new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term);
documentFrequencies.put(term, completeEstimateMatchingRowsCount(expression, dataRange.keyRange()));
}
return orderByBM25(keys.stream(), documentFrequencies, orderer);
}
else
{
return SortingIterator.createCloseable(
orderer.getComparator(),
keys,
Expand All @@ -403,14 +442,15 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
},
Runnables.doNothing()
);
}
}

private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Orderer orderer)
private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Map<ByteBuffer, Long> documentFrequencies, Orderer orderer)
{
assert orderer.isBM25();
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
AbstractAnalyzer analyzer = indexContext.getAnalyzerFactory().create();
BM25Utils.DocStats docStats = computeDocumentFrequencies(queryTerms, analyzer);
BM25Utils.DocStats docStats = new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
Iterator<BM25Utils.DocTF> it = stream
.map(pk -> BM25Utils.EagerDocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
.filter(Objects::nonNull)
Expand All @@ -422,54 +462,6 @@ private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey>
memtable);
}

/**
* Count document frequencies for each term using brute force
*/
private BM25Utils.DocStats computeDocumentFrequencies(List<ByteBuffer> queryTerms, AbstractAnalyzer docAnalyzer)
{
var documentFrequencies = new HashMap<ByteBuffer, Long>();

// count all documents in the queried column
try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())),
DataRange.allData(memtable.metadata().partitioner)))
{
while (it.hasNext())
{
var partitions = it.next();
while (partitions.hasNext())
{
var unfiltered = partitions.next();
if (!unfiltered.isRow())
continue;
var row = (Row) unfiltered;
var cell = row.getCell(indexContext.getDefinition());
if (cell == null)
continue;

Set<ByteBuffer> queryTermsPerDoc = new HashSet<>(queryTerms.size());
docAnalyzer.reset(cell.buffer());
try
{
while (docAnalyzer.hasNext())
{
ByteBuffer term = docAnalyzer.next();
if (queryTerms.contains(term))
queryTermsPerDoc.add(term);
}
}
finally
{
docAnalyzer.end();
}
for (ByteBuffer term : queryTermsPerDoc)
documentFrequencies.merge(term, 1L, Long::sum);

}
}
}
return new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
}

@Nullable
private org.apache.cassandra.db.rows.Cell<?> getCellForKey(PrimaryKey key)
{
Expand Down
Loading