From 6579c81a61d05239b8024c95c8ae4a0ae7da5fd8 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Wed, 27 Nov 2024 21:07:50 -0300 Subject: [PATCH] feat: Wandb integration #112 Obs: Unsolved Lint formatting issues. --- experiments/solves/solve_collectables_sb3.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/experiments/solves/solve_collectables_sb3.py b/experiments/solves/solve_collectables_sb3.py index 55e64bf..f4f3f6b 100644 --- a/experiments/solves/solve_collectables_sb3.py +++ b/experiments/solves/solve_collectables_sb3.py @@ -7,6 +7,8 @@ from gymnasium import spaces from pysc2.env import sc2_env from stable_baselines3 import PPO +import wandb +from wandb.integration.sb3 import WandbCallback from urnai.environments.stablebaselines3.custom_env import CustomEnv from urnai.sc2.actions.collectables import CollectablesActionSpace @@ -36,6 +38,16 @@ def declare_trainer(): models_dir = "saves/models/PPO" logdir = "saves/logs" + conf_dict = { + "policy":"MlpPolicy", + "model_save_name": "PPO"} + + run = wandb.init( + project='solve_collectables', + config=conf_dict, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + ) + model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir) trainer = SB3Trainer(custom_env, models_dir, logdir, model) @@ -46,7 +58,8 @@ def main(unused_argv): try: trainer = declare_trainer() trainer.train_model(timesteps=10000, reset_num_timesteps=False, - tb_log_name="PPO", repeat_times=30) + tb_log_name="PPO", repeat_times=30, + callback=WandbCallback()) # trainer.load_model(f"{trainer.models_dir}/290000") # trainer.test_model(total_steps=10000, deterministic=True) except KeyboardInterrupt: