From d477a07d19fa725e4c8ec38cacb5438d0c6fc1b9 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 8 Nov 2023 12:02:42 +0100 Subject: [PATCH] Upgrade SB3 and fix type hints (#420) * Upgrade to latest SB3 version * Fix hyperparam opt type hints * Fix exp manager type hints * Fix key passed to sampler * Ignore mypy --- .github/workflows/ci.yml | 4 +- .github/workflows/trained_agents.yml | 2 +- CHANGELOG.md | 4 +- Makefile | 5 +- pyproject.toml | 10 +-- requirements.txt | 4 +- rl_zoo3/enjoy.py | 2 +- rl_zoo3/exp_manager.py | 61 +++++++------ rl_zoo3/hyperparams_opt.py | 124 +++++++++++++-------------- rl_zoo3/import_envs.py | 14 +-- rl_zoo3/plots/plot_from_file.py | 4 +- rl_zoo3/push_to_hub.py | 2 +- rl_zoo3/record_video.py | 2 +- rl_zoo3/train.py | 4 +- rl_zoo3/utils.py | 2 +- rl_zoo3/version.txt | 2 +- setup.py | 2 +- 17 files changed, 122 insertions(+), 126 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fbf302fb8..212a6ef19 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: AutoROM --accept-license --source-file Roms.tar.gz # cpu version of pytorch - faster to download - pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu pip install pybullet==3.2.5 # for v4 MuJoCo envs: pip install mujoco @@ -61,8 +61,6 @@ jobs: - name: Type check run: | make type - # skip pytype type check for python 3.11 (not supported) - if: "!(matrix.python-version == '3.11')" - name: Test with pytest run: | make pytest diff --git a/.github/workflows/trained_agents.yml b/.github/workflows/trained_agents.yml index ab04789cb..2d4dd3f01 100644 --- a/.github/workflows/trained_agents.yml +++ b/.github/workflows/trained_agents.yml @@ -40,7 +40,7 @@ jobs: AutoROM --accept-license --source-file Roms.tar.gz # cpu version of pytorch - faster to download - pip install torch==1.13.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu pip install pybullet==3.2.5 pip install -r requirements.txt # Use headless version diff --git a/CHANGELOG.md b/CHANGELOG.md index 401515994..91171159b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## Release 2.2.0a8 (WIP) +## Release 2.2.0a11 (WIP) ### Breaking Changes - Removed `gym` dependency, the package is still required for some pretrained agents. @@ -18,6 +18,8 @@ - Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)` - Switched to ruff for sorting imports - Updated tests to use `shlex.split()` +- Fixed `rl_zoo3/hyperparams_opt.py` type hints +- Fixed `rl_zoo3/exp_manager.py` type hints ## Release 2.1.0 (2023-08-17) diff --git a/Makefile b/Makefile index 7718e0524..3b0a463b3 100644 --- a/Makefile +++ b/Makefile @@ -8,13 +8,10 @@ pytest: check-trained-agents: python -m pytest -v tests/test_enjoy.py -k trained_agent --color=yes -pytype: - pytype -j auto ${LINT_PATHS} -d import-error - mypy: mypy ${LINT_PATHS} --install-types --non-interactive -type: pytype mypy +type: mypy lint: # stop the build if there are Python syntax errors or undefined names diff --git a/pyproject.toml b/pyproject.toml index 86f821905..aeffa65fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,20 +20,12 @@ max-complexity = 15 [tool.black] line-length = 127 - -[tool.pytype] -inputs = ["."] -exclude = ["tests/dummy_env"] -# disable = [] - [tool.mypy] ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - rl_zoo3/hyperparams_opt.py$ - | rl_zoo3/exp_manager.py$ - | tests/dummy_env/*$ + tests/dummy_env/*$ )""" [tool.pytest.ini_options] diff --git a/requirements.txt b/requirements.txt index 00bb83af8..698aabf41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ gym==0.26.2 -stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a8,<3.0 -sb3-contrib>=2.2.0a8,<3.0 +stable-baselines3[extra_no_roms,tests,docs]>=2.2.0a11,<3.0 +sb3-contrib>=2.2.0a11,<3.0 box2d-py==2.3.8 pybullet pybullet_envs_gymnasium diff --git a/rl_zoo3/enjoy.py b/rl_zoo3/enjoy.py index 341150ac1..71cdd8467 100644 --- a/rl_zoo3/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -147,7 +147,7 @@ def enjoy() -> None: # noqa: C901 args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: - loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr + loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) if loaded_args["env_kwargs"] is not None: env_kwargs = loaded_args["env_kwargs"] # overwrite with command line arguments diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 39d3b920d..b61786f72 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -46,7 +46,7 @@ from torch import nn as nn # Register custom envs -import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error +import rl_zoo3.import_envs # noqa: F401 from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule @@ -116,13 +116,13 @@ def __init__( self.n_timesteps = n_timesteps self.normalize = False self.normalize_kwargs: Dict[str, Any] = {} - self.env_wrapper = None + self.env_wrapper: Optional[Callable] = None self.frame_stack = None self.seed = seed self.optimization_log_path = optimization_log_path self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] - self.vec_env_wrapper = None + self.vec_env_wrapper: Optional[Callable] = None self.vec_env_kwargs: Dict[str, Any] = {} # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"} @@ -138,7 +138,7 @@ def __init__( self.n_eval_envs = n_eval_envs self.n_envs = 1 # it will be updated when reading hyperparams - self.n_actions = None # For DDPG/TD3 action noise objects + self.n_actions = 0 # For DDPG/TD3 action noise objects self._hyperparams: Dict[str, Any] = {} self.monitor_kwargs: Dict[str, Any] = {} @@ -186,8 +186,10 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]: :return: the initialized RL model """ - hyperparams, saved_hyperparams = self.read_hyperparameters() - hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams) + unprocessed_hyperparams, saved_hyperparams = self.read_hyperparameters() + hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams( + unprocessed_hyperparams + ) self.create_log_folder() self.create_callbacks() @@ -221,7 +223,7 @@ def learn(self, model: BaseAlgorithm) -> None: """ :param model: an initialized RL model """ - kwargs = {} + kwargs: Dict[str, Any] = {} if self.log_interval > -1: kwargs = {"log_interval": self.log_interval} @@ -245,6 +247,7 @@ def learn(self, model: BaseAlgorithm) -> None: self.callbacks[0].on_training_end() # Release resources try: + assert model.env is not None model.env.close() except EOFError: pass @@ -265,7 +268,9 @@ def save_trained_model(self, model: BaseAlgorithm) -> None: if self.normalize: # Important: save the running average, for testing the agent we need that normalization - model.get_vec_normalize_env().save(os.path.join(self.params_path, "vecnormalize.pkl")) + vec_normalize = model.get_vec_normalize_env() + assert vec_normalize is not None + vec_normalize.save(os.path.join(self.params_path, "vecnormalize.pkl")) def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None: """ @@ -293,7 +298,7 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: with open(self.config) as f: hyperparams_dict = yaml.safe_load(f) elif self.config.endswith(".py"): - global_variables = {} + global_variables: Dict = {} # Load hyperparameters from python file exec(Path(self.config).read_text(), global_variables) hyperparams_dict = global_variables["hyperparams"] @@ -452,6 +457,9 @@ def _preprocess_action_noise( noise_std = hyperparams["noise_std"] # Save for later (hyperparameter optimization) + assert isinstance( + env.action_space, spaces.Box + ), f"Action noise can only be used with Box action space, not {env.action_space}" self.n_actions = env.action_space.shape[0] if "normal" in noise_type: @@ -516,7 +524,7 @@ def create_callbacks(self): @staticmethod def entry_point(env_id: str) -> str: - return str(gym.envs.registry[env_id].entry_point) # pytype: disable=module-attr + return str(gym.envs.registry[env_id].entry_point) @staticmethod def is_atari(env_id: str) -> bool: @@ -618,7 +626,7 @@ def make_env(**kwargs) -> gym.Env: env_kwargs=env_kwargs, monitor_dir=log_dir, wrapper_class=self.env_wrapper, - vec_env_cls=self.vec_env_class, + vec_env_cls=self.vec_env_class, # type: ignore[arg-type] vec_env_kwargs=self.vec_env_kwargs, monitor_kwargs=self.monitor_kwargs, ) @@ -645,11 +653,11 @@ def make_env(**kwargs) -> gym.Env: # the other channel last); VecTransposeImage will throw an error for space in env.observation_space.spaces.values(): wrap_with_vectranspose = wrap_with_vectranspose or ( - is_image_space(space) and not is_image_space_channels_first(space) + is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type] ) else: wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( - env.observation_space + env.observation_space # type: ignore[arg-type] ) if wrap_with_vectranspose: @@ -683,13 +691,16 @@ def _load_pretrained_agent(self, hyperparams: Dict[str, Any], env: VecEnv) -> Ba if os.path.exists(replay_buffer_path): print("Loading replay buffer") # `truncate_last_traj` will be taken into account only if we use HER replay buffer + assert hasattr( + model, "load_replay_buffer" + ), "The current model doesn't have a `load_replay_buffer` to load the replay buffer" model.load_replay_buffer(replay_buffer_path, truncate_last_traj=self.truncate_last_trajectory) return model def _create_sampler(self, sampler_method: str) -> BaseSampler: # n_warmup_steps: Disable pruner until the trial reaches the given number of steps. if sampler_method == "random": - sampler = RandomSampler(seed=self.seed) + sampler: BaseSampler = RandomSampler(seed=self.seed) elif sampler_method == "tpe": sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed, multivariate=True) elif sampler_method == "skopt": @@ -705,7 +716,7 @@ def _create_sampler(self, sampler_method: str) -> BaseSampler: def _create_pruner(self, pruner_method: str) -> BasePruner: if pruner_method == "halving": - pruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0) + pruner: BasePruner = SuccessiveHalvingPruner(min_resource=1, reduction_factor=4, min_early_stopping_rate=0) elif pruner_method == "median": pruner = MedianPruner(n_startup_trials=self.n_startup_trials, n_warmup_steps=self.n_evaluations // 3) elif pruner_method == "none": @@ -718,17 +729,17 @@ def _create_pruner(self, pruner_method: str) -> BasePruner: def objective(self, trial: optuna.Trial) -> float: kwargs = self._hyperparams.copy() - # Hack to use DDPG/TD3 noise sampler - trial.n_actions = self.n_actions - # Hack when using HerReplayBuffer - trial.using_her_replay_buffer = kwargs.get("replay_buffer_class") == HerReplayBuffer - if trial.using_her_replay_buffer: - trial.her_kwargs = kwargs.get("replay_buffer_kwargs", {}) + n_envs = 1 if self.algo == "ars" else self.n_envs + + additional_args = { + "using_her_replay_buffer": kwargs.get("replay_buffer_class") == HerReplayBuffer, + "her_kwargs": kwargs.get("replay_buffer_kwargs", {}), + } + # Pass n_actions to initialize DDPG/TD3 noise sampler # Sample candidate hyperparameters - sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial) + sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial, self.n_actions, n_envs, additional_args) kwargs.update(sampled_hyperparams) - n_envs = 1 if self.algo == "ars" else self.n_envs env = self.create_envs(n_envs, no_log=True) # By default, do not activate verbose output to keep @@ -778,13 +789,15 @@ def objective(self, trial: optuna.Trial) -> float: ) try: - model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) + model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type] # Free memory + assert model.env is not None model.env.close() eval_env.close() except (AssertionError, ValueError) as e: # Sometimes, random hyperparams can generate NaN # Free memory + assert model.env is not None model.env.close() eval_env.close() # Prune hyperparams that generate NaNs diff --git a/rl_zoo3/hyperparams_opt.py b/rl_zoo3/hyperparams_opt.py index 360734cd9..1ff6708a0 100644 --- a/rl_zoo3/hyperparams_opt.py +++ b/rl_zoo3/hyperparams_opt.py @@ -8,7 +8,7 @@ from rl_zoo3 import linear_schedule -def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_ppo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for PPO hyperparams. @@ -19,16 +19,13 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: n_steps = trial.suggest_categorical("n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]) gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]) learning_rate = trial.suggest_float("learning_rate", 1e-5, 1, log=True) - lr_schedule = "constant" - # Uncomment to enable learning rate schedule - # lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant']) ent_coef = trial.suggest_float("ent_coef", 0.00000001, 0.1, log=True) clip_range = trial.suggest_categorical("clip_range", [0.1, 0.2, 0.3, 0.4]) n_epochs = trial.suggest_categorical("n_epochs", [1, 5, 10, 20]) gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5]) vf_coef = trial.suggest_float("vf_coef", 0, 1) - net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"]) + net_arch_type = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"]) # Uncomment for gSDE (continuous actions) # log_std_init = trial.suggest_float("log_std_init", -4, 1) # Uncomment for gSDE (continuous action) @@ -37,24 +34,26 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: ortho_init = False # ortho_init = trial.suggest_categorical('ortho_init', [False, True]) # activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu']) - activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) + activation_fn_name = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) + # lr_schedule = "constant" + # Uncomment to enable learning rate schedule + # lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant']) + # if lr_schedule == "linear": + # learning_rate = linear_schedule(learning_rate) # TODO: account when using multiple envs if batch_size > n_steps: batch_size = n_steps - if lr_schedule == "linear": - learning_rate = linear_schedule(learning_rate) - # Independent networks usually work best # when not working with images net_arch = { "tiny": dict(pi=[64], vf=[64]), "small": dict(pi=[64, 64], vf=[64, 64]), "medium": dict(pi=[256, 256], vf=[256, 256]), - }[net_arch] + }[net_arch_type] - activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn] + activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn_name] return { "n_steps": n_steps, @@ -77,14 +76,14 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: } -def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_ppo_lstm_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for RecurrentPPO hyperparams. uses sample_ppo_params(), this function samples for the policy_kwargs :param trial: :return: """ - hyperparams = sample_ppo_params(trial) + hyperparams = sample_ppo_params(trial, n_actions, n_envs, additional_args) enable_critic_lstm = trial.suggest_categorical("enable_critic_lstm", [False, True]) lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [16, 32, 64, 128, 256, 512]) @@ -99,7 +98,7 @@ def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]: return hyperparams -def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_trpo_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for TRPO hyperparams. @@ -110,16 +109,13 @@ def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]: n_steps = trial.suggest_categorical("n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]) gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]) learning_rate = trial.suggest_float("learning_rate", 1e-5, 1, log=True) - lr_schedule = "constant" - # Uncomment to enable learning rate schedule - # lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant']) # line_search_shrinking_factor = trial.suggest_categorical("line_search_shrinking_factor", [0.6, 0.7, 0.8, 0.9]) n_critic_updates = trial.suggest_categorical("n_critic_updates", [5, 10, 20, 25, 30]) cg_max_steps = trial.suggest_categorical("cg_max_steps", [5, 10, 20, 25, 30]) # cg_damping = trial.suggest_categorical("cg_damping", [0.5, 0.2, 0.1, 0.05, 0.01]) target_kl = trial.suggest_categorical("target_kl", [0.1, 0.05, 0.03, 0.02, 0.01, 0.005, 0.001]) gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) - net_arch = trial.suggest_categorical("net_arch", ["small", "medium"]) + net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium"]) # Uncomment for gSDE (continuous actions) # log_std_init = trial.suggest_float("log_std_init", -4, 1) # Uncomment for gSDE (continuous action) @@ -128,23 +124,25 @@ def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]: ortho_init = False # ortho_init = trial.suggest_categorical('ortho_init', [False, True]) # activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu']) - activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) + activation_fn_name = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) + # lr_schedule = "constant" + # Uncomment to enable learning rate schedule + # lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant']) + # if lr_schedule == "linear": + # learning_rate = linear_schedule(learning_rate) # TODO: account when using multiple envs if batch_size > n_steps: batch_size = n_steps - if lr_schedule == "linear": - learning_rate = linear_schedule(learning_rate) - # Independent networks usually work best # when not working with images net_arch = { "small": dict(pi=[64, 64], vf=[64, 64]), "medium": dict(pi=[256, 256], vf=[256, 256]), - }[net_arch] + }[net_arch_type] - activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn] + activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn_name] return { "n_steps": n_steps, @@ -167,7 +165,7 @@ def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]: } -def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_a2c_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for A2C hyperparams. @@ -188,19 +186,19 @@ def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]: # Uncomment for gSDE (continuous actions) # log_std_init = trial.suggest_float("log_std_init", -4, 1) ortho_init = trial.suggest_categorical("ortho_init", [False, True]) - net_arch = trial.suggest_categorical("net_arch", ["small", "medium"]) + net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium"]) # sde_net_arch = trial.suggest_categorical("sde_net_arch", [None, "tiny", "small"]) # full_std = trial.suggest_categorical("full_std", [False, True]) # activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu']) - activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) + activation_fn_name = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) if lr_schedule == "linear": - learning_rate = linear_schedule(learning_rate) + learning_rate = linear_schedule(learning_rate) # type: ignore[assignment] net_arch = { "small": dict(pi=[64, 64], vf=[64, 64]), "medium": dict(pi=[256, 256], vf=[256, 256]), - }[net_arch] + }[net_arch_type] # sde_net_arch = { # None: None, @@ -208,7 +206,7 @@ def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]: # "small": [64, 64], # }[sde_net_arch] - activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn] + activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn_name] return { "n_steps": n_steps, @@ -231,7 +229,7 @@ def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]: } -def sample_sac_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_sac_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for SAC hyperparams. @@ -255,7 +253,7 @@ def sample_sac_params(trial: optuna.Trial) -> Dict[str, Any]: # You can comment that out when not using gSDE log_std_init = trial.suggest_float("log_std_init", -4, 1) # NOTE: Add "verybig" to net_arch when tuning HER - net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) + net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) # activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU]) net_arch = { @@ -265,7 +263,7 @@ def sample_sac_params(trial: optuna.Trial) -> Dict[str, Any]: # Uncomment for tuning HER # "large": [256, 256, 256], # "verybig": [512, 512, 512], - }[net_arch] + }[net_arch_type] target_entropy = "auto" # if ent_coef == 'auto': @@ -286,13 +284,13 @@ def sample_sac_params(trial: optuna.Trial) -> Dict[str, Any]: "policy_kwargs": dict(log_std_init=log_std_init, net_arch=net_arch), } - if trial.using_her_replay_buffer: - hyperparams = sample_her_params(trial, hyperparams) + if additional_args["using_her_replay_buffer"]: + hyperparams = sample_her_params(trial, hyperparams, additional_args["her_kwargs"]) return hyperparams -def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_td3_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for TD3 hyperparams. @@ -313,7 +311,7 @@ def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]: noise_std = trial.suggest_float("noise_std", 0, 1) # NOTE: Add "verybig" to net_arch when tuning HER - net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) + net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) # activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU]) net_arch = { @@ -322,7 +320,7 @@ def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]: "big": [400, 300], # Uncomment for tuning HER # "verybig": [256, 256, 256], - }[net_arch] + }[net_arch_type] hyperparams = { "gamma": gamma, @@ -336,21 +334,19 @@ def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]: } if noise_type == "normal": - hyperparams["action_noise"] = NormalActionNoise( - mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions) - ) + hyperparams["action_noise"] = NormalActionNoise(mean=np.zeros(n_actions), sigma=noise_std * np.ones(n_actions)) elif noise_type == "ornstein-uhlenbeck": hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise( - mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions) + mean=np.zeros(n_actions), sigma=noise_std * np.ones(n_actions) ) - if trial.using_her_replay_buffer: - hyperparams = sample_her_params(trial, hyperparams) + if additional_args["using_her_replay_buffer"]: + hyperparams = sample_her_params(trial, hyperparams, additional_args["her_kwargs"]) return hyperparams -def sample_ddpg_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_ddpg_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for DDPG hyperparams. @@ -371,14 +367,14 @@ def sample_ddpg_params(trial: optuna.Trial) -> Dict[str, Any]: noise_std = trial.suggest_float("noise_std", 0, 1) # NOTE: Add "verybig" to net_arch when tuning HER (see TD3) - net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) + net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "big"]) # activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU]) net_arch = { "small": [64, 64], "medium": [256, 256], "big": [400, 300], - }[net_arch] + }[net_arch_type] hyperparams = { "gamma": gamma, @@ -392,21 +388,19 @@ def sample_ddpg_params(trial: optuna.Trial) -> Dict[str, Any]: } if noise_type == "normal": - hyperparams["action_noise"] = NormalActionNoise( - mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions) - ) + hyperparams["action_noise"] = NormalActionNoise(mean=np.zeros(n_actions), sigma=noise_std * np.ones(n_actions)) elif noise_type == "ornstein-uhlenbeck": hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise( - mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions) + mean=np.zeros(n_actions), sigma=noise_std * np.ones(n_actions) ) - if trial.using_her_replay_buffer: - hyperparams = sample_her_params(trial, hyperparams) + if additional_args["using_her_replay_buffer"]: + hyperparams = sample_her_params(trial, hyperparams, additional_args["her_kwargs"]) return hyperparams -def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_dqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for DQN hyperparams. @@ -426,9 +420,9 @@ def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]: subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8]) gradient_steps = max(train_freq // subsample_steps, 1) - net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"]) + net_arch_type = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"]) - net_arch = {"tiny": [64], "small": [64, 64], "medium": [256, 256]}[net_arch] + net_arch = {"tiny": [64], "small": [64, 64], "medium": [256, 256]}[net_arch_type] hyperparams = { "gamma": gamma, @@ -444,13 +438,13 @@ def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]: "policy_kwargs": dict(net_arch=net_arch), } - if trial.using_her_replay_buffer: - hyperparams = sample_her_params(trial, hyperparams) + if additional_args["using_her_replay_buffer"]: + hyperparams = sample_her_params(trial, hyperparams, additional_args["her_kwargs"]) return hyperparams -def sample_her_params(trial: optuna.Trial, hyperparams: Dict[str, Any]) -> Dict[str, Any]: +def sample_her_params(trial: optuna.Trial, hyperparams: Dict[str, Any], her_kwargs: Dict[str, Any]) -> Dict[str, Any]: """ Sampler for HerReplayBuffer hyperparams. @@ -458,7 +452,7 @@ def sample_her_params(trial: optuna.Trial, hyperparams: Dict[str, Any]) -> Dict[ :parma hyperparams: :return: """ - her_kwargs = trial.her_kwargs.copy() + her_kwargs = her_kwargs.copy() her_kwargs["n_sampled_goal"] = trial.suggest_int("n_sampled_goal", 1, 5) her_kwargs["goal_selection_strategy"] = trial.suggest_categorical( "goal_selection_strategy", ["final", "episode", "future"] @@ -467,7 +461,7 @@ def sample_her_params(trial: optuna.Trial, hyperparams: Dict[str, Any]) -> Dict[ return hyperparams -def sample_tqc_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_tqc_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for TQC hyperparams. @@ -475,7 +469,7 @@ def sample_tqc_params(trial: optuna.Trial) -> Dict[str, Any]: :return: """ # TQC is SAC + Distributional RL - hyperparams = sample_sac_params(trial) + hyperparams = sample_sac_params(trial, n_actions, n_envs, additional_args) n_quantiles = trial.suggest_int("n_quantiles", 5, 50) top_quantiles_to_drop_per_net = trial.suggest_int("top_quantiles_to_drop_per_net", 0, n_quantiles - 1) @@ -486,7 +480,7 @@ def sample_tqc_params(trial: optuna.Trial) -> Dict[str, Any]: return hyperparams -def sample_qrdqn_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_qrdqn_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for QR-DQN hyperparams. @@ -494,7 +488,7 @@ def sample_qrdqn_params(trial: optuna.Trial) -> Dict[str, Any]: :return: """ # TQC is DQN + Distributional RL - hyperparams = sample_dqn_params(trial) + hyperparams = sample_dqn_params(trial, n_actions, n_envs, additional_args) n_quantiles = trial.suggest_int("n_quantiles", 5, 200) hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles}) @@ -502,7 +496,7 @@ def sample_qrdqn_params(trial: optuna.Trial) -> Dict[str, Any]: return hyperparams -def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]: +def sample_ars_params(trial: optuna.Trial, n_actions: int, n_envs: int, additional_args: dict) -> Dict[str, Any]: """ Sampler for ARS hyperparams. :param trial: diff --git a/rl_zoo3/import_envs.py b/rl_zoo3/import_envs.py index 4cdee2262..7d2b40447 100644 --- a/rl_zoo3/import_envs.py +++ b/rl_zoo3/import_envs.py @@ -6,12 +6,12 @@ from rl_zoo3.wrappers import MaskVelocityWrapper try: - import pybullet_envs_gymnasium # pytype: disable=import-error + import pybullet_envs_gymnasium except ImportError: pass try: - import highway_env # pytype: disable=import-error + import highway_env except ImportError: pass else: @@ -21,27 +21,27 @@ np.float = np.float32 # type: ignore[attr-defined] try: - import custom_envs # pytype: disable=import-error + import custom_envs except ImportError: pass try: - import gym_donkeycar # pytype: disable=import-error + import gym_donkeycar except ImportError: pass try: - import panda_gym # pytype: disable=import-error + import panda_gym except ImportError: pass try: - import rocket_lander_gym # pytype: disable=import-error + import rocket_lander_gym except ImportError: pass try: - import minigrid # pytype: disable=import-error + import minigrid except ImportError: pass diff --git a/rl_zoo3/plots/plot_from_file.py b/rl_zoo3/plots/plot_from_file.py index 9c007cc23..d412d6001 100644 --- a/rl_zoo3/plots/plot_from_file.py +++ b/rl_zoo3/plots/plot_from_file.py @@ -10,8 +10,8 @@ from matplotlib import pyplot as plt try: - from rliable import library as rly # pytype: disable=import-error - from rliable import metrics, plot_utils # pytype: disable=import-error + from rliable import library as rly + from rliable import metrics, plot_utils except ImportError: rly = None diff --git a/rl_zoo3/push_to_hub.py b/rl_zoo3/push_to_hub.py index b31931cfa..50f354478 100644 --- a/rl_zoo3/push_to_hub.py +++ b/rl_zoo3/push_to_hub.py @@ -354,7 +354,7 @@ def package_to_hub( args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: - loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr + loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) if loaded_args["env_kwargs"] is not None: env_kwargs = loaded_args["env_kwargs"] diff --git a/rl_zoo3/record_video.py b/rl_zoo3/record_video.py index 7dd6eed4d..a2b2071d4 100644 --- a/rl_zoo3/record_video.py +++ b/rl_zoo3/record_video.py @@ -83,7 +83,7 @@ args_path = os.path.join(log_path, env_name, "args.yml") if os.path.isfile(args_path): with open(args_path) as f: - loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr + loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) if loaded_args["env_kwargs"] is not None: env_kwargs = loaded_args["env_kwargs"] # overwrite with command line arguments diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 9da093dd0..53be1683b 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -12,7 +12,7 @@ from stable_baselines3.common.utils import set_random_seed # Register custom envs -import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error +import rl_zoo3.import_envs # noqa: F401 from rl_zoo3.exp_manager import ExperimentManager from rl_zoo3.utils import ALGOS, StoreDict @@ -164,7 +164,7 @@ def train() -> None: importlib.import_module(env_module) env_id = args.env - registered_envs = set(gym.envs.registry.keys()) # pytype: disable=module-attr + registered_envs = set(gym.envs.registry.keys()) # If the environment is not found, suggest the closest match if env_id not in registered_envs: diff --git a/rl_zoo3/utils.py b/rl_zoo3/utils.py index dfc4fb754..575f27974 100644 --- a/rl_zoo3/utils.py +++ b/rl_zoo3/utils.py @@ -406,7 +406,7 @@ def get_saved_hyperparams( if os.path.isfile(config_file): # Load saved hyperparameters with open(os.path.join(stats_path, "config.yml")) as f: - hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr + hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader) hyperparams["normalize"] = hyperparams.get("normalize", False) else: obs_rms_path = os.path.join(stats_path, "obs_rms.pkl") diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index f1f23b30a..13ce6d730 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -2.2.0a8 +2.2.0a11 diff --git a/setup.py b/setup.py index e776e929c..4b9dd62d3 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ }, entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]}, install_requires=[ - "sb3_contrib>=2.2.0a8,<3.0", + "sb3_contrib>=2.2.0a11,<3.0", "gymnasium~=0.29.1", "huggingface_sb3>=3.0,<4.0", "tqdm",