Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.xml
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@
<dependency groupId="org.apache.lucene" artifactId="lucene-core" version="9.8.0-5ea8bb4f21" />
<dependency groupId="org.apache.lucene" artifactId="lucene-analysis-common" version="9.8.0-5ea8bb4f21" />
<dependency groupId="org.apache.lucene" artifactId="lucene-backward-codecs" version="9.8.0-5ea8bb4f21" />
<dependency groupId="io.github.jbellis" artifactId="jvector" version="4.0.0-beta.3" />
<dependency groupId="io.github.jbellis" artifactId="jvector" version="4.0.0-beta.4" />
<dependency groupId="com.bpodgursky" artifactId="jbool_expressions" version="1.14" scope="test"/>

<dependency groupId="com.carrotsearch.randomizedtesting" artifactId="randomizedtesting-runner" version="2.1.2" scope="test">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
// Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
// retrieve both values with O(1) lookup when we need to resolve the full resolution score in the
// BruteForceRowIdIterator.
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
});
var iter = segmentOrdinalPairs.mapToIndexScoreIterator(scoreFunction);
approximateScores.pushMany(iter, segmentOrdinalPairs.size());
columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size());
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics);
Expand All @@ -320,9 +319,8 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
var similarityFunction = indexContext.getIndexWriterConfig().getSimilarityFunction();
var esf = vectorsView.rerankerFor(queryVector, similarityFunction);
// Because the scores are exact, we only store the rowid, score pair.
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
});
var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(esf);
scoredRowIds.pushMany(iter, segmentOrdinalPairs.size());
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
return new NodeQueueRowIdIterator(scoredRowIds);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.cassandra.utils.AbstractIterator;
import org.apache.cassandra.utils.SortingIterator;


/**
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link SortingIterator} of
* {@link RowWithApproximateScore}.
* An iterator over {@link RowIdWithMeta} that lazily consumes from a {@link NodeQueue} of approximate scores.
* <p>
* The idea is that we maintain the same level of accuracy as we would get from a graph search, by re-ranking the top
* `k` best approximate scores at a time with the full resolution vectors to return the top `limit`.
Expand Down
63 changes: 43 additions & 20 deletions src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,32 +170,55 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
// Calculate average document length
double avgDocLength = totalTermCount / documents.size();

// Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity
// Calculate BM25 scores.
// Uses a NodeQueue that avoids allocating an object for each document.
var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP);
for (int i = 0; i < documents.size(); i++)
{
var doc = documents.get(i);
double score = 0.0;
for (var queryTerm : queryTerms)
{
int tf = doc.getTermFrequency(queryTerm);
Long df = docStats.frequencies.get(queryTerm);
// we shouldn't have more hits for a term than we counted total documents
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);

double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
double deltaScore = normalizedTf * idf;
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
tf, df, doc.termCount(), docStats.docCount, deltaScore);
score += deltaScore;
// Create an anonymous NodeScoreIterator that holds the logic for computing BM25
var iter = new NodeQueue.NodeScoreIterator() {
int current = 0;

@Override
public boolean hasNext() {
return current < documents.size();
}
nodeQueue.push(i, (float) score);
}

@Override
public int pop() {
return current++;
}

@Override
public float topScore() {
// Compute BM25 for the current document
return scoreDoc(documents.get(current), docStats, queryTerms, avgDocLength);
}
};
// pushMany is an O(n) operation where n is the final size of the queue. Iterative calls to push is O(n log n).
nodeQueue.pushMany(iter, documents.size());

return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator);
}

private static float scoreDoc(DocTF doc, DocStats docStats, List<ByteBuffer> queryTerms, double avgDocLength)
{
double score = 0.0;
for (var queryTerm : queryTerms)
{
int tf = doc.getTermFrequency(queryTerm);
Long df = docStats.frequencies.get(queryTerm);
// we shouldn't have more hits for a term than we counted total documents
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);

double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
double deltaScore = normalizedTf * idf;
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
tf, df, doc.termCount(), docStats.docCount, deltaScore);
score += deltaScore;
}
return (float) score;
}

private static class NodeQueueDocTFIterator extends AbstractIterator<PrimaryKeyWithSortKey>
{
private final NodeQueue nodeQueue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.util.function.IntConsumer;

import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import org.agrona.collections.IntIntConsumer;

/**
Expand All @@ -33,7 +35,7 @@ public class SegmentRowIdOrdinalPairs
private final int[] array;

/**
* Create a new IntIntPairArray with the given capacity.
* Create a new SegmentRowIdOrdinalPairs with the given capacity.
* @param capacity the capacity
*/
public SegmentRowIdOrdinalPairs(int capacity)
Expand Down Expand Up @@ -102,15 +104,52 @@ public void forEachSegmentRowIdOrdinalPair(IntIntConsumer consumer)
}

/**
* Iterate over the pairs in the array, calling the consumer for each pair passing (index, x, y).
* @param consumer the consumer to call for each pair
* Create an iterator over the segment row id and scored ordinal pairs in the array.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
*/
public void forEachIndexOrdinalPair(IntIntConsumer consumer)
public NodeQueue.NodeScoreIterator mapToSegmentRowIdScoreIterator(ScoreFunction scoreFunction)
{
for (int i = 0; i < size; i++)
consumer.accept(i, array[i * 2 + 1]);
return mapToScoreIterator(scoreFunction, false);
}

/**
* Create an iterator over the index and scored ordinal pairs in the array.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
*/
public NodeQueue.NodeScoreIterator mapToIndexScoreIterator(ScoreFunction scoreFunction)
{
return mapToScoreIterator(scoreFunction, true);
}

/**
* Create an iterator over the index or the segment row id and the score for the ordinal.
* @param scoreFunction the score function to use to compute the next score based on the ordinal
*/
private NodeQueue.NodeScoreIterator mapToScoreIterator(ScoreFunction scoreFunction, boolean mapToIndex)
{
return new NodeQueue.NodeScoreIterator()
{
int i = 0;

@Override
public boolean hasNext()
{
return i < size;
}

@Override
public int pop()
{
return mapToIndex ? i++ : array[i++ * 2];
}

@Override
public float topScore()
{
return scoreFunction.similarityTo(array[i * 2 + 1]);
}
};
}

/**
* Calls the consumer for each right value in each pair of the array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

package org.apache.cassandra.index.sai.utils;

import io.github.jbellis.jvector.graph.NodeQueue;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

public class SegmentRowIdOrdinalPairsTest
{
Expand Down Expand Up @@ -95,32 +99,6 @@ public void testForEachSegmentRowIdOrdinalPair()
assertEquals(Integer.valueOf(30), ordinals.get(2));
}

@Test
public void testForEachIndexOrdinalPair()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

List<Integer> indices = new ArrayList<>();
List<Integer> ordinals = new ArrayList<>();

pairs.forEachIndexOrdinalPair((index, ordinal) -> {
indices.add(index);
ordinals.add(ordinal);
});

assertEquals(3, indices.size());
assertEquals(3, ordinals.size());
assertEquals(Integer.valueOf(0), indices.get(0));
assertEquals(Integer.valueOf(10), ordinals.get(0));
assertEquals(Integer.valueOf(1), indices.get(1));
assertEquals(Integer.valueOf(20), ordinals.get(1));
assertEquals(Integer.valueOf(2), indices.get(2));
assertEquals(Integer.valueOf(30), ordinals.get(2));
}

@Test
public void testGetSegmentRowIdAndOrdinalBoundaryChecks()
{
Expand Down Expand Up @@ -158,9 +136,6 @@ public void testOperationsOnEmptyArray()

pairs.forEachSegmentRowIdOrdinalPair((x, y) -> count.incrementAndGet());
assertEquals(0, count.get());

pairs.forEachIndexOrdinalPair((x, y) -> count.incrementAndGet());
assertEquals(0, count.get());
}

@Test
Expand All @@ -170,4 +145,81 @@ public void testZeroCapacity()
assertEquals(0, pairs.size());
assertThrows(IndexOutOfBoundsException.class, () -> pairs.add(1, 10));
}

@Test
public void testMapToSegmentRowIdScoreIterator()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

// Create a simple score function that returns the ordinal value divided by 10 as the score
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator iterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);

// Test first pair
assertTrue(iterator.hasNext());
assertEquals(1.0f, iterator.topScore(), 0.001f);
assertEquals(1, iterator.pop());

// Test second pair
assertTrue(iterator.hasNext());
assertEquals(2.0f, iterator.topScore(), 0.001f);
assertEquals(2, iterator.pop());

// Test third pair
assertTrue(iterator.hasNext());
assertEquals(3.0f, iterator.topScore(), 0.001f);
assertEquals(3, iterator.pop());

// Test end of iteration
assertFalse(iterator.hasNext());
}

@Test
public void testMapToIndexScoreIterator()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(3);
pairs.add(1, 10);
pairs.add(2, 20);
pairs.add(3, 30);

// Create a simple score function that returns the ordinal value divided by 10 as the score
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator iterator = pairs.mapToIndexScoreIterator(scoreFunction);

// Test first pair
assertTrue(iterator.hasNext());
assertEquals(1.0f, iterator.topScore(), 0.001f);
assertEquals(0, iterator.pop());

// Test second pair
assertTrue(iterator.hasNext());
assertEquals(2.0f, iterator.topScore(), 0.001f);
assertEquals(1, iterator.pop());

// Test third pair
assertTrue(iterator.hasNext());
assertEquals(3.0f, iterator.topScore(), 0.001f);
assertEquals(2, iterator.pop());

// Test end of iteration
assertFalse(iterator.hasNext());
}

@Test
public void testEmptyIterators()
{
SegmentRowIdOrdinalPairs pairs = new SegmentRowIdOrdinalPairs(0);
ScoreFunction.ExactScoreFunction scoreFunction = ordinal -> ordinal / 10.0f;

NodeQueue.NodeScoreIterator segmentRowIdIterator = pairs.mapToSegmentRowIdScoreIterator(scoreFunction);
assertFalse(segmentRowIdIterator.hasNext());

NodeQueue.NodeScoreIterator indexIterator = pairs.mapToIndexScoreIterator(scoreFunction);
assertFalse(indexIterator.hasNext());
}
}