Skip to content

Commit 9ca436c

Browse files
committed
format: ruff
Signed-off-by: Roni Friedman-Melamed <[email protected]>
1 parent c22397b commit 9ca436c

File tree

3 files changed

+44
-37
lines changed

3 files changed

+44
-37
lines changed

src/instructlab/eval/mmlu.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def run(self, server_url: str | None = None) -> tuple:
153153

154154
return overall_score, individual_scores
155155

156-
def _run_mmlu(self, server_url: str | None = None, return_all_results:bool = False) -> dict:
156+
def _run_mmlu(
157+
self, server_url: str | None = None, return_all_results: bool = False
158+
) -> dict:
157159
if server_url is not None:
158160
# Requires lm_eval >= 0.4.4
159161
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"

src/instructlab/eval/unitxt.py

+39-32
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66

77
# Standard
8-
import os, shutil
9-
import yaml
108
from uuid import uuid4
9+
import os
10+
import shutil
1111

1212
# Third Party
1313
from lm_eval.tasks.unitxt import task
14+
import yaml
1415

1516
# First Party
1617
from instructlab.eval.mmlu import MMLUBranchEvaluator
@@ -20,7 +21,8 @@
2021

2122
logger = setup_logger(__name__)
2223

23-
TEMP_DIR_PREFIX = 'unitxt_temp'
24+
TEMP_DIR_PREFIX = "unitxt_temp"
25+
2426

2527
class UnitxtEvaluator(MMLUBranchEvaluator):
2628
"""
@@ -29,45 +31,50 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
2931
Attributes:
3032
model_path absolute path to or name of a huggingface model
3133
unitxt_recipe unitxt recipe (see unitxt.ai for more information)
32-
A Recipe holds a complete specification of a unitxt pipeline
34+
A Recipe holds a complete specification of a unitxt pipeline
3335
Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
34-
36+
3537
"""
38+
3639
name = "unitxt"
40+
3741
def __init__(
3842
self,
39-
model_path,
43+
model_path,
4044
unitxt_recipe: str,
4145
):
4246
task = self.assign_task_name()
4347
tasks_dir = self.assign_tasks_dir(task)
4448
super().__init__(
45-
model_path = model_path,
46-
tasks_dir = tasks_dir,
47-
tasks = [task],
48-
few_shots = 0
49+
model_path=model_path, tasks_dir=tasks_dir, tasks=[task], few_shots=0
4950
)
5051
self.unitxt_recipe = unitxt_recipe
5152

5253
def assign_tasks_dir(self, task):
53-
return f'{TEMP_DIR_PREFIX}_{task}'
54+
return f"{TEMP_DIR_PREFIX}_{task}"
5455

5556
def assign_task_name(self):
5657
return str(uuid4())
5758

58-
def prepare_unitxt_files(self)->tuple:
59+
def prepare_unitxt_files(self) -> tuple:
5960
task = self.tasks[0]
60-
yaml_file = os.path.join(self.tasks_dir,f"{task}.yaml")
61+
yaml_file = os.path.join(self.tasks_dir, f"{task}.yaml")
6162
create_unitxt_pointer(self.tasks_dir)
62-
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task)
63+
create_unitxt_yaml(
64+
yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task
65+
)
6366

6467
def remove_unitxt_files(self):
65-
if self.tasks_dir.startswith(TEMP_DIR_PREFIX): #to avoid unintended deletion if this class is inherited
68+
if self.tasks_dir.startswith(
69+
TEMP_DIR_PREFIX
70+
): # to avoid unintended deletion if this class is inherited
6671
shutil.rmtree(self.tasks_dir)
6772
else:
68-
logger.warning(f"unitxt tasks dir did not start with '{TEMP_DIR_PREFIX}' and therefor was not deleted")
73+
logger.warning(
74+
f"unitxt tasks dir did not start with '{TEMP_DIR_PREFIX}' and therefor was not deleted"
75+
)
6976

70-
def run(self,server_url: str | None = None) -> tuple:
77+
def run(self, server_url: str | None = None) -> tuple:
7178
"""
7279
Runs evaluation
7380
@@ -80,13 +87,16 @@ def run(self,server_url: str | None = None) -> tuple:
8087
os.environ["TOKENIZERS_PARALLELISM"] = "true"
8188
results = self._run_mmlu(server_url=server_url, return_all_results=True)
8289
taskname = self.tasks[0]
83-
global_scores = results['results'][taskname]
84-
global_scores.pop('alias')
90+
global_scores = results["results"][taskname]
91+
global_scores.pop("alias")
8592
try:
86-
instances = results['samples'][taskname]
93+
instances = results["samples"][taskname]
8794
instance_scores = {}
88-
metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
89-
for i,instance in enumerate(instances):
95+
metrics = [
96+
metric.replace("metrics.", "")
97+
for metric in instances[0]["doc"]["metrics"]
98+
]
99+
for i, instance in enumerate(instances):
90100
scores = {}
91101
for metric in metrics:
92102
scores[metric] = instance[metric][0]
@@ -97,23 +107,20 @@ def run(self,server_url: str | None = None) -> tuple:
97107
logger.error(e.__traceback__)
98108
instance_scores = None
99109
self.remove_unitxt_files()
100-
return global_scores,instance_scores
110+
return global_scores, instance_scores
101111

102112

103-
def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
104-
data = {
105-
'task': f'{task_name}',
106-
'include': 'unitxt',
107-
'recipe': f'{unitxt_recipe}'
108-
}
109-
with open(yaml_file, 'w') as file:
113+
def create_unitxt_yaml(yaml_file, unitxt_recipe, task_name):
114+
data = {"task": f"{task_name}", "include": "unitxt", "recipe": f"{unitxt_recipe}"}
115+
with open(yaml_file, "w") as file:
110116
yaml.dump(data, file, default_flow_style=False)
111117
logger.debug(f"task {task} unitxt recipe written to {yaml_file}")
112118

119+
113120
def create_unitxt_pointer(tasks_dir):
114121
class_line = "class: !function " + task.__file__.replace("task.py", "task.Unitxt")
115-
output_file = os.path.join(tasks_dir,'unitxt')
122+
output_file = os.path.join(tasks_dir, "unitxt")
116123
os.makedirs(os.path.dirname(output_file), exist_ok=True)
117-
with open(output_file, 'w') as f:
124+
with open(output_file, "w") as f:
118125
f.write(class_line)
119126
logger.debug(f"Unitxt task pointer written to {output_file}")

tests/test_unitxt.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ def test_unitxt():
77
try:
88
model_path = "instructlab/granite-7b-lab"
99
unitxt_recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10"
10-
unitxt = UnitxtEvaluator(
11-
model_path=model_path, unitxt_recipe=unitxt_recipe
12-
)
10+
unitxt = UnitxtEvaluator(model_path=model_path, unitxt_recipe=unitxt_recipe)
1311
overall_score, single_scores = unitxt.run()
1412
print(overall_score)
1513
except Exception as exc:
@@ -19,4 +17,4 @@ def test_unitxt():
1917

2018

2119
if __name__ == "__main__":
22-
assert test_unitxt() == True
20+
assert test_unitxt() == True

0 commit comments

Comments
 (0)