diff --git a/train.py b/train.py index bbf85795e..93e6bd49c 100644 --- a/train.py +++ b/train.py @@ -64,6 +64,7 @@ type=int, default=500, ) + parser.add_argument("--n-models", help="Number of models for optimizing hyperparameters.", type=int, default=1) parser.add_argument( "-optimize", "--optimize-hyperparameters", action="store_true", default=False, help="Run hyperparameters search" ) @@ -201,6 +202,7 @@ args.storage, args.study_name, args.n_trials, + args.n_models, args.n_jobs, args.sampler, args.pruner, diff --git a/utils/exp_manager.py b/utils/exp_manager.py index a9c9fe6f8..55f02433a 100644 --- a/utils/exp_manager.py +++ b/utils/exp_manager.py @@ -73,6 +73,7 @@ def __init__( storage: Optional[str] = None, study_name: Optional[str] = None, n_trials: int = 1, + n_models: int = 1, n_jobs: int = 1, sampler: str = "tpe", pruner: str = "median", @@ -133,10 +134,13 @@ def __init__( self.no_optim_plots = no_optim_plots # maximum number of trials for finding the best hyperparams self.n_trials = n_trials + # number of parallel trained models, result is the median score + self.n_models = n_models # number of parallel jobs when doing hyperparameter search self.n_jobs = n_jobs self.sampler = sampler self.pruner = pruner + assert not (self.n_models > 1 and self.pruner != "none"), "Pruner is not currently supported for multiple models" self.n_startup_trials = n_startup_trials self.n_evaluations = n_evaluations self.deterministic_eval = not self.is_atari(self.env_id) @@ -649,15 +653,18 @@ def objective(self, trial: optuna.Trial) -> float: if self.verbose >= 2: trial_verbosity = self.verbose - model = ALGOS[self.algo]( - env=env, - tensorboard_log=None, - # We do not seed the trial - seed=None, - verbose=trial_verbosity, - device=self.device, - **kwargs, - ) + models = [ + ALGOS[self.algo]( + env=env, + tensorboard_log=None, + # We do not seed the trial + seed=None, + verbose=trial_verbosity if model_idx == 0 else 0, + device=self.device, + **kwargs, + ) + for model_idx in range(self.n_models) + ] eval_env = self.create_envs(n_envs=self.n_eval_envs, eval_env=True) @@ -668,51 +675,58 @@ def objective(self, trial: optuna.Trial) -> float: path = None if self.optimization_log_path is not None: path = os.path.join(self.optimization_log_path, f"trial_{str(trial.number)}") - callbacks = get_callback_list({"callback": self.specified_callbacks}) - eval_callback = TrialEvalCallback( - eval_env, - trial, - best_model_save_path=path, - log_path=path, - n_eval_episodes=self.n_eval_episodes, - eval_freq=optuna_eval_freq, - deterministic=self.deterministic_eval, - ) - callbacks.append(eval_callback) - learn_kwargs = {} - # Special case for ARS - if self.algo == "ars" and self.n_envs > 1: - learn_kwargs["async_eval"] = AsyncEval( - [lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy + rewards = np.zeros(self.n_models) + for model_idx, model in enumerate(models): + callbacks = get_callback_list({"callback": self.specified_callbacks}) + eval_callback = TrialEvalCallback( + eval_env, + trial, + best_model_save_path=path, + log_path=path, + n_eval_episodes=self.n_eval_episodes, + eval_freq=optuna_eval_freq, + deterministic=self.deterministic_eval, ) + callbacks.append(eval_callback) - try: - model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) - # Free memory - model.env.close() - eval_env.close() - except (AssertionError, ValueError) as e: - # Sometimes, random hyperparams can generate NaN - # Free memory - model.env.close() - eval_env.close() - # Prune hyperparams that generate NaNs - print(e) - print("============") - print("Sampled hyperparams:") - pprint(sampled_hyperparams) - raise optuna.exceptions.TrialPruned() - is_pruned = eval_callback.is_pruned - reward = eval_callback.last_mean_reward - - del model.env, eval_env - del model - - if is_pruned: - raise optuna.exceptions.TrialPruned() - - return reward + learn_kwargs = {} + # Special case for ARS + if self.algo == "ars" and self.n_envs > 1: + learn_kwargs["async_eval"] = AsyncEval( + [lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy + ) + + try: + model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) + # Free memory + model.env.close() + except (AssertionError, ValueError) as e: + # Sometimes, random hyperparams can generate NaN + # Free memory + model.env.close() + eval_env.close() + # Prune hyperparams that generate NaNs + print(e) + print("============") + print("Sampled hyperparams:") + pprint(sampled_hyperparams) + raise optuna.exceptions.TrialPruned() + is_pruned = eval_callback.is_pruned + rewards[model_idx] = eval_callback.last_mean_reward + + del model.env + del model + + if is_pruned: + eval_env.close() + del eval_env + raise optuna.exceptions.TrialPruned() + + eval_env.close() + del eval_env + + return np.median(rewards) def hyperparameters_optimization(self) -> None: