Skip to content

RetrievalResults as sequence of tensors #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 21, 2024
Merged

Conversation

AlekseySh
Copy link
Contributor

@AlekseySh AlekseySh commented May 18, 2024

CHANGELOG

  • RetrievalResults uses Sequence of Tensors which may have different size. In other words, it allows us to support the case when queries have different number of retrieved items.
  • Consequently, changed batched_knn, retrieval_metrics and PairwiseReranker to support new input type.
  • Added assert that distances arrive sorted to RetrievalResults, retrieved ids are unique and other checks.

New tests:

  • Added tests on corner cases for RetrievalResults creation.
  • Added tests on visualization when queries in RetrievalResults have different number of retrieved items.
  • Added new test with predefined values for batched_knn to make debugging easier.
  • Changed existing postprocessor tests: used sequence in datasets so queries have different number of retrieved items and we actually test new functionality.

@leoromanovich and I also checked that using Sequence of Tensors doesn't lead to poor performance on validation.

@AlekseySh AlekseySh changed the base branch from main to oml_3.0_release May 18, 2024 21:48
@AlekseySh AlekseySh changed the title RR as lists RetrievalResults as lists May 18, 2024
@AlekseySh AlekseySh self-assigned this May 18, 2024
@AlekseySh AlekseySh changed the title RetrievalResults as lists RetrievalResults as sequence of tensors May 19, 2024
@AlekseySh AlekseySh requested a review from DaloroAT May 20, 2024 23:28
distances_b_sorted, retrieved_ids_b = torch.topk(distances_b, k=top_n, largest=False, sorted=True)

# every query may have arbitrary number of retrieved items, so we are forced to use a loop to store the results
for dist, ids in zip(distances_b_sorted, retrieved_ids_b):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it simplifies or improves something, but you can split tensor by chunks to get tuple of tensors:

distances_b_sorted = torch.tensor(
    [
        [1.0, 2.0, 3.0, float("inf")],
        [float("inf"), float("inf"), float("inf"), float("inf")],
        [3.0, 5.0, 6, float("inf")],
    ]
)
retrieved_ids_b = torch.tensor([[10, 1, 2, 7], [3, 14, 5, 6]])

mask_to_keep = ~distances_b_sorted.isinf()
elems_per_query = mask_to_keep.sum(dim=1)

distances = torch.split(distances_b_sorted[mask_to_keep], elems_per_query.tolist())
retrieved_ids = torch.split(retrieved_ids_b[mask_to_keep], elems_per_query.tolist())

You can play with different amounts of infs. Or just leave as is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've check with TQDM that this function is not the bottleneck, so, let's do optimization later (after more urgent stuff)

top_k = _clip_max_with_warning(top_k, gt_tops.shape[1])

def precision_single(is_correct: BoolTensor, n_gt_: int, k_: int) -> float:
k_ = min(k_, len(is_correct))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might len(is_correct) be zero? 🤔 previously denominator was defined only by k or shape of gt_tops that are non-zero. But now it might be empty which is okay after postprocessing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. But support of this case is added in the next PR: #566

def precision_single(is_correct: BoolTensor, n_gt_: int, k_: int) -> float:
k_ = min(k_, len(is_correct))
value = torch.cumsum(is_correct, dim=0)[k_ - 1] / min(n_gt_, k_)
return float(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that converting torch.float [torch.cumsum()/... inside func] -> float [on return] -> torch.float [outside after func] makes sense when using the inner function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I return float from the inner function because I expect here just a single value without any extra dimensions which I may have if I keep it as tensor

assert retrieved_ids.shape[1] <= len(dataset.get_gallery_ids())
assert len(dataset.get_query_ids()) == len(
rr.retrieved_ids
), "RetrievalResults and dataset must have the same number of queries."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RetrievalResults.__name__ or rr.__class__.__name__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed, thx

gt_ids: List[LongTensor] = None,
distances: Sequence[FloatTensor],
retrieved_ids: Sequence[LongTensor],
gt_ids: Sequence[LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional[Sequence[LongTensor]] = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, done

assert distances.shape == retrieved_ids.shape
assert distances.ndim == 2
for d, r in zip(distances, retrieved_ids):
if not (d[:-1] <= d[1:]).all():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically it's okay to have empty distances and retrieved_ids (e.g. after filtering by threshold) for some queries, and it's funny that all these checks pass 😁 I was not sure about how 1 check should work for empty tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, it's expected

I've added test on that:

    # we retrieved nothing, but it's not a error
    RetrievalResults(
        distances=[FloatTensor([])],
        retrieved_ids=[LongTensor([])],
        gt_ids=[LongTensor([1])],
    )


"""
assert len(embeddings) == len(dataset), "Embeddings and dataset must have the same size."

if SEQUENCE_COLUMN in dataset.extra_data:
sequence_ids = LongTensor(pd.factorize(dataset.extra_data[SEQUENCE_COLUMN])[0])
sequence = pd.Series(dataset.extra_data[SEQUENCE_COLUMN])
sequence_ids = LongTensor(pd.factorize(sequence)[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's not related to this PR, but I recommend to add param sort=True to pd.factorize. I worked in sim-api with categories or some handcrafted column and my df was recreated in runtime each time depends on params. And each time raw values encoded to different values (I used them later for debugging). Unique values of such "categories" were the same, just with different amounts and orders.

It's not happening if df remains the same, but why not ...

import pandas as pd
from random import shuffle

value = [1, 1, 2, 3, 1, 2, 3, 4]
shuffle(value)

df = pd.DataFrame({"raw": value})
df["encode"] = pd.factorize(df["raw"])[0]

df = df.sort_values(by="raw")

print(df)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I didn't know about this. And thank you for the snippet. Changed.

@@ -89,17 +105,19 @@ def compute_from_embeddings(
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids)

def __str__(self) -> str:
max_el_to_show = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, done
also changed naming: _max_elements_in_str_repr

@@ -125,13 +143,18 @@ def visualize(
if not isinstance(dataset, IQueryGalleryDataset):
raise TypeError(f"Dataset has to support {IQueryGalleryDataset.__name__}. Got {type(dataset)}.")

nq1, nq2 = len(self.retrieved_ids), len(dataset.get_query_ids())
if nq1 != nq2:
raise RuntimeError(f"Number of queries in RetrievalResults and Dataset must match: {nq1} != {nq2}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract names of classes programmatically as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type(dataset) will show you full path to class, it's not friendly
dataset.__class__.__name__ shows only name
and
self.__class__.__name__ for RetrievalResults

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx!

@DaloroAT
Copy link
Collaborator

  • RetrievalResults uses Sequence of Tensors which may have different size. In other words, it allows us to support the case when queries have different number of retrieved items.
  • Consequently, changed batched_knn, retrieval_metrics and PairwiseReranker to support new input type.

image

@AlekseySh AlekseySh merged commit eeac2b9 into oml_3.0_release May 21, 2024
8 checks passed
@AlekseySh AlekseySh deleted the rr_as_lists branch May 21, 2024 20:09
@AlekseySh AlekseySh restored the rr_as_lists branch May 21, 2024 20:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants