Skip to content

Commit 848b6d6

Browse files
authored
Metrics: categories, empty predictions, tests, fnmr@fmr
**CHANGELOG** * Added support of empty predictions to the retrieval metrics. For example, it may be useful when we cut retrieval results by distance threshold). * Moved categories handling to functional metrics from `EmbeddingMetrics` class, also updated `.md` example to show how to deal with categories. * Added `calc_fnmr_at_fmr_rr`, removed `extract_pos_neg_dists` and `calc_fnmr_at_fmr_from_matrices`. Returned `fnmr@fmr` to `EmbeddingMetrics` (there was a todo). TESTS * Moved tests that use old formats of retrieval metrics to a separate folder: `...test_metrics/test_outdated/...`. * Added a few new tests on retrieval metrics: test handling categories and empty predictions. * Added test on `calc_fnmr_at_fmr_rr`. * Added test that `EmbeddingMetrics` calculate all expected metrics.
1 parent eeac2b9 commit 848b6d6

File tree

19 files changed

+350
-137
lines changed

19 files changed

+350
-137
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ wandb_login:
5959

6060
.PHONY: run_all_tests
6161
run_all_tests: download_mock_dataset wandb_login
62-
export PYTORCH_ENABLE_MPS_FALLBACK=1; pytest --disable-warnings -sv tests
62+
export PYTORCH_ENABLE_MPS_FALLBACK=1; export PYTHONPATH=.; pytest --disable-warnings -sv tests
6363
pytest --disable-warnings --doctest-modules --doctest-continue-on-failure -sv oml
6464

6565
.PHONY: run_short_tests
6666
run_short_tests: download_mock_dataset
67-
export PYTORCH_ENABLE_MPS_FALLBACK=1; pytest --disable-warnings -sv -m "not long and not needs_optional_dependency" tests
67+
export PYTORCH_ENABLE_MPS_FALLBACK=1; export PYTHONPATH=.; pytest --disable-warnings -sv -m "not long and not needs_optional_dependency" tests
6868
pytest --disable-warnings --doctest-modules --doctest-continue-on-failure -sv oml
6969

7070
.PHONY: test_converters
@@ -122,6 +122,7 @@ clean:
122122
find . -type d -name "ml-runs" -exec rm -r {} +
123123
find . -type d -name "logs" -exec rm -r {} +
124124
find . -type d -name ".ipynb_checkpoints" -exec rm -r {} +
125+
find . -type d -name ".hydra" -exec rm -r {} +
125126
find . -type f -name "*.log" -exec rm {} +
126127
find . -type f -name "*predictions.json" -exec rm {} +
127128
rm -rf docs/build

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ for batch in tqdm(train_loader):
334334
[comment]:vanilla-validation-start
335335
```python
336336

337+
import numpy as np
338+
337339
from oml.datasets import ImageQueryGalleryLabeledDataset
338340
from oml.inference import inference
339341
from oml.metrics import calc_retrieval_metrics_rr
@@ -345,17 +347,19 @@ from oml.registry.transforms import get_transforms_for_pretrained
345347
extractor = ViTExtractor.from_pretrained("vits16_dino")
346348
transform, _ = get_transforms_for_pretrained("vits16_dino")
347349

348-
_, df_val = download_mock_dataset(global_paths=True)
350+
_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_category.csv")
349351
dataset = ImageQueryGalleryLabeledDataset(df_val, transform=transform)
350352

353+
# you can optionally provide categories to have category wise metrics
354+
query_categories = np.array(dataset.extra_data["category"])[dataset.get_query_ids()]
355+
351356
embeddings = inference(extractor, dataset, batch_size=4)
352357

353358
rr = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve=5)
354-
metrics = calc_retrieval_metrics_rr(rr, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
359+
metrics = calc_retrieval_metrics_rr(rr, query_categories, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
355360

356361
print(rr, "\n", metrics)
357-
rr.visualize(query_ids=[2, 1], dataset=dataset).show()
358-
362+
rr.visualize(query_ids=[2, 1], dataset=dataset)
359363
```
360364
[comment]:vanilla-validation-end
361365
</p>

docs/readme/examples_source/extractor/val.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
[comment]:vanilla-validation-start
66
```python
77

8+
import numpy as np
9+
810
from oml.datasets import ImageQueryGalleryLabeledDataset
911
from oml.inference import inference
1012
from oml.metrics import calc_retrieval_metrics_rr
@@ -16,17 +18,19 @@ from oml.registry.transforms import get_transforms_for_pretrained
1618
extractor = ViTExtractor.from_pretrained("vits16_dino")
1719
transform, _ = get_transforms_for_pretrained("vits16_dino")
1820

19-
_, df_val = download_mock_dataset(global_paths=True)
21+
_, df_val = download_mock_dataset(global_paths=True, df_name="df_with_category.csv")
2022
dataset = ImageQueryGalleryLabeledDataset(df_val, transform=transform)
2123

24+
# you can optionally provide categories to have category wise metrics
25+
query_categories = np.array(dataset.extra_data["category"])[dataset.get_query_ids()]
26+
2227
embeddings = inference(extractor, dataset, batch_size=4)
2328

2429
rr = RetrievalResults.compute_from_embeddings(embeddings, dataset, n_items_to_retrieve=5)
25-
metrics = calc_retrieval_metrics_rr(rr, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
30+
metrics = calc_retrieval_metrics_rr(rr, query_categories, map_top_k=(3, 5), precision_top_k=(5,), cmc_top_k=(3,))
2631

2732
print(rr, "\n", metrics)
28-
rr.visualize(query_ids=[2, 1], dataset=dataset).show()
29-
33+
rr.visualize(query_ids=[2, 1], dataset=dataset)
3034
```
3135
[comment]:vanilla-validation-end
3236
</p>

docs/source/contents/metrics.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ calc_retrieval_metrics
2626
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2727
.. autofunction:: oml.functional.metrics.calc_retrieval_metrics
2828

29-
calc_topological_metrics
29+
calc_retrieval_metrics_rr
3030
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31-
.. autofunction:: oml.functional.metrics.calc_topological_metrics
31+
.. autofunction:: oml.metrics.embeddings.calc_retrieval_metrics_rr
3232

3333
calc_cmc
3434
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -46,6 +46,14 @@ calc_fnmr_at_fmr
4646
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4747
.. autofunction:: oml.functional.metrics.calc_fnmr_at_fmr
4848

49+
calc_fnmr_at_fmr_rr
50+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
.. autofunction:: oml.metrics.embeddings.calc_fnmr_at_fmr_rr
52+
53+
calc_topological_metrics
54+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
.. autofunction:: oml.functional.metrics.calc_topological_metrics
56+
4957
calc_pcf
5058
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5159
.. autofunction:: oml.functional.metrics.calc_pcf

oml/functional/knn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import BoolTensor, FloatTensor, LongTensor
5+
from tqdm.auto import tqdm
56

67
from oml.const import BS_KNN
78
from oml.utils.misc_torch import pairwise_dist
@@ -53,7 +54,7 @@ def batched_knn(
5354
gt_ids = []
5455

5556
# we do batching over first (queries) dimension
56-
for i in range(0, nq, bs):
57+
for i in tqdm(range(0, nq, bs), desc="Finding nearest neighbors."):
5758
distances_b = pairwise_dist(x1=embeddings_query[i : i + bs, :], x2=embeddings_gallery, p=2)
5859
ids_query_b = ids_query[i : i + bs]
5960

0 commit comments

Comments
 (0)