diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index b53ff7460eaf..22217c38f2c0 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -299,9 +299,8 @@ private CloseableIterator orderByBruteForce(CompressedVectors cv // Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can // retrieve both values with O(1) lookup when we need to resolve the full resolution score in the // BruteForceRowIdIterator. - segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> { - approximateScores.push(i, scoreFunction.similarityTo(ordinal)); - }); + var iter = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction); + approximateScores.pushMany(iter, segmentOrdinalPairs.size()); columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size()); var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView()); return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics); @@ -320,9 +319,8 @@ private CloseableIterator orderByBruteForce(VectorFloat query var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction(); var esf = vectorsView.rerankerFor(queryVector, similarityFunction); // Because the scores are exact, we only store the rowid, score pair. - segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> { - scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal)); - }); + var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(esf); + scoredRowIds.pushMany(iter, segmentOrdinalPairs.size()); columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size()); return new NodeQueueRowIdIterator(scoredRowIds); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java b/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java index e2b2aa22b93f..392341fb1298 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java @@ -26,12 +26,10 @@ import org.apache.cassandra.index.sai.utils.RowIdWithScore; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; -import org.apache.cassandra.utils.SortingIterator; /** - * An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link SortingIterator} of - * {@link RowWithApproximateScore}. + * An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link NodeQueue} of approximate scores. *

* The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top * `k` best approximate scores at a time with the full resolution vectors to return the top `limit`. diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java index ba8bfbcbd0ad..29039affe662 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -170,32 +170,55 @@ public static CloseableIterator computeScores(CloseableIt // Calculate average document length double avgDocLength = totalTermCount / documents.size(); - // Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity + // Calculate BM25 scores. + // Uses a NodeQueue that avoids allocating an object for each document. var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP); - for (int i = 0; i < documents.size(); i++) - { - var doc = documents.get(i); - double score = 0.0; - for (var queryTerm : queryTerms) - { - int tf = doc.getTermFrequency(queryTerm); - Long df = docStats.frequencies.get(queryTerm); - // we shouldn't have more hits for a term than we counted total documents - assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); - - double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength)); - double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); - double deltaScore = normalizedTf * idf; - assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", - tf, df, doc.termCount(), docStats.docCount, deltaScore); - score += deltaScore; + // Create an anonymous NodeScoreIterator that holds the logic for computing BM25 + var iter = new NodeQueue.NodeScoreIterator() { + int current = 0; + + @Override + public boolean hasNext() { + return current < documents.size(); } - nodeQueue.push(i, (float) score); - } + + @Override + public int pop() { + return current++; + } + + @Override + public float topScore() { + // Compute BM25 for the current document + return scoreDoc(documents.get(current), docStats, queryTerms, avgDocLength); + } + }; + // pushMany is an O(n) operation where n is the final size of the queue. Iterative calls to push is O(n log n). + nodeQueue.pushMany(iter, documents.size()); return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator); } + private static float scoreDoc(DocTF doc, DocStats docStats, List queryTerms, double avgDocLength) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + // we shouldn't have more hits for a term than we counted total documents + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); + + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount(), docStats.docCount, deltaScore); + score += deltaScore; + } + return (float) score; + } + private static class NodeQueueDocTFIterator extends AbstractIterator { private final NodeQueue nodeQueue; diff --git a/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java b/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java index 0f87fc279185..0c47de40e181 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java +++ b/src/java/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairs.java @@ -20,6 +20,8 @@ import java.util.function.IntConsumer; +import io.github.jbellis.jvector.graph.NodeQueue; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import org.agrona.collections.IntIntConsumer; /** @@ -33,7 +35,7 @@ public class SegmentRowIdOrdinalPairs private final int[] array; /** - * Create a new IntIntPairArray with the given capacity. + * Create a new SegmentRowIdOrdinalPairs with the given capacity. * @param capacity the capacity */ public SegmentRowIdOrdinalPairs(int capacity) @@ -102,15 +104,53 @@ public void forEachSegmentRowIdOrdinalPair(IntIntConsumer consumer) } /** - * Iterate over the pairs in the array, calling the consumer for each pair passing (index, x, y). - * @param consumer the consumer to call for each pair + * Create an iterator over the segment row id and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal */ - public void forEachIndexOrdinalPair(IntIntConsumer consumer) + public NodeQueue.NodeScoreIterator mapToSegmentRowIdScoreIterator(ScoreFunction scoreFunction) { - for (int i = 0; i < size; i++) - consumer.accept(i, array[i * 2 + 1]); + return mapToScoreIterator(scoreFunction, false); + } + + /** + * Create an iterator over the index and scored ordinal pairs in the array. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + */ + public NodeQueue.NodeScoreIterator mapToIndexScoreIterator(ScoreFunction scoreFunction) + { + return mapToScoreIterator(scoreFunction, true); } + /** + * Create an iterator over the index or the segment row id and the score for the ordinal. + * @param scoreFunction the score function to use to compute the next score based on the ordinal + * @param mapToIndex whether to map to the index or the segment row id + */ + private NodeQueue.NodeScoreIterator mapToScoreIterator(ScoreFunction scoreFunction, boolean mapToIndex) + { + return new NodeQueue.NodeScoreIterator() + { + int i = 0; + + @Override + public boolean hasNext() + { + return i < size; + } + + @Override + public int pop() + { + return mapToIndex ? i++ : array[i++ * 2]; + } + + @Override + public float topScore() + { + return scoreFunction.similarityTo(array[i * 2 + 1]); + } + }; + } /** * Calls the consumer for each right value in each pair of the array. diff --git a/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java b/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java index 87daed0e609e..2f05cd4cb0a2 100644 --- a/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java +++ b/test/unit/org/apache/cassandra/index/sai/utils/SegmentRowIdOrdinalPairsTest.java @@ -16,6 +16,8 @@ package org.apache.cassandra.index.sai.utils; +import io.github.jbellis.jvector.graph.NodeQueue; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -23,7 +25,9 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; public class SegmentRowIdOrdinalPairsTest { @@ -95,32 +99,6 @@ public void testForEachSegmentRowIdOrdinalPair() assertEquals(Integer.valueOf(30), ordinals.get(2)); } - @Test - public void testForEachIndexOrdinalPair() - { - SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); - pairs.add(1, 10); - pairs.add(2, 20); - pairs.add(3, 30); - - List indices = new ArrayList<>(); - List ordinals = new ArrayList<>(); - - pairs.forEachIndexOrdinalPair((index, ordinal) -> { - indices.add(index); - ordinals.add(ordinal); - }); - - assertEquals(3, indices.size()); - assertEquals(3, ordinals.size()); - assertEquals(Integer.valueOf(0), indices.get(0)); - assertEquals(Integer.valueOf(10), ordinals.get(0)); - assertEquals(Integer.valueOf(1), indices.get(1)); - assertEquals(Integer.valueOf(20), ordinals.get(1)); - assertEquals(Integer.valueOf(2), indices.get(2)); - assertEquals(Integer.valueOf(30), ordinals.get(2)); - } - @Test public void testGetSegmentRowIdAndOrdinalBoundaryChecks() { @@ -158,9 +136,6 @@ public void testOperationsOnEmptyArray() pairs.forEachSegmentRowIdOrdinalPair((x, y) -> count.incrementAndGet()); assertEquals(0, count.get()); - - pairs.forEachIndexOrdinalPair((x, y) -> count.incrementAndGet()); - assertEquals(0, count.get()); } @Test @@ -170,4 +145,81 @@ public void testZeroCapacity() assertEquals(0, pairs.size()); assertThrows(IndexOutOfBoundsException.class, () -> pairs.add(1, 10)); } + + @Test + public void testMapToSegmentRowIdScoreIterator() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); + pairs.add(1, 10); + pairs.add(2, 20); + pairs.add(3, 30); + + // Create a simple score function that returns the ordinal value divided by 10 as the score + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator iterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction); + + // Test first pair + assertTrue(iterator.hasNext()); + assertEquals(1.0f, iterator.topScore(), 0.001f); + assertEquals(1, iterator.pop()); + + // Test second pair + assertTrue(iterator.hasNext()); + assertEquals(2.0f, iterator.topScore(), 0.001f); + assertEquals(2, iterator.pop()); + + // Test third pair + assertTrue(iterator.hasNext()); + assertEquals(3.0f, iterator.topScore(), 0.001f); + assertEquals(3, iterator.pop()); + + // Test end of iteration + assertFalse(iterator.hasNext()); + } + + @Test + public void testMapToIndexScoreIterator() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3); + pairs.add(1, 10); + pairs.add(2, 20); + pairs.add(3, 30); + + // Create a simple score function that returns the ordinal value divided by 10 as the score + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator iterator = pairs.mapToIndexScoreIterator(scoreFunction); + + // Test first pair + assertTrue(iterator.hasNext()); + assertEquals(1.0f, iterator.topScore(), 0.001f); + assertEquals(0, iterator.pop()); + + // Test second pair + assertTrue(iterator.hasNext()); + assertEquals(2.0f, iterator.topScore(), 0.001f); + assertEquals(1, iterator.pop()); + + // Test third pair + assertTrue(iterator.hasNext()); + assertEquals(3.0f, iterator.topScore(), 0.001f); + assertEquals(2, iterator.pop()); + + // Test end of iteration + assertFalse(iterator.hasNext()); + } + + @Test + public void testEmptyIterators() + { + SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(0); + ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f; + + NodeQueue.NodeScoreIterator segmentRowIdIterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction); + assertFalse(segmentRowIdIterator.hasNext()); + + NodeQueue.NodeScoreIterator indexIterator = pairs.mapToIndexScoreIterator(scoreFunction); + assertFalse(indexIterator.hasNext()); + } }