Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/java/org/apache/cassandra/cql3/UntypedResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public double getDouble(String column)
return DoubleType.instance.compose(data.get(column));
}

public double getFloat(String column)
public float getFloat(String column)
{
return FloatType.instance.compose(data.get(column));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
// Rerankless search, so we go straight to the NodeQueueRowIdIterator.
var iter = segmentOrdinalPairs.mapToSegmentRowIdScoreIterator(scoreFunction);
approximateScores.pushMany(iter, segmentOrdinalPairs.size());
return new NodeQueueRowIdIterator(approximateScores, graph.usesNVQ());
return new NodeQueueRowIdIterator(approximateScores, true);
}

// Store the index of the (rowId, ordinal) pair from the segmentOrdinalPairs in the NodeQueue so that we can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,19 @@ else if (compressedVectors == null)
limit, rerankK, isRerankless, usePruning, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
columnQueryMetrics.onSearchResult(result, elapsed, false);
context.addAnnGraphSearchLatency(elapsed);
boolean isScoreApproximate = usesNVQ || isRerankless;
if (threshold > 0)
{
// Threshold based searches are comprehensive and do not need to resume the search.
graphAccessManager.release();
nodesVisitedConsumer.accept(result.getVisitedCount());
var nodeScores = CloseableIterator.wrap(Arrays.stream(result.getNodes()).iterator());
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView(), usesNVQ);
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView(), isScoreApproximate);
}
else
{
var nodeScores = new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, context, columnQueryMetrics, nodesVisitedConsumer, limit, rerankK, false, source.toString());
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView(), usesNVQ);
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView(), isScoreApproximate);
}
}
catch (Throwable t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ private void validateQueries()
assertEquals(10, r.size());
for (var row : r)
{
float similarity = (float) row.getFloat("similarity");
float similarity = row.getFloat("similarity");
assertTrue(similarity <= lastSimilarity);
lastSimilarity = similarity;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,52 @@ private void ensureDisablingPruningIncreasesRecall(List<float[]> queryVectors, L
recallWithoutPruning >= recallWithPruning);
}

// Note: test only fails when scores are sent from replica to coordinator.
@Test
public void testRerankKZeroOrderMatchesFullPrecisionSimilarity() throws Throwable
{
var baseVectors = readFvecs(String.format("test/data/%s/%s_base.fvecs", DATASET, DATASET));
var queryVectors = readFvecs(String.format("test/data/%s/%s_query.fvecs", DATASET, DATASET));

// Create table and index
createTable();
createIndex();

// Flush because in memory index uses FP vectors, therefore ignoring rerank_k = 0
insertVectors(baseVectors, 0);
flush();

// Test with a subset of query vectors to keep test runtime reasonable, but query with a high limit to
// increase probability for incorrect ordering
int numQueriesToTest = 10;
int limit = 100;

for (int queryIdx = 0; queryIdx < numQueriesToTest; queryIdx++)
{
float[] queryVector = queryVectors.get(queryIdx);
String queryVectorAsString = Arrays.toString(queryVector);

// Execute query with rerank_k = 0 and get the similarity scores computed by the coordinator
String query = String.format("SELECT pk, similarity_euclidean(val, %s) as similarity FROM %%s ORDER BY val ANN OF %s LIMIT %d WITH ann_options = {'rerank_k': 0}",
queryVectorAsString, queryVectorAsString, limit);
UntypedResultSet result = execute(query);

// Verify that results are in descending order of similarity score
// (Euclidean similarity is 1.0 / (1.0 + distance²), so higher score = more similar)
float lastSimilarity = Float.MAX_VALUE;
assertEquals(limit, result.size());
for (UntypedResultSet.Row row : result)
{
float similarity = row.getFloat("similarity");
assertTrue(String.format("Query %d: Similarity scores should be in descending order (higher score = more similar). " +
"Previous: %.10f, Current: %.10f",
queryIdx, lastSimilarity, similarity),
similarity <= lastSimilarity);
lastSimilarity = similarity;
}
}
}

@Test
public void testCompaction() throws Throwable
{
Expand Down
Loading