5
5
"""
6
6
7
7
# Standard
8
- import os , shutil
9
- import yaml
10
8
from uuid import uuid4
9
+ import os
10
+ import shutil
11
11
12
12
# Third Party
13
13
from lm_eval .tasks .unitxt import task
14
+ import yaml
14
15
15
16
# First Party
16
17
from instructlab .eval .mmlu import MMLUBranchEvaluator
20
21
21
22
logger = setup_logger (__name__ )
22
23
23
- TEMP_DIR_PREFIX = 'unitxt_temp'
24
+ TEMP_DIR_PREFIX = "unitxt_temp"
25
+
24
26
25
27
class UnitxtEvaluator (MMLUBranchEvaluator ):
26
28
"""
@@ -29,45 +31,50 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
29
31
Attributes:
30
32
model_path absolute path to or name of a huggingface model
31
33
unitxt_recipe unitxt recipe (see unitxt.ai for more information)
32
- A Recipe holds a complete specification of a unitxt pipeline
34
+ A Recipe holds a complete specification of a unitxt pipeline
33
35
Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
34
-
36
+
35
37
"""
38
+
36
39
name = "unitxt"
40
+
37
41
def __init__ (
38
42
self ,
39
- model_path ,
43
+ model_path ,
40
44
unitxt_recipe : str ,
41
45
):
42
46
task = self .assign_task_name ()
43
47
tasks_dir = self .assign_tasks_dir (task )
44
48
super ().__init__ (
45
- model_path = model_path ,
46
- tasks_dir = tasks_dir ,
47
- tasks = [task ],
48
- few_shots = 0
49
+ model_path = model_path , tasks_dir = tasks_dir , tasks = [task ], few_shots = 0
49
50
)
50
51
self .unitxt_recipe = unitxt_recipe
51
52
52
53
def assign_tasks_dir (self , task ):
53
- return f' { TEMP_DIR_PREFIX } _{ task } '
54
+ return f" { TEMP_DIR_PREFIX } _{ task } "
54
55
55
56
def assign_task_name (self ):
56
57
return str (uuid4 ())
57
58
58
- def prepare_unitxt_files (self )-> tuple :
59
+ def prepare_unitxt_files (self ) -> tuple :
59
60
task = self .tasks [0 ]
60
- yaml_file = os .path .join (self .tasks_dir ,f"{ task } .yaml" )
61
+ yaml_file = os .path .join (self .tasks_dir , f"{ task } .yaml" )
61
62
create_unitxt_pointer (self .tasks_dir )
62
- create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task )
63
+ create_unitxt_yaml (
64
+ yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task
65
+ )
63
66
64
67
def remove_unitxt_files (self ):
65
- if self .tasks_dir .startswith (TEMP_DIR_PREFIX ): #to avoid unintended deletion if this class is inherited
68
+ if self .tasks_dir .startswith (
69
+ TEMP_DIR_PREFIX
70
+ ): # to avoid unintended deletion if this class is inherited
66
71
shutil .rmtree (self .tasks_dir )
67
72
else :
68
- logger .warning (f"unitxt tasks dir did not start with '{ TEMP_DIR_PREFIX } ' and therefor was not deleted" )
73
+ logger .warning (
74
+ f"unitxt tasks dir did not start with '{ TEMP_DIR_PREFIX } ' and therefor was not deleted"
75
+ )
69
76
70
- def run (self ,server_url : str | None = None ) -> tuple :
77
+ def run (self , server_url : str | None = None ) -> tuple :
71
78
"""
72
79
Runs evaluation
73
80
@@ -80,13 +87,16 @@ def run(self,server_url: str | None = None) -> tuple:
80
87
os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
81
88
results = self ._run_mmlu (server_url = server_url , return_all_results = True )
82
89
taskname = self .tasks [0 ]
83
- global_scores = results [' results' ][taskname ]
84
- global_scores .pop (' alias' )
90
+ global_scores = results [" results" ][taskname ]
91
+ global_scores .pop (" alias" )
85
92
try :
86
- instances = results [' samples' ][taskname ]
93
+ instances = results [" samples" ][taskname ]
87
94
instance_scores = {}
88
- metrics = [metric .replace ('metrics.' ,'' ) for metric in instances [0 ]['doc' ]['metrics' ]]
89
- for i ,instance in enumerate (instances ):
95
+ metrics = [
96
+ metric .replace ("metrics." , "" )
97
+ for metric in instances [0 ]["doc" ]["metrics" ]
98
+ ]
99
+ for i , instance in enumerate (instances ):
90
100
scores = {}
91
101
for metric in metrics :
92
102
scores [metric ] = instance [metric ][0 ]
@@ -97,23 +107,20 @@ def run(self,server_url: str | None = None) -> tuple:
97
107
logger .error (e .__traceback__ )
98
108
instance_scores = None
99
109
self .remove_unitxt_files ()
100
- return global_scores ,instance_scores
110
+ return global_scores , instance_scores
101
111
102
112
103
- def create_unitxt_yaml (yaml_file ,unitxt_recipe , task_name ):
104
- data = {
105
- 'task' : f'{ task_name } ' ,
106
- 'include' : 'unitxt' ,
107
- 'recipe' : f'{ unitxt_recipe } '
108
- }
109
- with open (yaml_file , 'w' ) as file :
113
+ def create_unitxt_yaml (yaml_file , unitxt_recipe , task_name ):
114
+ data = {"task" : f"{ task_name } " , "include" : "unitxt" , "recipe" : f"{ unitxt_recipe } " }
115
+ with open (yaml_file , "w" ) as file :
110
116
yaml .dump (data , file , default_flow_style = False )
111
117
logger .debug (f"task { task } unitxt recipe written to { yaml_file } " )
112
118
119
+
113
120
def create_unitxt_pointer (tasks_dir ):
114
121
class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
115
- output_file = os .path .join (tasks_dir ,' unitxt' )
122
+ output_file = os .path .join (tasks_dir , " unitxt" )
116
123
os .makedirs (os .path .dirname (output_file ), exist_ok = True )
117
- with open (output_file , 'w' ) as f :
124
+ with open (output_file , "w" ) as f :
118
125
f .write (class_line )
119
126
logger .debug (f"Unitxt task pointer written to { output_file } " )
0 commit comments