diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 6069bc9..ba0b9ed 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -1,3 +1,5 @@ +import io +import pathlib from typing import Any, Dict, List, Optional, Tuple, Type, Union import jax @@ -9,6 +11,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, Schedule +from stable_baselines3.common.utils import get_device class OffPolicyAlgorithmJax(OffPolicyAlgorithm): @@ -116,3 +119,13 @@ def _setup_model(self) -> None: ) # Convert train freq parameter to TrainFreq object self._convert_train_freq() + + def load_replay_buffer( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + truncate_last_traj: bool = True, + ) -> None: + super().load_replay_buffer(path, truncate_last_traj) + # Override replay buffer device to be always cpu for conversion to numpy + assert self.replay_buffer is not None + self.replay_buffer.device = get_device("cpu") diff --git a/sbx/version.txt b/sbx/version.txt index ac39a10..f374f66 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.9.0 +0.9.1 diff --git a/tests/test_buffers.py b/tests/test_buffers.py new file mode 100644 index 0000000..0d711d2 --- /dev/null +++ b/tests/test_buffers.py @@ -0,0 +1,14 @@ +import pytest +import torch as th + +from sbx import SAC + + +def test_force_cpu_device(tmp_path): + if not th.cuda.is_available(): + pytest.skip("No CUDA device") + model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200) + assert model.replay_buffer.device == th.device("cpu") + model.save_replay_buffer(tmp_path / "replay") + model.load_replay_buffer(tmp_path / "replay") + assert model.replay_buffer.device == th.device("cpu")