23
23
import java .util .Arrays ;
24
24
import java .util .Comparator ;
25
25
import java .util .HashMap ;
26
- import java .util .HashSet ;
27
26
import java .util .Iterator ;
28
27
import java .util .List ;
28
+ import java .util .Map ;
29
29
import java .util .Objects ;
30
- import java .util .Set ;
31
30
import java .util .concurrent .atomic .LongAdder ;
32
31
import java .util .stream .Stream ;
33
32
import javax .annotation .Nullable ;
39
38
40
39
import org .apache .cassandra .cql3 .Operator ;
41
40
import org .apache .cassandra .db .Clustering ;
42
- import org .apache .cassandra .db .DataRange ;
43
41
import org .apache .cassandra .db .DecoratedKey ;
44
42
import org .apache .cassandra .db .PartitionPosition ;
45
- import org .apache .cassandra .db .RegularAndStaticColumns ;
46
- import org .apache .cassandra .db .filter .ColumnFilter ;
47
43
import org .apache .cassandra .db .marshal .AbstractType ;
48
44
import org .apache .cassandra .db .memtable .Memtable ;
49
45
import org .apache .cassandra .db .memtable .ShardBoundaries ;
50
46
import org .apache .cassandra .db .memtable .TrieMemtable ;
51
- import org .apache .cassandra .db .rows .Row ;
52
47
import org .apache .cassandra .dht .AbstractBounds ;
48
+ import org .apache .cassandra .dht .Range ;
53
49
import org .apache .cassandra .index .sai .IndexContext ;
54
50
import org .apache .cassandra .index .sai .QueryContext ;
55
51
import org .apache .cassandra .index .sai .analyzer .AbstractAnalyzer ;
@@ -130,7 +126,7 @@ public int indexedRows()
130
126
131
127
/**
132
128
* Approximate total count of terms in the memory index.
133
- * The count is approximate because deletions are not accounted for.
129
+ * The count is approximate because some deletions are not accounted for in the current implementation .
134
130
*
135
131
* @return total count of terms for indexes rows.
136
132
*/
@@ -290,7 +286,7 @@ public void update(DecoratedKey key, Clustering clustering, Iterator<ByteBuffer>
290
286
public KeyRangeIterator search (QueryContext queryContext , Expression expression , AbstractBounds <PartitionPosition > keyRange , int limit )
291
287
{
292
288
int startShard = boundaries .getShardForToken (keyRange .left .getToken ());
293
- int endShard = keyRange . right . isMinimum () ? boundaries . shardCount () - 1 : boundaries . getShardForToken ( keyRange . right . getToken () );
289
+ int endShard = getEndShardForBounds ( keyRange );
294
290
295
291
KeyRangeConcatIterator .Builder builder = KeyRangeConcatIterator .builder (endShard - startShard + 1 );
296
292
@@ -320,6 +316,20 @@ public KeyRangeIterator search(QueryContext queryContext, Expression expression,
320
316
return builder .build ();
321
317
}
322
318
319
+ public KeyRangeIterator eagerSearch (Expression expression , AbstractBounds <PartitionPosition > keyRange )
320
+ {
321
+ int startShard = boundaries .getShardForToken (keyRange .left .getToken ());
322
+ int endShard = getEndShardForBounds (keyRange );
323
+
324
+ KeyRangeConcatIterator .Builder builder = KeyRangeConcatIterator .builder (endShard - startShard + 1 );
325
+ for (int shard = startShard ; shard <= endShard ; ++shard )
326
+ {
327
+ assert rangeIndexes [shard ] != null ;
328
+ builder .add (rangeIndexes [shard ].search (expression , keyRange ));
329
+ }
330
+ return builder .build ();
331
+ }
332
+
323
333
@ Override
324
334
public List <CloseableIterator <PrimaryKeyWithSortKey >> orderBy (QueryContext queryContext ,
325
335
Orderer orderer ,
@@ -328,16 +338,29 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
328
338
int limit )
329
339
{
330
340
int startShard = boundaries .getShardForToken (keyRange .left .getToken ());
331
- int endShard = keyRange . right . isMinimum () ? boundaries . shardCount () - 1 : boundaries . getShardForToken ( keyRange . right . getToken () );
341
+ int endShard = getEndShardForBounds ( keyRange );
332
342
333
343
if (orderer .isBM25 ())
334
344
{
335
345
// Intersect iterators to find documents containing all terms
336
346
List <ByteBuffer > queryTerms = orderer .getQueryTerms ();
337
- List <KeyRangeIterator > termIterators = keyIteratorsPerTerm (queryContext , keyRange , queryTerms );
347
+ Map <ByteBuffer , Long > documentFrequencies = new HashMap <>();
348
+ List <KeyRangeIterator > termIterators = new ArrayList <>(queryTerms .size ());
349
+ for (ByteBuffer term : queryTerms )
350
+ {
351
+ Expression expr = new Expression (indexContext ).add (Operator .ANALYZER_MATCHES , term );
352
+ // getMaxKeys() counts all rows that match the expression for shards within the key range. The key
353
+ // range is not applied to the search results yet, so there is a small chance for overcounting if
354
+ // the key range filters within a shard. This is assumed to be acceptable because the on disk
355
+ // estimate also uses the key range to skip irrelevant sstable segments but does not apply the key
356
+ // range when getting the estimate within a segment.
357
+ KeyRangeIterator iterator = eagerSearch (expr , keyRange );
358
+ documentFrequencies .put (term , iterator .getMaxKeys ());
359
+ termIterators .add (iterator );
360
+ }
338
361
KeyRangeIterator intersectedIterator = KeyRangeIntersectionIterator .builder (termIterators ).build ();
339
362
340
- return List .of (orderByBM25 (Streams .stream (intersectedIterator ), orderer ));
363
+ return List .of (orderByBM25 (Streams .stream (intersectedIterator ), documentFrequencies , orderer ));
341
364
}
342
365
else
343
366
{
@@ -351,36 +374,50 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
351
374
}
352
375
}
353
376
354
- private List <KeyRangeIterator > keyIteratorsPerTerm (QueryContext queryContext , AbstractBounds <PartitionPosition > keyRange , List <ByteBuffer > queryTerms )
355
- {
356
- List <KeyRangeIterator > termIterators = new ArrayList <>(queryTerms .size ());
357
- for (ByteBuffer term : queryTerms )
358
- {
359
- Expression expr = new Expression (indexContext );
360
- expr .add (Operator .ANALYZER_MATCHES , term );
361
- KeyRangeIterator iterator = search (queryContext , expr , keyRange , Integer .MAX_VALUE );
362
- termIterators .add (iterator );
363
- }
364
- return termIterators ;
365
- }
366
-
367
377
@ Override
368
378
public long estimateMatchingRowsCount (Expression expression , AbstractBounds <PartitionPosition > keyRange )
369
379
{
370
380
int startShard = boundaries .getShardForToken (keyRange .left .getToken ());
371
- int endShard = keyRange . right . isMinimum () ? boundaries . shardCount () - 1 : boundaries . getShardForToken ( keyRange . right . getToken () );
381
+ int endShard = getEndShardForBounds ( keyRange );
372
382
return rangeIndexes [startShard ].estimateMatchingRowsCount (expression , keyRange ) * (endShard - startShard + 1 );
373
383
}
374
384
385
+ // In the BM25 logic, estimateMatchingRowsCount is not accurate enough because we use the result to compute the
386
+ // document score.
387
+ private long completeEstimateMatchingRowsCount (Expression expression , AbstractBounds <PartitionPosition > keyRange )
388
+ {
389
+ int startShard = boundaries .getShardForToken (keyRange .left .getToken ());
390
+ int endShard = getEndShardForBounds (keyRange );
391
+ long count = 0 ;
392
+ for (int shard = startShard ; shard <= endShard ; ++shard )
393
+ {
394
+ assert rangeIndexes [shard ] != null ;
395
+ count += rangeIndexes [shard ].estimateMatchingRowsCount (expression , keyRange );
396
+ }
397
+ return count ;
398
+ }
399
+
375
400
@ Override
376
401
public CloseableIterator <PrimaryKeyWithSortKey > orderResultsBy (QueryContext queryContext , List <PrimaryKey > keys , Orderer orderer , int limit )
377
402
{
378
403
if (keys .isEmpty ())
379
404
return CloseableIterator .emptyIterator ();
380
405
381
406
if (orderer .isBM25 ())
382
- return orderByBM25 (keys .stream (), orderer );
407
+ {
408
+ HashMap <ByteBuffer , Long > documentFrequencies = new HashMap <>();
409
+ // We only need to get the document frequencies for the shards that contain the keys.
410
+ Range <PartitionPosition > range = Range .makeRowRange (keys .get (0 ).partitionKey ().getToken (),
411
+ keys .get (keys .size () - 1 ).partitionKey ().getToken ());
412
+ for (ByteBuffer term : orderer .getQueryTerms ())
413
+ {
414
+ Expression expression = new Expression (indexContext ).add (Operator .ANALYZER_MATCHES , term );
415
+ documentFrequencies .put (term , completeEstimateMatchingRowsCount (expression , range ));
416
+ }
417
+ return orderByBM25 (keys .stream (), documentFrequencies , orderer );
418
+ }
383
419
else
420
+ {
384
421
return SortingIterator .createCloseable (
385
422
orderer .getComparator (),
386
423
keys ,
@@ -403,14 +440,15 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
403
440
},
404
441
Runnables .doNothing ()
405
442
);
443
+ }
406
444
}
407
445
408
- private CloseableIterator <PrimaryKeyWithSortKey > orderByBM25 (Stream <PrimaryKey > stream , Orderer orderer )
446
+ private CloseableIterator <PrimaryKeyWithSortKey > orderByBM25 (Stream <PrimaryKey > stream , Map < ByteBuffer , Long > documentFrequencies , Orderer orderer )
409
447
{
410
448
assert orderer .isBM25 ();
411
449
List <ByteBuffer > queryTerms = orderer .getQueryTerms ();
412
450
AbstractAnalyzer analyzer = indexContext .getAnalyzerFactory ().create ();
413
- BM25Utils .DocStats docStats = computeDocumentFrequencies ( queryTerms , analyzer );
451
+ BM25Utils .DocStats docStats = new BM25Utils . DocStats ( documentFrequencies , indexedRows (), approximateTotalTermCount () );
414
452
Iterator <BM25Utils .DocTF > it = stream
415
453
.map (pk -> BM25Utils .EagerDocTF .createFromDocument (pk , getCellForKey (pk ), analyzer , queryTerms ))
416
454
.filter (Objects ::nonNull )
@@ -422,54 +460,6 @@ private CloseableIterator<PrimaryKeyWithSortKey> orderByBM25(Stream<PrimaryKey>
422
460
memtable );
423
461
}
424
462
425
- /**
426
- * Count document frequencies for each term using brute force
427
- */
428
- private BM25Utils .DocStats computeDocumentFrequencies (List <ByteBuffer > queryTerms , AbstractAnalyzer docAnalyzer )
429
- {
430
- var documentFrequencies = new HashMap <ByteBuffer , Long >();
431
-
432
- // count all documents in the queried column
433
- try (var it = memtable .makePartitionIterator (ColumnFilter .selection (RegularAndStaticColumns .of (indexContext .getDefinition ())),
434
- DataRange .allData (memtable .metadata ().partitioner )))
435
- {
436
- while (it .hasNext ())
437
- {
438
- var partitions = it .next ();
439
- while (partitions .hasNext ())
440
- {
441
- var unfiltered = partitions .next ();
442
- if (!unfiltered .isRow ())
443
- continue ;
444
- var row = (Row ) unfiltered ;
445
- var cell = row .getCell (indexContext .getDefinition ());
446
- if (cell == null )
447
- continue ;
448
-
449
- Set <ByteBuffer > queryTermsPerDoc = new HashSet <>(queryTerms .size ());
450
- docAnalyzer .reset (cell .buffer ());
451
- try
452
- {
453
- while (docAnalyzer .hasNext ())
454
- {
455
- ByteBuffer term = docAnalyzer .next ();
456
- if (queryTerms .contains (term ))
457
- queryTermsPerDoc .add (term );
458
- }
459
- }
460
- finally
461
- {
462
- docAnalyzer .end ();
463
- }
464
- for (ByteBuffer term : queryTermsPerDoc )
465
- documentFrequencies .merge (term , 1L , Long ::sum );
466
-
467
- }
468
- }
469
- }
470
- return new BM25Utils .DocStats (documentFrequencies , indexedRows (), approximateTotalTermCount ());
471
- }
472
-
473
463
@ Nullable
474
464
private org .apache .cassandra .db .rows .Cell <?> getCellForKey (PrimaryKey key )
475
465
{
@@ -487,6 +477,13 @@ private ByteComparable encode(ByteBuffer input)
487
477
return Version .current ().onDiskFormat ().encodeForTrie (input , indexContext .getValidator ());
488
478
}
489
479
480
+ private int getEndShardForBounds (AbstractBounds <PartitionPosition > bounds )
481
+ {
482
+ PartitionPosition position = bounds .right ;
483
+ return position .isMinimum () ? boundaries .shardCount () - 1
484
+ : boundaries .getShardForToken (position .getToken ());
485
+ }
486
+
490
487
/**
491
488
* NOTE: returned data may contain partition key not within the provided min and max which are only used to find
492
489
* corresponding subranges. We don't do filtering here to avoid unnecessary token comparison. In case of JBOD,
0 commit comments