-
Notifications
You must be signed in to change notification settings - Fork 71
RetrievalResults as sequence of tensors #565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
distances_b_sorted, retrieved_ids_b = torch.topk(distances_b, k=top_n, largest=False, sorted=True) | ||
|
||
# every query may have arbitrary number of retrieved items, so we are forced to use a loop to store the results | ||
for dist, ids in zip(distances_b_sorted, retrieved_ids_b): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it simplifies or improves something, but you can split tensor by chunks to get tuple of tensors:
distances_b_sorted = torch.tensor(
[
[1.0, 2.0, 3.0, float("inf")],
[float("inf"), float("inf"), float("inf"), float("inf")],
[3.0, 5.0, 6, float("inf")],
]
)
retrieved_ids_b = torch.tensor([[10, 1, 2, 7], [3, 14, 5, 6]])
mask_to_keep = ~distances_b_sorted.isinf()
elems_per_query = mask_to_keep.sum(dim=1)
distances = torch.split(distances_b_sorted[mask_to_keep], elems_per_query.tolist())
retrieved_ids = torch.split(retrieved_ids_b[mask_to_keep], elems_per_query.tolist())
You can play with different amounts of infs. Or just leave as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've check with TQDM that this function is not the bottleneck, so, let's do optimization later (after more urgent stuff)
top_k = _clip_max_with_warning(top_k, gt_tops.shape[1]) | ||
|
||
def precision_single(is_correct: BoolTensor, n_gt_: int, k_: int) -> float: | ||
k_ = min(k_, len(is_correct)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might len(is_correct)
be zero? 🤔 previously denominator was defined only by k
or shape of gt_tops
that are non-zero. But now it might be empty which is okay after postprocessing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. But support of this case is added in the next PR: #566
def precision_single(is_correct: BoolTensor, n_gt_: int, k_: int) -> float: | ||
k_ = min(k_, len(is_correct)) | ||
value = torch.cumsum(is_correct, dim=0)[k_ - 1] / min(n_gt_, k_) | ||
return float(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that converting torch.float [torch.cumsum()/... inside func] -> float [on return] -> torch.float [outside after func]
makes sense when using the inner function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I return float from the inner function because I expect here just a single value without any extra dimensions which I may have if I keep it as tensor
assert retrieved_ids.shape[1] <= len(dataset.get_gallery_ids()) | ||
assert len(dataset.get_query_ids()) == len( | ||
rr.retrieved_ids | ||
), "RetrievalResults and dataset must have the same number of queries." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RetrievalResults.__name__
or rr.__class__.__name__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed, thx
oml/retrieval/retrieval_results.py
Outdated
gt_ids: List[LongTensor] = None, | ||
distances: Sequence[FloatTensor], | ||
retrieved_ids: Sequence[LongTensor], | ||
gt_ids: Sequence[LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[Sequence[LongTensor]] = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx, done
assert distances.shape == retrieved_ids.shape | ||
assert distances.ndim == 2 | ||
for d, r in zip(distances, retrieved_ids): | ||
if not (d[:-1] <= d[1:]).all(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logically it's okay to have empty distances and retrieved_ids (e.g. after filtering by threshold) for some queries, and it's funny that all these checks pass 😁 I was not sure about how 1 check should work for empty tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, it's expected
I've added test on that:
# we retrieved nothing, but it's not a error
RetrievalResults(
distances=[FloatTensor([])],
retrieved_ids=[LongTensor([])],
gt_ids=[LongTensor([1])],
)
oml/retrieval/retrieval_results.py
Outdated
|
||
""" | ||
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size." | ||
|
||
if SEQUENCE_COLUMN in dataset.extra_data: | ||
sequence_ids = LongTensor(pd.factorize(dataset.extra_data[SEQUENCE_COLUMN])[0]) | ||
sequence = pd.Series(dataset.extra_data[SEQUENCE_COLUMN]) | ||
sequence_ids = LongTensor(pd.factorize(sequence)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's not related to this PR, but I recommend to add param sort=True
to pd.factorize
. I worked in sim-api with categories or some handcrafted column and my df
was recreated in runtime each time depends on params. And each time raw values encoded to different values (I used them later for debugging). Unique values of such "categories" were the same, just with different amounts and orders.
It's not happening if df
remains the same, but why not ...
import pandas as pd
from random import shuffle
value = [1, 1, 2, 3, 1, 2, 3, 4]
shuffle(value)
df = pd.DataFrame({"raw": value})
df["encode"] = pd.factorize(df["raw"])[0]
df = df.sort_values(by="raw")
print(df)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I didn't know about this. And thank you for the snippet. Changed.
oml/retrieval/retrieval_results.py
Outdated
@@ -89,17 +105,19 @@ def compute_from_embeddings( | |||
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids) | |||
|
|||
def __str__(self) -> str: | |||
max_el_to_show = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, done
also changed naming: _max_elements_in_str_repr
oml/retrieval/retrieval_results.py
Outdated
@@ -125,13 +143,18 @@ def visualize( | |||
if not isinstance(dataset, IQueryGalleryDataset): | |||
raise TypeError(f"Dataset has to support {IQueryGalleryDataset.__name__}. Got {type(dataset)}.") | |||
|
|||
nq1, nq2 = len(self.retrieved_ids), len(dataset.get_query_ids()) | |||
if nq1 != nq2: | |||
raise RuntimeError(f"Number of queries in RetrievalResults and Dataset must match: {nq1} != {nq2}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extract names of classes programmatically as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type(dataset)
will show you full path to class, it's not friendly
dataset.__class__.__name__
shows only name
and
self.__class__.__name__
for RetrievalResults
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx!
CHANGELOG
RetrievalResults
uses Sequence of Tensors which may have different size. In other words, it allows us to support the case when queries have different number of retrieved items.batched_knn
,retrieval_metrics
andPairwiseReranker
to support new input type.RetrievalResults
, retrieved ids are unique and other checks.New tests:
RetrievalResults
creation.RetrievalResults
have different number of retrieved items.batched_knn
to make debugging easier.sequence
in datasets so queries have different number of retrieved items and we actually test new functionality.@leoromanovich and I also checked that using Sequence of Tensors doesn't lead to poor performance on validation.