Skip to content

Commit d477a07

Browse files
authored
Upgrade SB3 and fix type hints (#420)
* Upgrade to latest SB3 version * Fix hyperparam opt type hints * Fix exp manager type hints * Fix key passed to sampler * Ignore mypy
1 parent e98c00e commit d477a07

17 files changed

+122
-126
lines changed

.github/workflows/ci.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
AutoROM --accept-license --source-file Roms.tar.gz
4040
4141
# cpu version of pytorch - faster to download
42-
pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
42+
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
4343
pip install pybullet==3.2.5
4444
# for v4 MuJoCo envs:
4545
pip install mujoco
@@ -61,8 +61,6 @@ jobs:
6161
- name: Type check
6262
run: |
6363
make type
64-
# skip pytype type check for python 3.11 (not supported)
65-
if: "!(matrix.python-version == '3.11')"
6664
- name: Test with pytest
6765
run: |
6866
make pytest

.github/workflows/trained_agents.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
AutoROM --accept-license --source-file Roms.tar.gz
4141
4242
# cpu version of pytorch - faster to download
43-
pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
43+
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
4444
pip install pybullet==3.2.5
4545
pip install -r requirements.txt
4646
# Use headless version

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Release 2.2.0a8 (WIP)
1+
## Release 2.2.0a11 (WIP)
22

33
### Breaking Changes
44
- Removed `gym` dependency, the package is still required for some pretrained agents.
@@ -18,6 +18,8 @@
1818
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`
1919
- Switched to ruff for sorting imports
2020
- Updated tests to use `shlex.split()`
21+
- Fixed `rl_zoo3/hyperparams_opt.py` type hints
22+
- Fixed `rl_zoo3/exp_manager.py` type hints
2123

2224
## Release 2.1.0 (2023-08-17)
2325

Makefile

+1-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@ pytest:
88
check-trained-agents:
99
python -m pytest -v tests/test_enjoy.py -k trained_agent --color=yes
1010

11-
pytype:
12-
pytype -j auto ${LINT_PATHS} -d import-error
13-
1411
mypy:
1512
mypy ${LINT_PATHS} --install-types --non-interactive
1613

17-
type: pytype mypy
14+
type: mypy
1815

1916
lint:
2017
# stop the build if there are Python syntax errors or undefined names

pyproject.toml

+1-9
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,12 @@ max-complexity = 15
2020
[tool.black]
2121
line-length = 127
2222

23-
24-
[tool.pytype]
25-
inputs = ["."]
26-
exclude = ["tests/dummy_env"]
27-
# disable = []
28-
2923
[tool.mypy]
3024
ignore_missing_imports = true
3125
follow_imports = "silent"
3226
show_error_codes = true
3327
exclude = """(?x)(
34-
rl_zoo3/hyperparams_opt.py$
35-
| rl_zoo3/exp_manager.py$
36-
| tests/dummy_env/*$
28+
tests/dummy_env/*$
3729
)"""
3830

3931
[tool.pytest.ini_options]

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
gym==0.26.2
2-
stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a8,<3.0
3-
sb3-contrib>=2.2.0a8,<3.0
2+
stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a11,<3.0
3+
sb3-contrib>=2.2.0a11,<3.0
44
box2d-py==2.3.8
55
pybullet
66
pybullet_envs_gymnasium

rl_zoo3/enjoy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def enjoy() -> None: # noqa: C901
147147
args_path = os.path.join(log_path, env_name, "args.yml")
148148
if os.path.isfile(args_path):
149149
with open(args_path) as f:
150-
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
150+
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
151151
if loaded_args["env_kwargs"] is not None:
152152
env_kwargs = loaded_args["env_kwargs"]
153153
# overwrite with command line arguments

rl_zoo3/exp_manager.py

+37-24
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from torch import nn as nn
4747

4848
# Register custom envs
49-
import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error
49+
import rl_zoo3.import_envs # noqa: F401
5050
from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback
5151
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
5252
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule
@@ -116,13 +116,13 @@ def __init__(
116116
self.n_timesteps = n_timesteps
117117
self.normalize = False
118118
self.normalize_kwargs: Dict[str, Any] = {}
119-
self.env_wrapper = None
119+
self.env_wrapper: Optional[Callable] = None
120120
self.frame_stack = None
121121
self.seed = seed
122122
self.optimization_log_path = optimization_log_path
123123

124124
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
125-
self.vec_env_wrapper = None
125+
self.vec_env_wrapper: Optional[Callable] = None
126126

127127
self.vec_env_kwargs: Dict[str, Any] = {}
128128
# self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
@@ -138,7 +138,7 @@ def __init__(
138138
self.n_eval_envs = n_eval_envs
139139

140140
self.n_envs = 1 # it will be updated when reading hyperparams
141-
self.n_actions = None # For DDPG/TD3 action noise objects
141+
self.n_actions = 0 # For DDPG/TD3 action noise objects
142142
self._hyperparams: Dict[str, Any] = {}
143143
self.monitor_kwargs: Dict[str, Any] = {}
144144

@@ -186,8 +186,10 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
186186
187187
:return: the initialized RL model
188188
"""
189-
hyperparams, saved_hyperparams = self.read_hyperparameters()
190-
hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams)
189+
unprocessed_hyperparams, saved_hyperparams = self.read_hyperparameters()
190+
hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(
191+
unprocessed_hyperparams
192+
)
191193

192194
self.create_log_folder()
193195
self.create_callbacks()
@@ -221,7 +223,7 @@ def learn(self, model: BaseAlgorithm) -> None:
221223
"""
222224
:param model: an initialized RL model
223225
"""
224-
kwargs = {}
226+
kwargs: Dict[str, Any] = {}
225227
if self.log_interval > -1:
226228
kwargs = {"log_interval": self.log_interval}
227229

@@ -245,6 +247,7 @@ def learn(self, model: BaseAlgorithm) -> None:
245247
self.callbacks[0].on_training_end()
246248
# Release resources
247249
try:
250+
assert model.env is not None
248251
model.env.close()
249252
except EOFError:
250253
pass
@@ -265,7 +268,9 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:
265268

266269
if self.normalize:
267270
# Important: save the running average, for testing the agent we need that normalization
268-
model.get_vec_normalize_env().save(os.path.join(self.params_path, "vecnormalize.pkl"))
271+
vec_normalize = model.get_vec_normalize_env()
272+
assert vec_normalize is not None
273+
vec_normalize.save(os.path.join(self.params_path, "vecnormalize.pkl"))
269274

270275
def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
271276
"""
@@ -293,7 +298,7 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
293298
with open(self.config) as f:
294299
hyperparams_dict = yaml.safe_load(f)
295300
elif self.config.endswith(".py"):
296-
global_variables = {}
301+
global_variables: Dict = {}
297302
# Load hyperparameters from python file
298303
exec(Path(self.config).read_text(), global_variables)
299304
hyperparams_dict = global_variables["hyperparams"]
@@ -452,6 +457,9 @@ def _preprocess_action_noise(
452457
noise_std = hyperparams["noise_std"]
453458

454459
# Save for later (hyperparameter optimization)
460+
assert isinstance(
461+
env.action_space, spaces.Box
462+
), f"Action noise can only be used with Box action space, not {env.action_space}"
455463
self.n_actions = env.action_space.shape[0]
456464

457465
if "normal" in noise_type:
@@ -516,7 +524,7 @@ def create_callbacks(self):
516524

517525
@staticmethod
518526
def entry_point(env_id: str) -> str:
519-
return str(gym.envs.registry[env_id].entry_point) # pytype: disable=module-attr
527+
return str(gym.envs.registry[env_id].entry_point)
520528

521529
@staticmethod
522530
def is_atari(env_id: str) -> bool:
@@ -618,7 +626,7 @@ def make_env(**kwargs) -> gym.Env:
618626
env_kwargs=env_kwargs,
619627
monitor_dir=log_dir,
620628
wrapper_class=self.env_wrapper,
621-
vec_env_cls=self.vec_env_class,
629+
vec_env_cls=self.vec_env_class, # type: ignore[arg-type]
622630
vec_env_kwargs=self.vec_env_kwargs,
623631
monitor_kwargs=self.monitor_kwargs,
624632
)
@@ -645,11 +653,11 @@ def make_env(**kwargs) -> gym.Env:
645653
# the other channel last); VecTransposeImage will throw an error
646654
for space in env.observation_space.spaces.values():
647655
wrap_with_vectranspose = wrap_with_vectranspose or (
648-
is_image_space(space) and not is_image_space_channels_first(space)
656+
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
649657
)
650658
else:
651659
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
652-
env.observation_space
660+
env.observation_space # type: ignore[arg-type]
653661
)
654662

655663
if wrap_with_vectranspose:
@@ -683,13 +691,16 @@ def _load_pretrained_agent(self, hyperparams: Dict[str, Any], env: VecEnv) -> Ba
683691
if os.path.exists(replay_buffer_path):
684692
print("Loading replay buffer")
685693
# `truncate_last_traj` will be taken into account only if we use HER replay buffer
694+
assert hasattr(
695+
model, "load_replay_buffer"
696+
), "The current model doesn't have a `load_replay_buffer` to load the replay buffer"
686697
model.load_replay_buffer(replay_buffer_path, truncate_last_traj=self.truncate_last_trajectory)
687698
return model
688699

689700
def _create_sampler(self, sampler_method: str) -> BaseSampler:
690701
# n_warmup_steps: Disable pruner until the trial reaches the given number of steps.
691702
if sampler_method == "random":
692-
sampler = RandomSampler(seed=self.seed)
703+
sampler: BaseSampler = RandomSampler(seed=self.seed)
693704
elif sampler_method == "tpe":
694705
sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed, multivariate=True)
695706
elif sampler_method == "skopt":
@@ -705,7 +716,7 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler:
705716

706717
def _create_pruner(self, pruner_method: str) -> BasePruner:
707718
if pruner_method == "halving":
708-
pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
719+
pruner: BasePruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
709720
elif pruner_method == "median":
710721
pruner = MedianPruner(n_startup_trials=self.n_startup_trials, n_warmup_steps=self.n_evaluations // 3)
711722
elif pruner_method == "none":
@@ -718,17 +729,17 @@ def _create_pruner(self, pruner_method: str) -> BasePruner:
718729
def objective(self, trial: optuna.Trial) -> float:
719730
kwargs = self._hyperparams.copy()
720731

721-
# Hack to use DDPG/TD3 noise sampler
722-
trial.n_actions = self.n_actions
723-
# Hack when using HerReplayBuffer
724-
trial.using_her_replay_buffer = kwargs.get("replay_buffer_class") == HerReplayBuffer
725-
if trial.using_her_replay_buffer:
726-
trial.her_kwargs = kwargs.get("replay_buffer_kwargs", {})
732+
n_envs = 1 if self.algo == "ars" else self.n_envs
733+
734+
additional_args = {
735+
"using_her_replay_buffer": kwargs.get("replay_buffer_class") == HerReplayBuffer,
736+
"her_kwargs": kwargs.get("replay_buffer_kwargs", {}),
737+
}
738+
# Pass n_actions to initialize DDPG/TD3 noise sampler
727739
# Sample candidate hyperparameters
728-
sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial)
740+
sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial, self.n_actions, n_envs, additional_args)
729741
kwargs.update(sampled_hyperparams)
730742

731-
n_envs = 1 if self.algo == "ars" else self.n_envs
732743
env = self.create_envs(n_envs, no_log=True)
733744

734745
# By default, do not activate verbose output to keep
@@ -778,13 +789,15 @@ def objective(self, trial: optuna.Trial) -> float:
778789
)
779790

780791
try:
781-
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
792+
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
782793
# Free memory
794+
assert model.env is not None
783795
model.env.close()
784796
eval_env.close()
785797
except (AssertionError, ValueError) as e:
786798
# Sometimes, random hyperparams can generate NaN
787799
# Free memory
800+
assert model.env is not None
788801
model.env.close()
789802
eval_env.close()
790803
# Prune hyperparams that generate NaNs

0 commit comments

Comments
 (0)