From 114fc93a9f2884c918c1728d5aab1f688bba5fa7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Dec 2023 14:04:19 +0100 Subject: [PATCH 1/6] Fix replay buffer device at load time --- sbx/common/off_policy_algorithm.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 6069bc9..2dbdf05 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -9,6 +9,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, Schedule +from stable_baselines3.common.utils import get_device class OffPolicyAlgorithmJax(OffPolicyAlgorithm): @@ -116,3 +117,13 @@ def _setup_model(self) -> None: ) # Convert train freq parameter to TrainFreq object self._convert_train_freq() + + def load_replay_buffer( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + truncate_last_traj: bool = True, + ) -> None: + super().load_replay_buffer(path, truncate_last_traj) + # Override replay buffer device to be always cpu for conversion to numpy + self.replay_buffer.device = get_device("cpu") + From 35438009cad6c43235ba3dc845fed6f569e06027 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Dec 2023 14:08:49 +0100 Subject: [PATCH 2/6] Fix imports --- sbx/common/off_policy_algorithm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 2dbdf05..7e2e61d 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -1,3 +1,5 @@ +import io +import pathlib from typing import Any, Dict, List, Optional, Tuple, Type, Union import jax From 23f63c022cf9dbd2e2e2f0ee3fbfb966b17b4a45 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Dec 2023 14:09:36 +0100 Subject: [PATCH 3/6] Update version --- sbx/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbx/version.txt b/sbx/version.txt index ac39a10..f374f66 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.9.0 +0.9.1 From 897469077bd8aaea1d959b4e64a9bd25b250d84e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Dec 2023 10:42:40 +0100 Subject: [PATCH 4/6] Reformat and add test --- sbx/common/off_policy_algorithm.py | 3 +-- tests/test_buffers.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 tests/test_buffers.py diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 7e2e61d..05516ea 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -119,7 +119,7 @@ def _setup_model(self) -> None: ) # Convert train freq parameter to TrainFreq object self._convert_train_freq() - + def load_replay_buffer( self, path: Union[str, pathlib.Path, io.BufferedIOBase], @@ -128,4 +128,3 @@ def load_replay_buffer( super().load_replay_buffer(path, truncate_last_traj) # Override replay buffer device to be always cpu for conversion to numpy self.replay_buffer.device = get_device("cpu") - diff --git a/tests/test_buffers.py b/tests/test_buffers.py new file mode 100644 index 0000000..059c17c --- /dev/null +++ b/tests/test_buffers.py @@ -0,0 +1,14 @@ +import pytest +import torch as th + +from sbx import SAC + + +def test_force_cpu_device(tmp_path): + if not th.cuda.is_available(): + pytest.skip("No CUDA device") + model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200) + assert model.replay_buffer.device == th.device("cuda") + model.save_replay_buffer(tmp_path / "replay") + model.load_replay_buffer(tmp_path / "replay") + assert model.replay_buffer.device == th.device("cpu") From 6ee174d24abab94bba76cda57ec3250edd1dc304 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Dec 2023 10:45:42 +0100 Subject: [PATCH 5/6] Fix test --- tests/test_buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 059c17c..0d711d2 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -8,7 +8,7 @@ def test_force_cpu_device(tmp_path): if not th.cuda.is_available(): pytest.skip("No CUDA device") model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200) - assert model.replay_buffer.device == th.device("cuda") + assert model.replay_buffer.device == th.device("cpu") model.save_replay_buffer(tmp_path / "replay") model.load_replay_buffer(tmp_path / "replay") assert model.replay_buffer.device == th.device("cpu") From 1a02e95bf91cb2e84565d9fe997652de0f86d4c1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Dec 2023 10:49:34 +0100 Subject: [PATCH 6/6] Fix for mypy --- sbx/common/off_policy_algorithm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 05516ea..ba0b9ed 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -127,4 +127,5 @@ def load_replay_buffer( ) -> None: super().load_replay_buffer(path, truncate_last_traj) # Override replay buffer device to be always cpu for conversion to numpy + assert self.replay_buffer is not None self.replay_buffer.device = get_device("cpu")