Skip to content

Commit 08495a0

Browse files
authored
Fix reading custom task configs (#3425)
- Fix the check for missing custom task config files. - Support glob pattern in given file paths. - Add deduplication for custom configs based on their content.
1 parent b7a8cf7 commit 08495a0

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

lm_eval/__main__.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,24 +395,39 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
395395
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
396396
sys.exit()
397397
else:
398-
if os.path.isdir(args.tasks):
399-
import glob
398+
import glob
400399

400+
if os.path.isdir(args.tasks):
401401
task_names = []
402402
yaml_path = os.path.join(args.tasks, "*.yaml")
403403
for yaml_file in glob.glob(yaml_path):
404404
config = utils.load_yaml_config(yaml_file)
405405
task_names.append(config)
406406
else:
407+
import itertools
408+
407409
task_list = args.tasks.split(",")
408-
task_names = task_manager.match_tasks(task_list)
409-
for task in [task for task in task_list if task not in task_names]:
410-
if os.path.isfile(task):
411-
config = utils.load_yaml_config(task)
412-
task_names.append(config)
413-
task_missing = [
414-
task for task in task_list if task not in task_names and "*" not in task
415-
] # we don't want errors if a wildcard ("*") task name was used
410+
task_list = [
411+
os.path.abspath(task) if task.endswith(".yaml") else task
412+
for task in task_list
413+
]
414+
match_dict = dict.fromkeys(task_list) # deduplicate file paths
415+
416+
for task in match_dict.keys():
417+
if not task.endswith(".yaml"): # provided task names
418+
matches = task_manager.match_tasks(task)
419+
else: # custom config files
420+
matches = []
421+
for yaml_file in glob.glob(task):
422+
config = utils.load_yaml_config(yaml_file)
423+
matches.append(config)
424+
match_dict[task] = matches
425+
426+
task_names = []
427+
for task in itertools.chain.from_iterable(match_dict.values()):
428+
if task not in task_names:
429+
task_names.append(task)
430+
task_missing = [task for task, matches in match_dict.items() if not matches]
416431

417432
if task_missing:
418433
missing = ", ".join(task_missing)

0 commit comments

Comments
 (0)