diff --git a/coltra/trainers.py b/coltra/trainers.py index b8d1e8c4..e081ae0c 100644 --- a/coltra/trainers.py +++ b/coltra/trainers.py @@ -185,6 +185,7 @@ def train( trial.report(mean_reward, step) if trial.should_prune(): print("Trial was pruned at step {}".format(step)) + self.env.close() raise optuna.TrialPruned() return metrics diff --git a/scripts/optimize/optuna_crowd.py b/scripts/optimize/optuna_crowd.py index 37f8de7c..3e6e9107 100644 --- a/scripts/optimize/optuna_crowd.py +++ b/scripts/optimize/optuna_crowd.py @@ -153,7 +153,7 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float: disable_tqdm=False, save_path=trainer.path, collect_kwargs=config["environment"], - trial=trial, + # trial=trial, ) env.close() @@ -166,7 +166,7 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float: file_name=args.env, virtual_display=(1600, 900), no_graphics=False, - worker_id=worker_id, + worker_id=worker_id+5, ) config["environment"]["evaluation_mode"] = 1.0 env.reset(**config["environment"]) @@ -217,21 +217,22 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float: # Generate the dashboard print("Generating dashboard") - trajectory = du.read_trajectory(trajectory_path) - - plt.clf() - du.make_dashboard(trajectory, save_path=dashboard_path) + print("Skipping dashboard") + # trajectory = du.read_trajectory(trajectory_path) + # + # plt.clf() + # du.make_dashboard(trajectory, save_path=dashboard_path) # Upload to wandb - print("Uploading dashboard") - wandb.log( - { - "dashboard": wandb.Image( - dashboard_path, - caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", - ) - } - ) + # print("Uploading dashboard") + # wandb.log( + # { + # "dashboard": wandb.Image( + # dashboard_path, + # caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", + # ) + # } + # ) frame_size = renders.shape[1:3] diff --git a/scripts/optimize/optuna_setup.py b/scripts/optimize/optuna_setup.py index b28c0ced..825ba05f 100644 --- a/scripts/optimize/optuna_setup.py +++ b/scripts/optimize/optuna_setup.py @@ -30,6 +30,6 @@ class Parser(BaseParser): study = optuna.create_study( storage=f"sqlite:///{args.name}.db", study_name=args.name, - pruner=optuna.pruners.HyperbandPruner(), + pruner=None, direction="maximize", )