Skip to content

Commit 56651aa

Browse files
authored
Fix item access of some _TASKS_TO_AUTOMODELS (#642)
fix item access
1 parent 60e25c9 commit 56651aa

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

optimum/exporters/tasks.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -646,17 +646,17 @@ def get_model_class_for_task(task: str, framework: str = "pt") -> Type:
646646
task = TasksManager.format_task(task)
647647
TasksManager._validate_framework_choice(framework)
648648
if framework == "pt":
649-
task_to_automodel = TasksManager._TASKS_TO_AUTOMODELS
649+
tasks_to_automodel = TasksManager._TASKS_TO_AUTOMODELS
650650
else:
651-
task_to_automodel = TasksManager._TASKS_TO_TF_AUTOMODELS
652-
if task not in task_to_automodel:
651+
tasks_to_automodel = TasksManager._TASKS_TO_TF_AUTOMODELS
652+
if task not in tasks_to_automodel:
653653
raise KeyError(
654654
f"Unknown task: {task}. Possible values are: "
655-
+ ", ".join([f"`{key}` for {task_to_automodel[key].__name__}" for key in task_to_automodel])
655+
+ ", ".join([f"`{key}` for {tasks_to_automodel[key]}" for key in tasks_to_automodel])
656656
)
657657

658658
module = importlib.import_module(TasksManager._TASKS_TO_LIBRARY[task])
659-
return getattr(module, task_to_automodel[task])
659+
return getattr(module, tasks_to_automodel[task])
660660

661661
@staticmethod
662662
def determine_framework(

0 commit comments

Comments
 (0)