Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple models for optimization #225

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
type=int,
default=500,
)
parser.add_argument("--n-models", help="Number of models for optimizing hyperparameters.", type=int, default=1)
parser.add_argument(
"-optimize", "--optimize-hyperparameters", action="store_true", default=False, help="Run hyperparameters search"
)
Expand Down Expand Up @@ -201,6 +202,7 @@
args.storage,
args.study_name,
args.n_trials,
args.n_models,
args.n_jobs,
args.sampler,
args.pruner,
Expand Down
116 changes: 65 additions & 51 deletions utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
storage: Optional[str] = None,
study_name: Optional[str] = None,
n_trials: int = 1,
n_models: int = 1,
n_jobs: int = 1,
sampler: str = "tpe",
pruner: str = "median",
Expand Down Expand Up @@ -133,10 +134,13 @@ def __init__(
self.no_optim_plots = no_optim_plots
# maximum number of trials for finding the best hyperparams
self.n_trials = n_trials
# number of parallel trained models, result is the median score
self.n_models = n_models
# number of parallel jobs when doing hyperparameter search
self.n_jobs = n_jobs
self.sampler = sampler
self.pruner = pruner
assert not (self.n_models > 1 and self.pruner != "none"), "Pruner is not currently supported for multiple models"
self.n_startup_trials = n_startup_trials
self.n_evaluations = n_evaluations
self.deterministic_eval = not self.is_atari(self.env_id)
Expand Down Expand Up @@ -649,15 +653,18 @@ def objective(self, trial: optuna.Trial) -> float:
if self.verbose >= 2:
trial_verbosity = self.verbose

model = ALGOS[self.algo](
env=env,
tensorboard_log=None,
# We do not seed the trial
seed=None,
verbose=trial_verbosity,
device=self.device,
**kwargs,
)
models = [
ALGOS[self.algo](
env=env,
tensorboard_log=None,
# We do not seed the trial
seed=None,
verbose=trial_verbosity if model_idx == 0 else 0,
device=self.device,
**kwargs,
)
for model_idx in range(self.n_models)
]

eval_env = self.create_envs(n_envs=self.n_eval_envs, eval_env=True)

Expand All @@ -668,51 +675,58 @@ def objective(self, trial: optuna.Trial) -> float:
path = None
if self.optimization_log_path is not None:
path = os.path.join(self.optimization_log_path, f"trial_{str(trial.number)}")
callbacks = get_callback_list({"callback": self.specified_callbacks})
eval_callback = TrialEvalCallback(
eval_env,
trial,
best_model_save_path=path,
log_path=path,
n_eval_episodes=self.n_eval_episodes,
eval_freq=optuna_eval_freq,
deterministic=self.deterministic_eval,
)
callbacks.append(eval_callback)

learn_kwargs = {}
# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
rewards = np.zeros(self.n_models)
for model_idx, model in enumerate(models):
callbacks = get_callback_list({"callback": self.specified_callbacks})
eval_callback = TrialEvalCallback(
eval_env,
trial,
best_model_save_path=path,
log_path=path,
n_eval_episodes=self.n_eval_episodes,
eval_freq=optuna_eval_freq,
deterministic=self.deterministic_eval,
)
callbacks.append(eval_callback)

try:
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
# Free memory
model.env.close()
eval_env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
print(e)
print("============")
print("Sampled hyperparams:")
pprint(sampled_hyperparams)
raise optuna.exceptions.TrialPruned()
is_pruned = eval_callback.is_pruned
reward = eval_callback.last_mean_reward

del model.env, eval_env
del model

if is_pruned:
raise optuna.exceptions.TrialPruned()

return reward
learn_kwargs = {}
# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
)

try:
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)
# Free memory
model.env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
print(e)
print("============")
print("Sampled hyperparams:")
pprint(sampled_hyperparams)
raise optuna.exceptions.TrialPruned()
is_pruned = eval_callback.is_pruned
rewards[model_idx] = eval_callback.last_mean_reward

del model.env
del model

if is_pruned:
eval_env.close()
del eval_env
raise optuna.exceptions.TrialPruned()

eval_env.close()
del eval_env

return np.median(rewards)

def hyperparameters_optimization(self) -> None:

Expand Down