Skip to content

Commit f8d9233

Browse files
authored
fix: set fixed evaluators order and title as classmethod (#76)
1 parent cab13d8 commit f8d9233

File tree

7 files changed

+53
-27
lines changed

7 files changed

+53
-27
lines changed

tti_eval/evaluation/base.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,7 @@
66
from .utils import normalize
77

88

9-
class EvaluationModelTitleInterface:
10-
# Enforce the evaluation models' title while removing its explicit mention in the `EvaluationModel` init.
11-
# This way, when `EvaluationModel` is used as a type hint, there won't be a warning about unfilled title.
12-
def __init__(self, title: str, **kwargs) -> None:
13-
self._title = title
14-
15-
@property
16-
def title(self) -> str:
17-
return self._title
18-
19-
20-
class EvaluationModel(EvaluationModelTitleInterface, ABC):
9+
class EvaluationModel(ABC):
2110
def __init__(
2211
self,
2312
train_embeddings: Embeddings,
@@ -95,6 +84,11 @@ def num_classes(self) -> int:
9584
def evaluate(self) -> float:
9685
...
9786

87+
@classmethod
88+
@abstractmethod
89+
def title(cls) -> str:
90+
...
91+
9892

9993
class ClassificationModel(EvaluationModel):
10094
@abstractmethod

tti_eval/evaluation/evaluator.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from natsort import natsorted, ns
55
from tabulate import tabulate
6+
from tqdm.auto import tqdm
67

78
from tti_eval.common import EmbeddingDefinition, Split
89
from tti_eval.constants import OUTPUT_PATH
@@ -47,9 +48,9 @@ def run_evaluation(
4748
embedding_definitions: list[EmbeddingDefinition],
4849
) -> dict[EmbeddingDefinition, dict[str, float]]:
4950
embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {}
50-
model_keys: set[str] = set()
51+
used_evaluators: set[str] = set()
5152

52-
for def_ in embedding_definitions:
53+
for def_ in tqdm(embedding_definitions, desc="Evaluating embedding definitions", leave=False):
5354
train_embeddings = def_.load_embeddings(Split.TRAIN)
5455
validation_embeddings = def_.load_embeddings(Split.VALIDATION)
5556

@@ -68,11 +69,13 @@ def run_evaluation(
6869
train_embeddings=train_embeddings,
6970
validation_embeddings=validation_embeddings,
7071
)
71-
evaluator_performance[evaluator.title] = evaluator.evaluate()
72-
model_keys.add(evaluator.title)
72+
evaluator_performance[evaluator.title()] = evaluator.evaluate()
73+
used_evaluators.add(evaluator.title())
7374

74-
for n in model_keys:
75-
print_evaluation_results(embeddings_performance, n)
75+
for evaluator_type in evaluators:
76+
evaluator_title = evaluator_type.title()
77+
if evaluator_title in used_evaluators:
78+
print_evaluation_results(embeddings_performance, evaluator_title)
7679
return embeddings_performance
7780

7881

tti_eval/evaluation/image_retrieval.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
from autofaiss import build_index
66

77
from tti_eval.common import Embeddings
8+
from tti_eval.utils import disable_tqdm, enable_tqdm
89

910
from .base import EvaluationModel
1011

1112
logger = logging.getLogger("multiclips")
1213

1314

1415
class I2IRetrievalEvaluator(EvaluationModel):
16+
@classmethod
17+
def title(cls) -> str:
18+
return "I2IR"
19+
1520
def __init__(
1621
self,
1722
train_embeddings: Embeddings,
@@ -33,14 +38,16 @@ def __init__(
3338
3439
:raises ValueError: If the build of the faiss index for similarity search fails.
3540
"""
36-
super().__init__(train_embeddings, validation_embeddings, num_classes, title="I2IR")
41+
super().__init__(train_embeddings, validation_embeddings, num_classes)
3742
self.k = min(k, len(validation_embeddings.images))
3843

3944
class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True)
4045
self._class_counts = np.zeros(self.num_classes, dtype=np.int32)
4146
self._class_counts[class_ids] = counts
4247

48+
disable_tqdm() # Disable tqdm progress bar when building the index
4349
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
50+
enable_tqdm()
4451
if index is None:
4552
raise ValueError("Failed to build an index for knn search")
4653
self._index = index

tti_eval/evaluation/knn.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from autofaiss import build_index
66

77
from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
8+
from tti_eval.utils import disable_tqdm, enable_tqdm
89

910
from .base import ClassificationModel
1011
from .utils import softmax
@@ -13,6 +14,10 @@
1314

1415

1516
class WeightedKNNClassifier(ClassificationModel):
17+
@classmethod
18+
def title(cls) -> str:
19+
return "wKNN"
20+
1621
def __init__(
1722
self,
1823
train_embeddings: Embeddings,
@@ -36,15 +41,13 @@ def __init__(
3641
3742
:raises ValueError: If the build of the faiss index for KNN fails.
3843
"""
39-
super().__init__(train_embeddings, validation_embeddings, num_classes, title="wKNN")
44+
super().__init__(train_embeddings, validation_embeddings, num_classes)
4045
self.k = k
41-
46+
disable_tqdm() # Disable tqdm progress bar when building the index
4247
index, self.index_infos = build_index(
43-
train_embeddings.images,
44-
metric_type="l2",
45-
save_on_disk=False,
46-
verbose=logging.ERROR,
48+
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
4749
)
50+
enable_tqdm()
4851
if index is None:
4952
raise ValueError("Failed to build an index for knn search")
5053
self._index = index

tti_eval/evaluation/linear_probe.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212

1313
class LinearProbeClassifier(ClassificationModel):
14+
@classmethod
15+
def title(cls) -> str:
16+
return "linear_probe"
17+
1418
def __init__(
1519
self,
1620
train_embeddings: Embeddings,
@@ -28,7 +32,7 @@ def __init__(
2832
:param log_reg_params: Parameters for the Logistic Regression model.
2933
:param use_cross_validation: Flag that indicated whether to use cross-validation when training the model.
3034
"""
31-
super().__init__(train_embeddings, validation_embeddings, num_classes, title="linear_probe")
35+
super().__init__(train_embeddings, validation_embeddings, num_classes)
3236

3337
params = log_reg_params or {}
3438
self.classifier: LogisticRegressionCV | LogisticRegression

tti_eval/evaluation/zero_shot.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77

88
class ZeroShotClassifier(ClassificationModel):
9+
@classmethod
10+
def title(cls) -> str:
11+
return "zero_shot"
12+
913
def __init__(
1014
self,
1115
train_embeddings: Embeddings,
@@ -19,7 +23,7 @@ def __init__(
1923
:param validation_embeddings: Embeddings and their labels used for evaluating the search space.
2024
:param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
2125
"""
22-
super().__init__(train_embeddings, validation_embeddings, num_classes, title="zero_shot")
26+
super().__init__(train_embeddings, validation_embeddings, num_classes)
2327
if self._train_embeddings.classes is None:
2428
raise ValueError("Expected class embeddings in `train_embeddings`, got `None`")
2529

tti_eval/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
from functools import partialmethod
12
from itertools import chain
23
from typing import Literal, overload
34

5+
from tqdm import tqdm
6+
47
from tti_eval.common import EmbeddingDefinition
58
from tti_eval.constants import PROJECT_PATHS
69

710

11+
def disable_tqdm():
12+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
13+
14+
15+
def enable_tqdm():
16+
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
17+
18+
819
@overload
920
def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
1021
...

0 commit comments

Comments
 (0)