File tree 1 file changed +5
-5
lines changed
1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -646,17 +646,17 @@ def get_model_class_for_task(task: str, framework: str = "pt") -> Type:
646
646
task = TasksManager .format_task (task )
647
647
TasksManager ._validate_framework_choice (framework )
648
648
if framework == "pt" :
649
- task_to_automodel = TasksManager ._TASKS_TO_AUTOMODELS
649
+ tasks_to_automodel = TasksManager ._TASKS_TO_AUTOMODELS
650
650
else :
651
- task_to_automodel = TasksManager ._TASKS_TO_TF_AUTOMODELS
652
- if task not in task_to_automodel :
651
+ tasks_to_automodel = TasksManager ._TASKS_TO_TF_AUTOMODELS
652
+ if task not in tasks_to_automodel :
653
653
raise KeyError (
654
654
f"Unknown task: { task } . Possible values are: "
655
- + ", " .join ([f"`{ key } ` for { task_to_automodel [key ]. __name__ } " for key in task_to_automodel ])
655
+ + ", " .join ([f"`{ key } ` for { tasks_to_automodel [key ]} " for key in tasks_to_automodel ])
656
656
)
657
657
658
658
module = importlib .import_module (TasksManager ._TASKS_TO_LIBRARY [task ])
659
- return getattr (module , task_to_automodel [task ])
659
+ return getattr (module , tasks_to_automodel [task ])
660
660
661
661
@staticmethod
662
662
def determine_framework (
You can’t perform that action at this time.
0 commit comments