18
18
from pathlib import Path
19
19
20
20
from lerobot .common import envs , policies # noqa: F401
21
- from lerobot .common .utils .utils import auto_select_torch_device , is_amp_available , is_torch_device_available
22
21
from lerobot .configs import parser
23
22
from lerobot .configs .default import EvalConfig
24
23
from lerobot .configs .policies import PreTrainedConfig
25
- from lerobot .configs .train import TrainPipelineConfig
26
24
27
25
28
26
@dataclass
@@ -35,11 +33,6 @@ class EvalPipelineConfig:
35
33
policy : PreTrainedConfig | None = None
36
34
output_dir : Path | None = None
37
35
job_name : str | None = None
38
- # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
39
- device : str | None = None # cuda | cpu | mps
40
- # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
41
- # automatic gradient scaling is used.
42
- use_amp : bool = False
43
36
seed : int | None = 1000
44
37
45
38
def __post_init__ (self ):
@@ -50,27 +43,6 @@ def __post_init__(self):
50
43
self .policy = PreTrainedConfig .from_pretrained (policy_path , cli_overrides = cli_overrides )
51
44
self .policy .pretrained_path = policy_path
52
45
53
- # When no device or use_amp are given, use the one from training config.
54
- if self .device is None or self .use_amp is None :
55
- train_cfg = TrainPipelineConfig .from_pretrained (policy_path )
56
- if self .device is None :
57
- self .device = train_cfg .device
58
- if self .use_amp is None :
59
- self .use_amp = train_cfg .use_amp
60
-
61
- # Automatically switch to available device if necessary
62
- if not is_torch_device_available (self .device ):
63
- auto_device = auto_select_torch_device ()
64
- logging .warning (f"Device '{ self .device } ' is not available. Switching to '{ auto_device } '." )
65
- self .device = auto_device
66
-
67
- # Automatically deactivate AMP if necessary
68
- if self .use_amp and not is_amp_available (self .device ):
69
- logging .warning (
70
- f"Automatic Mixed Precision (amp) is not available on device '{ self .device } '. Deactivating AMP."
71
- )
72
- self .use_amp = False
73
-
74
46
else :
75
47
logging .warning (
76
48
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@@ -87,11 +59,6 @@ def __post_init__(self):
87
59
eval_dir = f"{ now :%Y-%m-%d} /{ now :%H-%M-%S} _{ self .job_name } "
88
60
self .output_dir = Path ("outputs/eval" ) / eval_dir
89
61
90
- if self .device is None :
91
- raise ValueError ("Set one of the following device: cuda, cpu or mps" )
92
- elif self .device == "cuda" and self .use_amp is None :
93
- raise ValueError ("Set 'use_amp' to True or False." )
94
-
95
62
@classmethod
96
63
def __get_path_fields__ (cls ) -> list [str ]:
97
64
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
0 commit comments