3
3
4
4
from natsort import natsorted , ns
5
5
from tabulate import tabulate
6
+ from tqdm .auto import tqdm
6
7
7
8
from tti_eval .common import EmbeddingDefinition , Split
8
9
from tti_eval .constants import OUTPUT_PATH
@@ -47,9 +48,9 @@ def run_evaluation(
47
48
embedding_definitions : list [EmbeddingDefinition ],
48
49
) -> dict [EmbeddingDefinition , dict [str , float ]]:
49
50
embeddings_performance : dict [EmbeddingDefinition , dict [str , float ]] = {}
50
- model_keys : set [str ] = set ()
51
+ used_evaluators : set [str ] = set ()
51
52
52
- for def_ in embedding_definitions :
53
+ for def_ in tqdm ( embedding_definitions , desc = "Evaluating embedding definitions" , leave = False ) :
53
54
train_embeddings = def_ .load_embeddings (Split .TRAIN )
54
55
validation_embeddings = def_ .load_embeddings (Split .VALIDATION )
55
56
@@ -68,11 +69,13 @@ def run_evaluation(
68
69
train_embeddings = train_embeddings ,
69
70
validation_embeddings = validation_embeddings ,
70
71
)
71
- evaluator_performance [evaluator .title ] = evaluator .evaluate ()
72
- model_keys .add (evaluator .title )
72
+ evaluator_performance [evaluator .title () ] = evaluator .evaluate ()
73
+ used_evaluators .add (evaluator .title () )
73
74
74
- for n in model_keys :
75
- print_evaluation_results (embeddings_performance , n )
75
+ for evaluator_type in evaluators :
76
+ evaluator_title = evaluator_type .title ()
77
+ if evaluator_title in used_evaluators :
78
+ print_evaluation_results (embeddings_performance , evaluator_title )
76
79
return embeddings_performance
77
80
78
81
0 commit comments