Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
// Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
// retrieve both values with O(1) lookup when we need to resolve the full resolution score in the
// BruteForceRowIdIterator.
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
});
var iter = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction);
approximateScores.pushMany(iter, segmentOrdinalPairs.size());
columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size());
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics);
Expand All @@ -320,9 +319,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
var esf = vectorsView.rerankerFor(queryVector, similarityFunction);
// Because the scores are exact, we only store the rowid, score pair.
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
});
var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(esf);
scoredRowIds.pushMany(iter, segmentOrdinalPairs.size());
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
return new NodeQueueRowIdIterator(scoredRowIds);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.cassandra.utils.AbstractIterator;
import org.apache.cassandra.utils.SortingIterator;


/**
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link SortingIterator} of
* {@link RowWithApproximateScore}.
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link NodeQueue} of approximate scores.
* <p>
* The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top
* `k` best approximate scores at a time with the full resolution vectors to return the top `limit`.
Expand Down
63 changes: 43 additions & 20 deletions src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,32 +170,55 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
// Calculate average document length
double avgDocLength = totalTermCount / documents.size();

// Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity
// Calculate BM25 scores.
// Uses a NodeQueue that avoids allocating an object for each document.
var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP);
for (int i = 0; i < documents.size(); i++)
{
var doc = documents.get(i);
double score = 0.0;
for (var queryTerm : queryTerms)
{
int tf = doc.getTermFrequency(queryTerm);
Long df = docStats.frequencies.get(queryTerm);
// we shouldn't have more hits for a term than we counted total documents
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);

double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
double deltaScore = normalizedTf * idf;
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
tf, df, doc.termCount(), docStats.docCount, deltaScore);
score += deltaScore;
// Create an anonymous NodeScoreIterator that holds the logic for computing BM25
var iter = new NodeQueue.NodeScoreIterator() {
int current = 0;

@Override
public boolean hasNext() {
return current < documents.size();
}
nodeQueue.push(i, (float) score);
}

@Override
public int pop() {
return current++;
}

@Override
public float topScore() {
// Compute BM25 for the current document
return scoreDoc(documents.get(current), docStats, queryTerms, avgDocLength);
}
};
// pushMany is an O(n) operation where n is the final size of the queue. Iterative calls to push is O(n log n).
nodeQueue.pushMany(iter, documents.size());

return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator);
}

private static float scoreDoc(DocTF doc, DocStats docStats, List<ByteBuffer> queryTerms, double avgDocLength)
{
double score = 0.0;
for (var queryTerm : queryTerms)
{
int tf = doc.getTermFrequency(queryTerm);
Long df = docStats.frequencies.get(queryTerm);
// we shouldn't have more hits for a term than we counted total documents
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);

double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
double deltaScore = normalizedTf * idf;
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
tf, df, doc.termCount(), docStats.docCount, deltaScore);
score += deltaScore;
}
return (float) score;
}

private static class NodeQueueDocTFIterator extends AbstractIterator<PrimaryKeyWithSortKey>
{
private final NodeQueue nodeQueue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.util.function.IntConsumer;

import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import org.agrona.collections.IntIntConsumer;

/**
Expand All @@ -33,7 +35,7 @@ public class SegmentRowIdOrdinalPairs
private final int[] array;

/**
* Create a new IntIntPairArray with the given capacity.
* Create a new SegmentRowIdOrdinalPairs with the given capacity.
* @param capacity the capacity
*/
public SegmentRowIdOrdinalPairs(int capacity)
Expand Down Expand Up @@ -102,15 +104,53 @@ public void forEachSegmentRowIdOrdinalPair(IntIntConsumer consumer)
}

/**
* Iterate over the pairs in the array, calling the consumer for each pair passing (index, x, y).
* @param consumer the consumer to call for each pair
* Create an iterator over the segment row id and scored ordinal pairs in the array.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
*/
public void forEachIndexOrdinalPair(IntIntConsumer consumer)
public NodeQueue.NodeScoreIterator mapToSegmentRowIdScoreIterator(ScoreFunction scoreFunction)
{
for (int i = 0; i < size; i++)
consumer.accept(i, array[i * 2 + 1]);
return mapToScoreIterator(scoreFunction, false);
}

/**
* Create an iterator over the index and scored ordinal pairs in the array.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
*/
public NodeQueue.NodeScoreIterator mapToIndexScoreIterator(ScoreFunction scoreFunction)
{
return mapToScoreIterator(scoreFunction, true);
}

/**
* Create an iterator over the index or the segment row id and the score for the ordinal.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
* @param mapToIndex whether to map to the index or the segment row id
*/
private NodeQueue.NodeScoreIterator mapToScoreIterator(ScoreFunction scoreFunction, boolean mapToIndex)
{
return new NodeQueue.NodeScoreIterator()
{
int i = 0;

@Override
public boolean hasNext()
{
return i < size;
}

@Override
public int pop()
{
return mapToIndex ? i++ : array[i++ * 2];
}

@Override
public float topScore()
{
return scoreFunction.similarityTo(array[i * 2 + 1]);
}
};
}

/**
* Calls the consumer for each right value in each pair of the array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

package org.apache.cassandra.index.sai.utils;

import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

public class SegmentRowIdOrdinalPairsTest
{
Expand Down Expand Up @@ -95,32 +99,6 @@ public void testForEachSegmentRowIdOrdinalPair()
assertEquals(Integer.valueOf(30), ordinals.get(2));
}

@Test
public void testForEachIndexOrdinalPair()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

List<Integer> indices = new ArrayList<>();
List<Integer> ordinals = new ArrayList<>();

pairs.forEachIndexOrdinalPair((index, ordinal) -> {
indices.add(index);
ordinals.add(ordinal);
});

assertEquals(3, indices.size());
assertEquals(3, ordinals.size());
assertEquals(Integer.valueOf(0), indices.get(0));
assertEquals(Integer.valueOf(10), ordinals.get(0));
assertEquals(Integer.valueOf(1), indices.get(1));
assertEquals(Integer.valueOf(20), ordinals.get(1));
assertEquals(Integer.valueOf(2), indices.get(2));
assertEquals(Integer.valueOf(30), ordinals.get(2));
}

@Test
public void testGetSegmentRowIdAndOrdinalBoundaryChecks()
{
Expand Down Expand Up @@ -158,9 +136,6 @@ public void testOperationsOnEmptyArray()

pairs.forEachSegmentRowIdOrdinalPair((x, y) -> count.incrementAndGet());
assertEquals(0, count.get());

pairs.forEachIndexOrdinalPair((x, y) -> count.incrementAndGet());
assertEquals(0, count.get());
}

@Test
Expand All @@ -170,4 +145,81 @@ public void testZeroCapacity()
assertEquals(0, pairs.size());
assertThrows(IndexOutOfBoundsException.class, () -> pairs.add(1, 10));
}

@Test
public void testMapToSegmentRowIdScoreIterator()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

// Create a simple score function that returns the ordinal value divided by 10 as the score
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator iterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);

// Test first pair
assertTrue(iterator.hasNext());
assertEquals(1.0f, iterator.topScore(), 0.001f);
assertEquals(1, iterator.pop());

// Test second pair
assertTrue(iterator.hasNext());
assertEquals(2.0f, iterator.topScore(), 0.001f);
assertEquals(2, iterator.pop());

// Test third pair
assertTrue(iterator.hasNext());
assertEquals(3.0f, iterator.topScore(), 0.001f);
assertEquals(3, iterator.pop());

// Test end of iteration
assertFalse(iterator.hasNext());
}

@Test
public void testMapToIndexScoreIterator()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

// Create a simple score function that returns the ordinal value divided by 10 as the score
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator iterator = pairs.mapToIndexScoreIterator(scoreFunction);

// Test first pair
assertTrue(iterator.hasNext());
assertEquals(1.0f, iterator.topScore(), 0.001f);
assertEquals(0, iterator.pop());

// Test second pair
assertTrue(iterator.hasNext());
assertEquals(2.0f, iterator.topScore(), 0.001f);
assertEquals(1, iterator.pop());

// Test third pair
assertTrue(iterator.hasNext());
assertEquals(3.0f, iterator.topScore(), 0.001f);
assertEquals(2, iterator.pop());

// Test end of iteration
assertFalse(iterator.hasNext());
}

@Test
public void testEmptyIterators()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(0);
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator segmentRowIdIterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);
assertFalse(segmentRowIdIterator.hasNext());

NodeQueue.NodeScoreIterator indexIterator = pairs.mapToIndexScoreIterator(scoreFunction);
assertFalse(indexIterator.hasNext());
}
}