Skip to content

Commit 8cccc37

Browse files
CNDB-14317: Optimize doc freq computation for memtable BM25 queries (#1789)
- **CNDB-14317: Optimize doc freq computation for memtable BM25 queries** - **Update approximateTotalTermCount() javadoc** ### What is the issue Fixes riptano/cndb#14317 ### What does this PR fix and why was it fixed We optimize the in memory BM25 computation by using the trie to get the number of rows matching a query term. This change removes a memtable scan and analyze by replacing it with calls to the index to get the number of docs. There are no new tests because it is expected to maintain the same semantics. I will review the test coverage to verify that assertion.
1 parent 431b764 commit 8cccc37

File tree

2 files changed

+101
-76
lines changed

2 files changed

+101
-76
lines changed

src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@
2323
import java.util.Arrays;
2424
import java.util.Comparator;
2525
import java.util.HashMap;
26-
import java.util.HashSet;
2726
import java.util.Iterator;
2827
import java.util.List;
28+
import java.util.Map;
2929
import java.util.Objects;
30-
import java.util.Set;
3130
import java.util.concurrent.atomic.LongAdder;
3231
import java.util.stream.Stream;
3332
import javax.annotation.Nullable;
@@ -39,17 +38,14 @@
3938

4039
import org.apache.cassandra.cql3.Operator;
4140
import org.apache.cassandra.db.Clustering;
42-
import org.apache.cassandra.db.DataRange;
4341
import org.apache.cassandra.db.DecoratedKey;
4442
import org.apache.cassandra.db.PartitionPosition;
45-
import org.apache.cassandra.db.RegularAndStaticColumns;
46-
import org.apache.cassandra.db.filter.ColumnFilter;
4743
import org.apache.cassandra.db.marshal.AbstractType;
4844
import org.apache.cassandra.db.memtable.Memtable;
4945
import org.apache.cassandra.db.memtable.ShardBoundaries;
5046
import org.apache.cassandra.db.memtable.TrieMemtable;
51-
import org.apache.cassandra.db.rows.Row;
5247
import org.apache.cassandra.dht.AbstractBounds;
48+
import org.apache.cassandra.dht.Range;
5349
import org.apache.cassandra.index.sai.IndexContext;
5450
import org.apache.cassandra.index.sai.QueryContext;
5551
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
@@ -130,7 +126,7 @@ public int indexedRows()
130126

131127
/**
132128
* Approximate total count of terms in the memory index.
133-
* The count is approximate because deletions are not accounted for.
129+
* The count is approximate because some deletions are not accounted for in the current implementation.
134130
*
135131
* @return total count of terms for indexes rows.
136132
*/
@@ -290,7 +286,7 @@ public void update(DecoratedKey key, Clustering clustering, Iterator<ByteBuffer>
290286
public KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange, int limit)
291287
{
292288
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
293-
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
289+
int endShard = getEndShardForBounds(keyRange);
294290

295291
KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(endShard - startShard + 1);
296292

@@ -320,6 +316,20 @@ public KeyRangeIterator search(QueryContext queryContext, Expression expression,
320316
return builder.build();
321317
}
322318

319+
public KeyRangeIterator eagerSearch(Expression expression, AbstractBounds<PartitionPosition> keyRange)
320+
{
321+
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
322+
int endShard = getEndShardForBounds(keyRange);
323+
324+
KeyRangeConcatIterator.Builder builder = KeyRangeConcatIterator.builder(endShard - startShard + 1);
325+
for (int shard = startShard; shard <= endShard; ++shard)
326+
{
327+
assert rangeIndexes[shard] != null;
328+
builder.add(rangeIndexes[shard].search(expression, keyRange));
329+
}
330+
return builder.build();
331+
}
332+
323333
@Override
324334
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext queryContext,
325335
Orderer orderer,
@@ -328,16 +338,29 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
328338
int limit)
329339
{
330340
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
331-
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
341+
int endShard = getEndShardForBounds(keyRange);
332342

333343
if (orderer.isBM25())
334344
{
335345
// Intersect iterators to find documents containing all terms
336346
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
337-
List<KeyRangeIterator> termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms);
347+
Map<ByteBuffer, Long> documentFrequencies = new HashMap<>();
348+
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
349+
for (ByteBuffer term : queryTerms)
350+
{
351+
Expression expr = new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term);
352+
// getMaxKeys() counts all rows that match the expression for shards within the key range. The key
353+
// range is not applied to the search results yet, so there is a small chance for overcounting if
354+
// the key range filters within a shard. This is assumed to be acceptable because the on disk
355+
// estimate also uses the key range to skip irrelevant sstable segments but does not apply the key
356+
// range when getting the estimate within a segment.
357+
KeyRangeIterator iterator = eagerSearch(expr, keyRange);
358+
documentFrequencies.put(term, iterator.getMaxKeys());
359+
termIterators.add(iterator);
360+
}
338361
KeyRangeIterator intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build();
339362

340-
return List.of(orderByBM25(Streams.stream(intersectedIterator), orderer));
363+
return List.of(orderByBM25(Streams.stream(intersectedIterator), documentFrequencies, orderer));
341364
}
342365
else
343366
{
@@ -351,36 +374,50 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
351374
}
352375
}
353376

354-
private List<KeyRangeIterator> keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds<PartitionPosition> keyRange, List<ByteBuffer> queryTerms)
355-
{
356-
List<KeyRangeIterator> termIterators = new ArrayList<>(queryTerms.size());
357-
for (ByteBuffer term : queryTerms)
358-
{
359-
Expression expr = new Expression(indexContext);
360-
expr.add(Operator.ANALYZER_MATCHES, term);
361-
KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE);
362-
termIterators.add(iterator);
363-
}
364-
return termIterators;
365-
}
366-
367377
@Override
368378
public long estimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
369379
{
370380
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
371-
int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken());
381+
int endShard = getEndShardForBounds(keyRange);
372382
return rangeIndexes[startShard].estimateMatchingRowsCount(expression, keyRange) * (endShard - startShard + 1);
373383
}
374384

385+
// In the BM25 logic, estimateMatchingRowsCount is not accurate enough because we use the result to compute the
386+
// document score.
387+
private long completeEstimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
388+
{
389+
int startShard = boundaries.getShardForToken(keyRange.left.getToken());
390+
int endShard = getEndShardForBounds(keyRange);
391+
long count = 0;
392+
for (int shard = startShard; shard <= endShard; ++shard)
393+
{
394+
assert rangeIndexes[shard] != null;
395+
count += rangeIndexes[shard].estimateMatchingRowsCount(expression, keyRange);
396+
}
397+
return count;
398+
}
399+
375400
@Override
376401
public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext queryContext, List<PrimaryKey> keys, Orderer orderer, int limit)
377402
{
378403
if (keys.isEmpty())
379404
return CloseableIterator.emptyIterator();
380405

381406
if (orderer.isBM25())
382-
return orderByBM25(keys.stream(), orderer);
407+
{
408+
HashMap<ByteBuffer, Long> documentFrequencies = new HashMap<>();
409+
// We only need to get the document frequencies for the shards that contain the keys.
410+
Range<PartitionPosition> range = Range.makeRowRange(keys.get(0).partitionKey().getToken(),
411+
keys.get(keys.size() - 1).partitionKey().getToken());
412+
for (ByteBuffer term : orderer.getQueryTerms())
413+
{
414+
Expression expression = new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term);
415+
documentFrequencies.put(term, completeEstimateMatchingRowsCount(expression, range));
416+
}
417+
return orderByBM25(keys.stream(), documentFrequencies, orderer);
418+
}
383419
else
420+
{
384421
return SortingIterator.createCloseable(
385422
orderer.getComparator(),
386423
keys,
@@ -403,14 +440,15 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
403440
},
404441
Runnables.doNothing()
405442
);
443+
}
406444
}
407445

408-
private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Orderer orderer)
446+
private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey> stream, Map<ByteBuffer, Long> documentFrequencies, Orderer orderer)
409447
{
410448
assert orderer.isBM25();
411449
List<ByteBuffer> queryTerms = orderer.getQueryTerms();
412450
AbstractAnalyzer analyzer = indexContext.getAnalyzerFactory().create();
413-
BM25Utils.DocStats docStats = computeDocumentFrequencies(queryTerms, analyzer);
451+
BM25Utils.DocStats docStats = new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
414452
Iterator<BM25Utils.DocTF> it = stream
415453
.map(pk -> BM25Utils.EagerDocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
416454
.filter(Objects::nonNull)
@@ -422,54 +460,6 @@ private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey>
422460
memtable);
423461
}
424462

425-
/**
426-
* Count document frequencies for each term using brute force
427-
*/
428-
private BM25Utils.DocStats computeDocumentFrequencies(List<ByteBuffer> queryTerms, AbstractAnalyzer docAnalyzer)
429-
{
430-
var documentFrequencies = new HashMap<ByteBuffer, Long>();
431-
432-
// count all documents in the queried column
433-
try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())),
434-
DataRange.allData(memtable.metadata().partitioner)))
435-
{
436-
while (it.hasNext())
437-
{
438-
var partitions = it.next();
439-
while (partitions.hasNext())
440-
{
441-
var unfiltered = partitions.next();
442-
if (!unfiltered.isRow())
443-
continue;
444-
var row = (Row) unfiltered;
445-
var cell = row.getCell(indexContext.getDefinition());
446-
if (cell == null)
447-
continue;
448-
449-
Set<ByteBuffer> queryTermsPerDoc = new HashSet<>(queryTerms.size());
450-
docAnalyzer.reset(cell.buffer());
451-
try
452-
{
453-
while (docAnalyzer.hasNext())
454-
{
455-
ByteBuffer term = docAnalyzer.next();
456-
if (queryTerms.contains(term))
457-
queryTermsPerDoc.add(term);
458-
}
459-
}
460-
finally
461-
{
462-
docAnalyzer.end();
463-
}
464-
for (ByteBuffer term : queryTermsPerDoc)
465-
documentFrequencies.merge(term, 1L, Long::sum);
466-
467-
}
468-
}
469-
}
470-
return new BM25Utils.DocStats(documentFrequencies, indexedRows(), approximateTotalTermCount());
471-
}
472-
473463
@Nullable
474464
private org.apache.cassandra.db.rows.Cell<?> getCellForKey(PrimaryKey key)
475465
{
@@ -487,6 +477,13 @@ private ByteComparable encode(ByteBuffer input)
487477
return Version.current().onDiskFormat().encodeForTrie(input, indexContext.getValidator());
488478
}
489479

480+
private int getEndShardForBounds(AbstractBounds<PartitionPosition> bounds)
481+
{
482+
PartitionPosition position = bounds.right;
483+
return position.isMinimum() ? boundaries.shardCount() - 1
484+
: boundaries.getShardForToken(position.getToken());
485+
}
486+
490487
/**
491488
* NOTE: returned data may contain partition key not within the provided min and max which are only used to find
492489
* corresponding subranges. We don't do filtering here to avoid unnecessary token comparison. In case of JBOD,

test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,34 @@ public void testDeletedRowWithPredicate() throws Throwable
128128
beforeAndAfterFlush(() -> assertRows(execute(select), row(1)));
129129
}
130130

131+
@Test
132+
public void testRangeRestrictedBM25OnlyQuery() throws Throwable
133+
{
134+
createTable("CREATE TABLE %s (k int PRIMARY KEY, v text, n int)");
135+
createIndex("CREATE CUSTOM INDEX ON %s(n) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
136+
createAnalyzedIndex();
137+
execute("INSERT INTO %s (k, v, n) VALUES (1, 'apple', 0)");
138+
execute("INSERT INTO %s (k, v, n) VALUES (2, 'apple juice', 0)");
139+
String select = "SELECT k FROM %s WHERE token(k) > token(1) AND token(k) < token(3) ORDER BY v BM25 OF 'apple' LIMIT 3";
140+
beforeAndAfterFlush(() -> assertRows(execute(select), row(2)));
141+
}
142+
143+
@Test
144+
public void testRangeRestrictedHybridQuery() throws Throwable
145+
{
146+
createTable("CREATE TABLE %s (k int PRIMARY KEY, v text, n int)");
147+
createIndex("CREATE CUSTOM INDEX ON %s(n) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'");
148+
createAnalyzedIndex();
149+
execute("INSERT INTO %s (k, v, n) VALUES (1, 'apple', 0)");
150+
execute("INSERT INTO %s (k, v, n) VALUES (2, 'apple juice', 0)");
151+
// Insert many unrelated rows so we do search-then-sort
152+
for (int i = 3; i < 100; i++)
153+
execute("INSERT INTO %s (k, v, n) VALUES (?, 'apple juice', 1)", i);
154+
String select = "SELECT k FROM %s WHERE token(k) > token(1) AND token(k) < token(3) " +
155+
"AND n = 0 ORDER BY v BM25 OF 'apple' LIMIT 3";
156+
beforeAndAfterFlush(() -> assertRows(execute(select), row(2)));
157+
}
158+
131159
@Test
132160
public void testTwoIndexesAmbiguousPredicate() throws Throwable
133161
{

0 commit comments

Comments
 (0)