Skip to content

Commit 5e94738

Browse files
refactor(config): Move device & amp args to PreTrainedConfig (#812)
Co-authored-by: Simon Alibert <[email protected]>
1 parent 10706ed commit 5e94738

19 files changed

+62
-136
lines changed

Makefile

+9-9
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ test-act-ete-train:
4747
--policy.dim_model=64 \
4848
--policy.n_action_steps=20 \
4949
--policy.chunk_size=20 \
50+
--policy.device=$(DEVICE) \
5051
--env.type=aloha \
5152
--env.episode_length=5 \
5253
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
@@ -61,7 +62,6 @@ test-act-ete-train:
6162
--save_checkpoint=true \
6263
--log_freq=1 \
6364
--wandb.enable=false \
64-
--device=$(DEVICE) \
6565
--output_dir=tests/outputs/act/
6666

6767
test-act-ete-train-resume:
@@ -72,18 +72,19 @@ test-act-ete-train-resume:
7272
test-act-ete-eval:
7373
python lerobot/scripts/eval.py \
7474
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
75+
--policy.device=$(DEVICE) \
7576
--env.type=aloha \
7677
--env.episode_length=5 \
7778
--eval.n_episodes=1 \
78-
--eval.batch_size=1 \
79-
--device=$(DEVICE)
79+
--eval.batch_size=1
8080

8181
test-diffusion-ete-train:
8282
python lerobot/scripts/train.py \
8383
--policy.type=diffusion \
8484
--policy.down_dims='[64,128,256]' \
8585
--policy.diffusion_step_embed_dim=32 \
8686
--policy.num_inference_steps=10 \
87+
--policy.device=$(DEVICE) \
8788
--env.type=pusht \
8889
--env.episode_length=5 \
8990
--dataset.repo_id=lerobot/pusht \
@@ -98,21 +99,21 @@ test-diffusion-ete-train:
9899
--save_freq=2 \
99100
--log_freq=1 \
100101
--wandb.enable=false \
101-
--device=$(DEVICE) \
102102
--output_dir=tests/outputs/diffusion/
103103

104104
test-diffusion-ete-eval:
105105
python lerobot/scripts/eval.py \
106106
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
107+
--policy.device=$(DEVICE) \
107108
--env.type=pusht \
108109
--env.episode_length=5 \
109110
--eval.n_episodes=1 \
110-
--eval.batch_size=1 \
111-
--device=$(DEVICE)
111+
--eval.batch_size=1
112112

113113
test-tdmpc-ete-train:
114114
python lerobot/scripts/train.py \
115115
--policy.type=tdmpc \
116+
--policy.device=$(DEVICE) \
116117
--env.type=xarm \
117118
--env.task=XarmLift-v0 \
118119
--env.episode_length=5 \
@@ -128,15 +129,14 @@ test-tdmpc-ete-train:
128129
--save_freq=2 \
129130
--log_freq=1 \
130131
--wandb.enable=false \
131-
--device=$(DEVICE) \
132132
--output_dir=tests/outputs/tdmpc/
133133

134134
test-tdmpc-ete-eval:
135135
python lerobot/scripts/eval.py \
136136
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
137+
--policy.device=$(DEVICE) \
137138
--env.type=xarm \
138139
--env.episode_length=5 \
139140
--env.task=XarmLift-v0 \
140141
--eval.n_episodes=1 \
141-
--eval.batch_size=1 \
142-
--device=$(DEVICE)
142+
--eval.batch_size=1

lerobot/common/policies/factory.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import logging
1818

19-
import torch
2019
from torch import nn
2120

2221
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
@@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
7675

7776
def make_policy(
7877
cfg: PreTrainedConfig,
79-
device: str | torch.device,
8078
ds_meta: LeRobotDatasetMetadata | None = None,
8179
env_cfg: EnvConfig | None = None,
8280
) -> PreTrainedPolicy:
@@ -88,15 +86,14 @@ def make_policy(
8886
Args:
8987
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
9088
be loaded with the weights from that path.
91-
device (str): the device to load the policy onto.
9289
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
9390
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
9491
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
9592
provided if ds_meta is not. Defaults to None.
9693
9794
Raises:
9895
ValueError: Either ds_meta or env and env_cfg must be provided.
99-
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
96+
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
10097
10198
Returns:
10299
PreTrainedPolicy: _description_
@@ -111,7 +108,7 @@ def make_policy(
111108
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
112109
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
113110
# slower than running natively on MPS.
114-
if cfg.type == "vqbet" and str(device) == "mps":
111+
if cfg.type == "vqbet" and cfg.device == "mps":
115112
raise NotImplementedError(
116113
"Current implementation of VQBeT does not support `mps` backend. "
117114
"Please use `cpu` or `cuda` backend."
@@ -145,7 +142,7 @@ def make_policy(
145142
# Make a fresh policy.
146143
policy = policy_cls(**kwargs)
147144

148-
policy.to(device)
145+
policy.to(cfg.device)
149146
assert isinstance(policy, nn.Module)
150147

151148
# policy = torch.compile(policy, mode="reduce-overhead")

lerobot/common/policies/pi0/configuration_pi0.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class PI0Config(PreTrainedConfig):
9090
def __post_init__(self):
9191
super().__post_init__()
9292

93+
# TODO(Steven): Validate device and amp? in all policy configs?
9394
"""Input validation (not exhaustive)."""
9495
if self.n_action_steps > self.chunk_size:
9596
raise ValueError(

lerobot/common/policies/pi0/conversion_scripts/benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main():
4545

4646
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
4747
cfg.pretrained_path = ckpt_torch_dir
48-
policy = make_policy(cfg, device, ds_meta=dataset.meta)
48+
policy = make_policy(cfg, ds_meta=dataset.meta)
4949

5050
# policy = torch.compile(policy, mode="reduce-overhead")
5151

lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def main():
101101

102102
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
103103
cfg.pretrained_path = ckpt_torch_dir
104-
policy = make_policy(cfg, device, dataset_meta)
104+
policy = make_policy(cfg, dataset_meta)
105105

106106
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
107107
# loss_dict["loss"].backward()

lerobot/common/policies/pretrained.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def from_pretrained(
8686
cache_dir: str | Path | None = None,
8787
local_files_only: bool = False,
8888
revision: str | None = None,
89-
map_location: str = "cpu",
9089
strict: bool = False,
9190
**kwargs,
9291
) -> T:
@@ -111,7 +110,7 @@ def from_pretrained(
111110
if os.path.isdir(model_id):
112111
print("Loading weights from local directory")
113112
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
114-
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
113+
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
115114
else:
116115
try:
117116
model_file = hf_hub_download(
@@ -125,13 +124,13 @@ def from_pretrained(
125124
token=token,
126125
local_files_only=local_files_only,
127126
)
128-
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
127+
policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
129128
except HfHubHTTPError as e:
130129
raise FileNotFoundError(
131130
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
132131
) from e
133132

134-
policy.to(map_location)
133+
policy.to(config.device)
135134
policy.eval()
136135
return policy
137136

lerobot/common/robot_devices/control_configs.py

-29
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
from dataclasses import dataclass
1716
from pathlib import Path
1817

1918
import draccus
2019

2120
from lerobot.common.robot_devices.robots.configs import RobotConfig
22-
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
2321
from lerobot.configs import parser
2422
from lerobot.configs.policies import PreTrainedConfig
25-
from lerobot.configs.train import TrainPipelineConfig
2623

2724

2825
@dataclass
@@ -57,11 +54,6 @@ class RecordControlConfig(ControlConfig):
5754
# Root directory where the dataset will be stored (e.g. 'dataset/path').
5855
root: str | Path | None = None
5956
policy: PreTrainedConfig | None = None
60-
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
61-
device: str | None = None # cuda | cpu | mps
62-
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
63-
# automatic gradient scaling is used.
64-
use_amp: bool | None = None
6557
# Limit the frames per second. By default, uses the policy fps.
6658
fps: int | None = None
6759
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@@ -104,27 +96,6 @@ def __post_init__(self):
10496
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
10597
self.policy.pretrained_path = policy_path
10698

107-
# When no device or use_amp are given, use the one from training config.
108-
if self.device is None or self.use_amp is None:
109-
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
110-
if self.device is None:
111-
self.device = train_cfg.device
112-
if self.use_amp is None:
113-
self.use_amp = train_cfg.use_amp
114-
115-
# Automatically switch to available device if necessary
116-
if not is_torch_device_available(self.device):
117-
auto_device = auto_select_torch_device()
118-
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
119-
self.device = auto_device
120-
121-
# Automatically deactivate AMP if necessary
122-
if self.use_amp and not is_amp_available(self.device):
123-
logging.warning(
124-
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
125-
)
126-
self.use_amp = False
127-
12899

129100
@ControlConfig.register_subclass("replay")
130101
@dataclass

lerobot/common/robot_devices/control_utils.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lerobot.common.datasets.image_writer import safe_stop_image_writer
3333
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
3434
from lerobot.common.datasets.utils import get_features_from_robot
35+
from lerobot.common.policies.pretrained import PreTrainedPolicy
3536
from lerobot.common.robot_devices.robots.utils import Robot
3637
from lerobot.common.robot_devices.utils import busy_wait
3738
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -193,8 +194,6 @@ def record_episode(
193194
episode_time_s,
194195
display_cameras,
195196
policy,
196-
device,
197-
use_amp,
198197
fps,
199198
single_task,
200199
):
@@ -205,8 +204,6 @@ def record_episode(
205204
dataset=dataset,
206205
events=events,
207206
policy=policy,
208-
device=device,
209-
use_amp=use_amp,
210207
fps=fps,
211208
teleoperate=policy is None,
212209
single_task=single_task,
@@ -221,9 +218,7 @@ def control_loop(
221218
display_cameras=False,
222219
dataset: LeRobotDataset | None = None,
223220
events=None,
224-
policy=None,
225-
device: torch.device | str | None = None,
226-
use_amp: bool | None = None,
221+
policy: PreTrainedPolicy = None,
227222
fps: int | None = None,
228223
single_task: str | None = None,
229224
):
@@ -246,9 +241,6 @@ def control_loop(
246241
if dataset is not None and fps is not None and dataset.fps != fps:
247242
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
248243

249-
if isinstance(device, str):
250-
device = get_safe_torch_device(device)
251-
252244
timestamp = 0
253245
start_episode_t = time.perf_counter()
254246
while timestamp < control_time_s:
@@ -260,7 +252,9 @@ def control_loop(
260252
observation = robot.capture_observation()
261253

262254
if policy is not None:
263-
pred_action = predict_action(observation, policy, device, use_amp)
255+
pred_action = predict_action(
256+
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
257+
)
264258
# Action can eventually be clipped using `max_relative_target`,
265259
# so action actually sent is saved in the dataset.
266260
action = robot.send_action(pred_action)

lerobot/common/utils/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
5151
return torch.device("cpu")
5252

5353

54+
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
5455
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
5556
"""Given a string, return a torch.device with checks on whether the device is available."""
57+
try_device = str(try_device)
5658
match try_device:
5759
case "cuda":
5860
assert torch.cuda.is_available()
@@ -85,14 +87,15 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
8587

8688

8789
def is_torch_device_available(try_device: str) -> bool:
90+
try_device = str(try_device) # Ensure try_device is a string
8891
if try_device == "cuda":
8992
return torch.cuda.is_available()
9093
elif try_device == "mps":
9194
return torch.backends.mps.is_available()
9295
elif try_device == "cpu":
9396
return True
9497
else:
95-
raise ValueError(f"Unknown device '{try_device}.")
98+
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
9699

97100

98101
def is_amp_available(device: str):

lerobot/configs/eval.py

-33
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
from pathlib import Path
1919

2020
from lerobot.common import envs, policies # noqa: F401
21-
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
2221
from lerobot.configs import parser
2322
from lerobot.configs.default import EvalConfig
2423
from lerobot.configs.policies import PreTrainedConfig
25-
from lerobot.configs.train import TrainPipelineConfig
2624

2725

2826
@dataclass
@@ -35,11 +33,6 @@ class EvalPipelineConfig:
3533
policy: PreTrainedConfig | None = None
3634
output_dir: Path | None = None
3735
job_name: str | None = None
38-
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
39-
device: str | None = None # cuda | cpu | mps
40-
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
41-
# automatic gradient scaling is used.
42-
use_amp: bool = False
4336
seed: int | None = 1000
4437

4538
def __post_init__(self):
@@ -50,27 +43,6 @@ def __post_init__(self):
5043
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
5144
self.policy.pretrained_path = policy_path
5245

53-
# When no device or use_amp are given, use the one from training config.
54-
if self.device is None or self.use_amp is None:
55-
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
56-
if self.device is None:
57-
self.device = train_cfg.device
58-
if self.use_amp is None:
59-
self.use_amp = train_cfg.use_amp
60-
61-
# Automatically switch to available device if necessary
62-
if not is_torch_device_available(self.device):
63-
auto_device = auto_select_torch_device()
64-
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
65-
self.device = auto_device
66-
67-
# Automatically deactivate AMP if necessary
68-
if self.use_amp and not is_amp_available(self.device):
69-
logging.warning(
70-
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
71-
)
72-
self.use_amp = False
73-
7446
else:
7547
logging.warning(
7648
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@@ -87,11 +59,6 @@ def __post_init__(self):
8759
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
8860
self.output_dir = Path("outputs/eval") / eval_dir
8961

90-
if self.device is None:
91-
raise ValueError("Set one of the following device: cuda, cpu or mps")
92-
elif self.device == "cuda" and self.use_amp is None:
93-
raise ValueError("Set 'use_amp' to True or False.")
94-
9562
@classmethod
9663
def __get_path_fields__(cls) -> list[str]:
9764
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""

0 commit comments

Comments
 (0)