Skip to content

Commit

Permalink
Upgrade SB3 and fix type hints (#420)
Browse files Browse the repository at this point in the history
* Upgrade to latest SB3 version

* Fix hyperparam opt type hints

* Fix exp manager type hints

* Fix key passed to sampler

* Ignore mypy
  • Loading branch information
araffin authored Nov 8, 2023
1 parent e98c00e commit d477a07
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 126 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch - faster to download
pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.2.5
# for v4 MuJoCo envs:
pip install mujoco
Expand All @@ -61,8 +61,6 @@ jobs:
- name: Type check
run: |
make type
# skip pytype type check for python 3.11 (not supported)
if: "!(matrix.python-version == '3.11')"
- name: Test with pytest
run: |
make pytest
2 changes: 1 addition & 1 deletion .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
AutoROM --accept-license --source-file Roms.tar.gz
# cpu version of pytorch - faster to download
pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install pybullet==3.2.5
pip install -r requirements.txt
# Use headless version
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Release 2.2.0a8 (WIP)
## Release 2.2.0a11 (WIP)

### Breaking Changes
- Removed `gym` dependency, the package is still required for some pretrained agents.
Expand All @@ -18,6 +18,8 @@
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`
- Switched to ruff for sorting imports
- Updated tests to use `shlex.split()`
- Fixed `rl_zoo3/hyperparams_opt.py` type hints
- Fixed `rl_zoo3/exp_manager.py` type hints

## Release 2.1.0 (2023-08-17)

Expand Down
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ pytest:
check-trained-agents:
python -m pytest -v tests/test_enjoy.py -k trained_agent --color=yes

pytype:
pytype -j auto ${LINT_PATHS} -d import-error

mypy:
mypy ${LINT_PATHS} --install-types --non-interactive

type: pytype mypy
type: mypy

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down
10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,12 @@ max-complexity = 15
[tool.black]
line-length = 127


[tool.pytype]
inputs = ["."]
exclude = ["tests/dummy_env"]
# disable = []

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
rl_zoo3/hyperparams_opt.py$
| rl_zoo3/exp_manager.py$
| tests/dummy_env/*$
tests/dummy_env/*$
)"""

[tool.pytest.ini_options]
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a8,<3.0
sb3-contrib>=2.2.0a8,<3.0
stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a11,<3.0
sb3-contrib>=2.2.0a11,<3.0
box2d-py==2.3.8
pybullet
pybullet_envs_gymnasium
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def enjoy() -> None: # noqa: C901
args_path = os.path.join(log_path, env_name, "args.yml")
if os.path.isfile(args_path):
with open(args_path) as f:
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
if loaded_args["env_kwargs"] is not None:
env_kwargs = loaded_args["env_kwargs"]
# overwrite with command line arguments
Expand Down
61 changes: 37 additions & 24 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from torch import nn as nn

# Register custom envs
import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error
import rl_zoo3.import_envs # noqa: F401
from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule
Expand Down Expand Up @@ -116,13 +116,13 @@ def __init__(
self.n_timesteps = n_timesteps
self.normalize = False
self.normalize_kwargs: Dict[str, Any] = {}
self.env_wrapper = None
self.env_wrapper: Optional[Callable] = None
self.frame_stack = None
self.seed = seed
self.optimization_log_path = optimization_log_path

self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
self.vec_env_wrapper = None
self.vec_env_wrapper: Optional[Callable] = None

self.vec_env_kwargs: Dict[str, Any] = {}
# self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
Expand All @@ -138,7 +138,7 @@ def __init__(
self.n_eval_envs = n_eval_envs

self.n_envs = 1 # it will be updated when reading hyperparams
self.n_actions = None # For DDPG/TD3 action noise objects
self.n_actions = 0 # For DDPG/TD3 action noise objects
self._hyperparams: Dict[str, Any] = {}
self.monitor_kwargs: Dict[str, Any] = {}

Expand Down Expand Up @@ -186,8 +186,10 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
:return: the initialized RL model
"""
hyperparams, saved_hyperparams = self.read_hyperparameters()
hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams)
unprocessed_hyperparams, saved_hyperparams = self.read_hyperparameters()
hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(
unprocessed_hyperparams
)

self.create_log_folder()
self.create_callbacks()
Expand Down Expand Up @@ -221,7 +223,7 @@ def learn(self, model: BaseAlgorithm) -> None:
"""
:param model: an initialized RL model
"""
kwargs = {}
kwargs: Dict[str, Any] = {}
if self.log_interval > -1:
kwargs = {"log_interval": self.log_interval}

Expand All @@ -245,6 +247,7 @@ def learn(self, model: BaseAlgorithm) -> None:
self.callbacks[0].on_training_end()
# Release resources
try:
assert model.env is not None
model.env.close()
except EOFError:
pass
Expand All @@ -265,7 +268,9 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:

if self.normalize:
# Important: save the running average, for testing the agent we need that normalization
model.get_vec_normalize_env().save(os.path.join(self.params_path, "vecnormalize.pkl"))
vec_normalize = model.get_vec_normalize_env()
assert vec_normalize is not None
vec_normalize.save(os.path.join(self.params_path, "vecnormalize.pkl"))

def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -293,7 +298,7 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
with open(self.config) as f:
hyperparams_dict = yaml.safe_load(f)
elif self.config.endswith(".py"):
global_variables = {}
global_variables: Dict = {}
# Load hyperparameters from python file
exec(Path(self.config).read_text(), global_variables)
hyperparams_dict = global_variables["hyperparams"]
Expand Down Expand Up @@ -452,6 +457,9 @@ def _preprocess_action_noise(
noise_std = hyperparams["noise_std"]

# Save for later (hyperparameter optimization)
assert isinstance(
env.action_space, spaces.Box
), f"Action noise can only be used with Box action space, not {env.action_space}"
self.n_actions = env.action_space.shape[0]

if "normal" in noise_type:
Expand Down Expand Up @@ -516,7 +524,7 @@ def create_callbacks(self):

@staticmethod
def entry_point(env_id: str) -> str:
return str(gym.envs.registry[env_id].entry_point) # pytype: disable=module-attr
return str(gym.envs.registry[env_id].entry_point)

@staticmethod
def is_atari(env_id: str) -> bool:
Expand Down Expand Up @@ -618,7 +626,7 @@ def make_env(**kwargs) -> gym.Env:
env_kwargs=env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_cls=self.vec_env_class, # type: ignore[arg-type]
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=self.monitor_kwargs,
)
Expand All @@ -645,11 +653,11 @@ def make_env(**kwargs) -> gym.Env:
# the other channel last); VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space)
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space
env.observation_space # type: ignore[arg-type]
)

if wrap_with_vectranspose:
Expand Down Expand Up @@ -683,13 +691,16 @@ def _load_pretrained_agent(self, hyperparams: Dict[str, Any], env: VecEnv) -> Ba
if os.path.exists(replay_buffer_path):
print("Loading replay buffer")
# `truncate_last_traj` will be taken into account only if we use HER replay buffer
assert hasattr(
model, "load_replay_buffer"
), "The current model doesn't have a `load_replay_buffer` to load the replay buffer"
model.load_replay_buffer(replay_buffer_path, truncate_last_traj=self.truncate_last_trajectory)
return model

def _create_sampler(self, sampler_method: str) -> BaseSampler:
# n_warmup_steps: Disable pruner until the trial reaches the given number of steps.
if sampler_method == "random":
sampler = RandomSampler(seed=self.seed)
sampler: BaseSampler = RandomSampler(seed=self.seed)
elif sampler_method == "tpe":
sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed, multivariate=True)
elif sampler_method == "skopt":
Expand All @@ -705,7 +716,7 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler:

def _create_pruner(self, pruner_method: str) -> BasePruner:
if pruner_method == "halving":
pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
pruner: BasePruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0)
elif pruner_method == "median":
pruner = MedianPruner(n_startup_trials=self.n_startup_trials, n_warmup_steps=self.n_evaluations // 3)
elif pruner_method == "none":
Expand All @@ -718,17 +729,17 @@ def _create_pruner(self, pruner_method: str) -> BasePruner:
def objective(self, trial: optuna.Trial) -> float:
kwargs = self._hyperparams.copy()

# Hack to use DDPG/TD3 noise sampler
trial.n_actions = self.n_actions
# Hack when using HerReplayBuffer
trial.using_her_replay_buffer = kwargs.get("replay_buffer_class") == HerReplayBuffer
if trial.using_her_replay_buffer:
trial.her_kwargs = kwargs.get("replay_buffer_kwargs", {})
n_envs = 1 if self.algo == "ars" else self.n_envs

additional_args = {
"using_her_replay_buffer": kwargs.get("replay_buffer_class") == HerReplayBuffer,
"her_kwargs": kwargs.get("replay_buffer_kwargs", {}),
}
# Pass n_actions to initialize DDPG/TD3 noise sampler
# Sample candidate hyperparameters
sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial)
sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial, self.n_actions, n_envs, additional_args)
kwargs.update(sampled_hyperparams)

n_envs = 1 if self.algo == "ars" else self.n_envs
env = self.create_envs(n_envs, no_log=True)

# By default, do not activate verbose output to keep
Expand Down Expand Up @@ -778,13 +789,15 @@ def objective(self, trial: optuna.Trial) -> float:
)

try:
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
# Free memory
assert model.env is not None
model.env.close()
eval_env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
assert model.env is not None
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
Expand Down
Loading

0 comments on commit d477a07

Please sign in to comment.