Skip to content

Commit 26437f1

Browse files
michaeljmarshalldjatnieks
authored andcommitted
CNDB-13689: use NodeQueue::pushMany to decrease time complexity to build heap (#1693)
### What is the issue Fixes: riptano/cndb#13689 ### What does this PR fix and why was it fixed This PR utilizes the NodeQueue::pushMany method to decrease the time complexity required to build the NodeQueue from `O(n log(n))` to `O(n)`. This is likely only significant for sufficiently large hybrid queries. For example, we have seen cases of the search producing 400k rows, which means that we do 400k insertions into these NodeQueue objects.
1 parent 6e759fa commit 26437f1

File tree

5 files changed

+175
-64
lines changed

5 files changed

+175
-64
lines changed

src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
299299
// Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
300300
// retrieve both values with O(1) lookup when we need to resolve the full resolution score in the
301301
// BruteForceRowIdIterator.
302-
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
303-
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
304-
});
302+
var iter = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction);
303+
approximateScores.pushMany(iter, segmentOrdinalPairs.size());
305304
columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size());
306305
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
307306
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics);
@@ -320,9 +319,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
320319
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
321320
var esf = vectorsView.rerankerFor(queryVector, similarityFunction);
322321
// Because the scores are exact, we only store the rowid, score pair.
323-
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
324-
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
325-
});
322+
var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(esf);
323+
scoredRowIds.pushMany(iter, segmentOrdinalPairs.size());
326324
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
327325
return new NodeQueueRowIdIterator(scoredRowIds);
328326
}

src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@
2626
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
2727
import org.apache.cassandra.io.util.FileUtils;
2828
import org.apache.cassandra.utils.AbstractIterator;
29-
import org.apache.cassandra.utils.SortingIterator;
3029

3130

3231
/**
33-
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link SortingIterator} of
34-
* {@link RowWithApproximateScore}.
32+
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link NodeQueue} of approximate scores.
3533
* <p>
3634
* The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top
3735
* `k` best approximate scores at a time with the full resolution vectors to return the top `limit`.

src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -170,32 +170,55 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
170170
// Calculate average document length
171171
double avgDocLength = totalTermCount / documents.size();
172172

173-
// Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity
173+
// Calculate BM25 scores.
174+
// Uses a NodeQueue that avoids allocating an object for each document.
174175
var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP);
175-
for (int i = 0; i < documents.size(); i++)
176-
{
177-
var doc = documents.get(i);
178-
double score = 0.0;
179-
for (var queryTerm : queryTerms)
180-
{
181-
int tf = doc.getTermFrequency(queryTerm);
182-
Long df = docStats.frequencies.get(queryTerm);
183-
// we shouldn't have more hits for a term than we counted total documents
184-
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);
185-
186-
double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
187-
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
188-
double deltaScore = normalizedTf * idf;
189-
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
190-
tf, df, doc.termCount(), docStats.docCount, deltaScore);
191-
score += deltaScore;
176+
// Create an anonymous NodeScoreIterator that holds the logic for computing BM25
177+
var iter = new NodeQueue.NodeScoreIterator() {
178+
int current = 0;
179+
180+
@Override
181+
public boolean hasNext() {
182+
return current < documents.size();
192183
}
193-
nodeQueue.push(i, (float) score);
194-
}
184+
185+
@Override
186+
public int pop() {
187+
return current++;
188+
}
189+
190+
@Override
191+
public float topScore() {
192+
// Compute BM25 for the current document
193+
return scoreDoc(documents.get(current), docStats, queryTerms, avgDocLength);
194+
}
195+
};
196+
// pushMany is an O(n) operation where n is the final size of the queue. Iterative calls to push is O(n log n).
197+
nodeQueue.pushMany(iter, documents.size());
195198

196199
return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator);
197200
}
198201

202+
private static float scoreDoc(DocTF doc, DocStats docStats, List<ByteBuffer> queryTerms, double avgDocLength)
203+
{
204+
double score = 0.0;
205+
for (var queryTerm : queryTerms)
206+
{
207+
int tf = doc.getTermFrequency(queryTerm);
208+
Long df = docStats.frequencies.get(queryTerm);
209+
// we shouldn't have more hits for a term than we counted total documents
210+
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);
211+
212+
double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
213+
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
214+
double deltaScore = normalizedTf * idf;
215+
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
216+
tf, df, doc.termCount(), docStats.docCount, deltaScore);
217+
score += deltaScore;
218+
}
219+
return (float) score;
220+
}
221+
199222
private static class NodeQueueDocTFIterator extends AbstractIterator<PrimaryKeyWithSortKey>
200223
{
201224
private final NodeQueue nodeQueue;

src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import java.util.function.IntConsumer;
2222

23+
import io.github.jbellis.jvector.graph.NodeQueue;
24+
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
2325
import org.agrona.collections.IntIntConsumer;
2426

2527
/**
@@ -33,7 +35,7 @@ public class SegmentRowIdOrdinalPairs
3335
private final int[] array;
3436

3537
/**
36-
* Create a new IntIntPairArray with the given capacity.
38+
* Create a new SegmentRowIdOrdinalPairs with the given capacity.
3739
* @param capacity the capacity
3840
*/
3941
public SegmentRowIdOrdinalPairs(int capacity)
@@ -102,15 +104,53 @@ public void forEachSegmentRowIdOrdinalPair(IntIntConsumer consumer)
102104
}
103105

104106
/**
105-
* Iterate over the pairs in the array, calling the consumer for each pair passing (index, x, y).
106-
* @param consumer the consumer to call for each pair
107+
* Create an iterator over the segment row id and scored ordinal pairs in the array.
108+
* @param scoreFunction the score function to use to compute the next score based on the ordinal
107109
*/
108-
public void forEachIndexOrdinalPair(IntIntConsumer consumer)
110+
public NodeQueue.NodeScoreIterator mapToSegmentRowIdScoreIterator(ScoreFunction scoreFunction)
109111
{
110-
for (int i = 0; i < size; i++)
111-
consumer.accept(i, array[i * 2 + 1]);
112+
return mapToScoreIterator(scoreFunction, false);
113+
}
114+
115+
/**
116+
* Create an iterator over the index and scored ordinal pairs in the array.
117+
* @param scoreFunction the score function to use to compute the next score based on the ordinal
118+
*/
119+
public NodeQueue.NodeScoreIterator mapToIndexScoreIterator(ScoreFunction scoreFunction)
120+
{
121+
return mapToScoreIterator(scoreFunction, true);
112122
}
113123

124+
/**
125+
* Create an iterator over the index or the segment row id and the score for the ordinal.
126+
* @param scoreFunction the score function to use to compute the next score based on the ordinal
127+
* @param mapToIndex whether to map to the index or the segment row id
128+
*/
129+
private NodeQueue.NodeScoreIterator mapToScoreIterator(ScoreFunction scoreFunction, boolean mapToIndex)
130+
{
131+
return new NodeQueue.NodeScoreIterator()
132+
{
133+
int i = 0;
134+
135+
@Override
136+
public boolean hasNext()
137+
{
138+
return i < size;
139+
}
140+
141+
@Override
142+
public int pop()
143+
{
144+
return mapToIndex ? i++ : array[i++ * 2];
145+
}
146+
147+
@Override
148+
public float topScore()
149+
{
150+
return scoreFunction.similarityTo(array[i * 2 + 1]);
151+
}
152+
};
153+
}
114154

115155
/**
116156
* Calls the consumer for each right value in each pair of the array.

test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616

1717
package org.apache.cassandra.index.sai.utils;
1818

19+
import io.github.jbellis.jvector.graph.NodeQueue;
20+
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
1921
import java.util.ArrayList;
2022
import java.util.List;
2123
import java.util.concurrent.atomic.AtomicInteger;
2224

2325
import org.junit.Test;
2426

2527
import static org.junit.Assert.assertEquals;
28+
import static org.junit.Assert.assertFalse;
2629
import static org.junit.Assert.assertThrows;
30+
import static org.junit.Assert.assertTrue;
2731

2832
public class SegmentRowIdOrdinalPairsTest
2933
{
@@ -95,32 +99,6 @@ public void testForEachSegmentRowIdOrdinalPair()
9599
assertEquals(Integer.valueOf(30), ordinals.get(2));
96100
}
97101

98-
@Test
99-
public void testForEachIndexOrdinalPair()
100-
{
101-
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
102-
pairs.add(1, 10);
103-
pairs.add(2, 20);
104-
pairs.add(3, 30);
105-
106-
List<Integer> indices = new ArrayList<>();
107-
List<Integer> ordinals = new ArrayList<>();
108-
109-
pairs.forEachIndexOrdinalPair((index, ordinal) -> {
110-
indices.add(index);
111-
ordinals.add(ordinal);
112-
});
113-
114-
assertEquals(3, indices.size());
115-
assertEquals(3, ordinals.size());
116-
assertEquals(Integer.valueOf(0), indices.get(0));
117-
assertEquals(Integer.valueOf(10), ordinals.get(0));
118-
assertEquals(Integer.valueOf(1), indices.get(1));
119-
assertEquals(Integer.valueOf(20), ordinals.get(1));
120-
assertEquals(Integer.valueOf(2), indices.get(2));
121-
assertEquals(Integer.valueOf(30), ordinals.get(2));
122-
}
123-
124102
@Test
125103
public void testGetSegmentRowIdAndOrdinalBoundaryChecks()
126104
{
@@ -158,9 +136,6 @@ public void testOperationsOnEmptyArray()
158136

159137
pairs.forEachSegmentRowIdOrdinalPair((x, y) -> count.incrementAndGet());
160138
assertEquals(0, count.get());
161-
162-
pairs.forEachIndexOrdinalPair((x, y) -> count.incrementAndGet());
163-
assertEquals(0, count.get());
164139
}
165140

166141
@Test
@@ -170,4 +145,81 @@ public void testZeroCapacity()
170145
assertEquals(0, pairs.size());
171146
assertThrows(IndexOutOfBoundsException.class, () -> pairs.add(1, 10));
172147
}
148+
149+
@Test
150+
public void testMapToSegmentRowIdScoreIterator()
151+
{
152+
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
153+
pairs.add(1, 10);
154+
pairs.add(2, 20);
155+
pairs.add(3, 30);
156+
157+
// Create a simple score function that returns the ordinal value divided by 10 as the score
158+
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;
159+
160+
NodeQueue.NodeScoreIterator iterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);
161+
162+
// Test first pair
163+
assertTrue(iterator.hasNext());
164+
assertEquals(1.0f, iterator.topScore(), 0.001f);
165+
assertEquals(1, iterator.pop());
166+
167+
// Test second pair
168+
assertTrue(iterator.hasNext());
169+
assertEquals(2.0f, iterator.topScore(), 0.001f);
170+
assertEquals(2, iterator.pop());
171+
172+
// Test third pair
173+
assertTrue(iterator.hasNext());
174+
assertEquals(3.0f, iterator.topScore(), 0.001f);
175+
assertEquals(3, iterator.pop());
176+
177+
// Test end of iteration
178+
assertFalse(iterator.hasNext());
179+
}
180+
181+
@Test
182+
public void testMapToIndexScoreIterator()
183+
{
184+
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
185+
pairs.add(1, 10);
186+
pairs.add(2, 20);
187+
pairs.add(3, 30);
188+
189+
// Create a simple score function that returns the ordinal value divided by 10 as the score
190+
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;
191+
192+
NodeQueue.NodeScoreIterator iterator = pairs.mapToIndexScoreIterator(scoreFunction);
193+
194+
// Test first pair
195+
assertTrue(iterator.hasNext());
196+
assertEquals(1.0f, iterator.topScore(), 0.001f);
197+
assertEquals(0, iterator.pop());
198+
199+
// Test second pair
200+
assertTrue(iterator.hasNext());
201+
assertEquals(2.0f, iterator.topScore(), 0.001f);
202+
assertEquals(1, iterator.pop());
203+
204+
// Test third pair
205+
assertTrue(iterator.hasNext());
206+
assertEquals(3.0f, iterator.topScore(), 0.001f);
207+
assertEquals(2, iterator.pop());
208+
209+
// Test end of iteration
210+
assertFalse(iterator.hasNext());
211+
}
212+
213+
@Test
214+
public void testEmptyIterators()
215+
{
216+
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(0);
217+
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;
218+
219+
NodeQueue.NodeScoreIterator segmentRowIdIterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);
220+
assertFalse(segmentRowIdIterator.hasNext());
221+
222+
NodeQueue.NodeScoreIterator indexIterator = pairs.mapToIndexScoreIterator(scoreFunction);
223+
assertFalse(indexIterator.hasNext());
224+
}
173225
}

0 commit comments

Comments
 (0)