From c25ced0ef6dfeb3e5f145be1f19bacd030d65d2b Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 13 Dec 2024 12:42:48 -0300 Subject: [PATCH] feat: Added method to alternate training and testing --- experiments/solves/solve_collectables_sb3.py | 30 +++----- urnai/trainers/stablebaselines3_trainer.py | 73 +++++++++++++++++--- 2 files changed, 74 insertions(+), 29 deletions(-) diff --git a/experiments/solves/solve_collectables_sb3.py b/experiments/solves/solve_collectables_sb3.py index f4f3f6b..b93896a 100644 --- a/experiments/solves/solve_collectables_sb3.py +++ b/experiments/solves/solve_collectables_sb3.py @@ -1,14 +1,14 @@ import os import sys +import numpy as np + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) from absl import app from gymnasium import spaces from pysc2.env import sc2_env from stable_baselines3 import PPO -import wandb -from wandb.integration.sb3 import WandbCallback from urnai.environments.stablebaselines3.custom_env import CustomEnv from urnai.sc2.actions.collectables import CollectablesActionSpace @@ -28,40 +28,32 @@ def declare_trainer(): # Define action and observation space action_space = spaces.Discrete(n=4, start=0) - observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=float) + observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) # Create the custom environment custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, action_space) - # models_dir = "saves/models/DQN" - models_dir = "saves/models/PPO" + model_name = "PPOMlp" + models_dir = f"saves/models/{model_name}" logdir = "saves/logs" conf_dict = { "policy":"MlpPolicy", - "model_save_name": "PPO"} - - run = wandb.init( - project='solve_collectables', - config=conf_dict, - sync_tensorboard=True, # auto-upload sb3's tensorboard metrics - ) + "model_save_name": model_name} - model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir) + model=PPO("CnnPolicy", custom_env, verbose=1, tensorboard_log=logdir) - trainer = SB3Trainer(custom_env, models_dir, logdir, model) + trainer = SB3Trainer(custom_env, models_dir, logdir, model, model_name, + "solve_collectables", conf_dict) return trainer def main(unused_argv): try: trainer = declare_trainer() - trainer.train_model(timesteps=10000, reset_num_timesteps=False, - tb_log_name="PPO", repeat_times=30, - callback=WandbCallback()) - # trainer.load_model(f"{trainer.models_dir}/290000") - # trainer.test_model(total_steps=10000, deterministic=True) + # trainer.load_model(f"{trainer.models_dir}/100000") + trainer.alternate_train_test(iterations=100, train_steps=10000, test_steps=1000) except KeyboardInterrupt: print("Training interrupted by user") diff --git a/urnai/trainers/stablebaselines3_trainer.py b/urnai/trainers/stablebaselines3_trainer.py index 9a848c6..bbbc16b 100644 --- a/urnai/trainers/stablebaselines3_trainer.py +++ b/urnai/trainers/stablebaselines3_trainer.py @@ -1,14 +1,22 @@ import os from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.type_aliases import MaybeCallback + +import wandb +from urnai.environments.stablebaselines3.custom_env import CustomEnv +from wandb.integration.sb3 import WandbCallback class SB3Trainer: - def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm): + def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str, + model : BaseAlgorithm, model_name : str, + wandb_project : str= "default_project", wandb_conf_dict : dict = None): self.custom_env = custom_env self.models_dir = models_dir self.model = model + self.model_name = model_name + self.wandb_project = wandb_project + self.wandb_conf_dict = wandb_conf_dict if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -19,18 +27,44 @@ def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm): def load_model(self, model_path): self.model = self.model.load(model_path, env = self.custom_env) - def train_model(self, timesteps: int = 10000, callback: MaybeCallback = None, - log_interval: int = 1, tb_log_name: str = "run", - reset_num_timesteps: bool = True, progress_bar: bool = False, - repeat_times: int = 1, start_from: int = 1): + def train_model(self, timesteps: int = 10000, log_interval: int = 1, + reset_num_timesteps: bool = False, progress_bar: bool = False, + repeat_times:int = 1, start_from:int = 1, train_run_id:str = None, + ) -> str: + + wandb_run = wandb.init( + project=self.wandb_project, + config=self.wandb_conf_dict, + name=f"train_{self.model_name}", + group="train", + sync_tensorboard=True, + resume="must" if train_run_id else None, + id=train_run_id + ) + for repeat_time in range(repeat_times): - self.model.learn(total_timesteps = timesteps, callback = callback, - log_interval = log_interval, tb_log_name = tb_log_name, + self.model.learn(total_timesteps = timesteps, callback = WandbCallback(), + log_interval = log_interval, reset_num_timesteps = reset_num_timesteps, - progress_bar = progress_bar) + progress_bar = progress_bar, + tb_log_name = self.model_name) self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + start_from)}") + + wandb_run.finish() + return wandb_run.id - def test_model(self, total_steps: int = 10000, deterministic: bool = True): + def test_model(self, total_steps: int = 10000, + deterministic: bool = True, + test_run_id : str = None + ) -> str: + wandb_run = wandb.init( + project=self.wandb_project, + name=f"test_{self.model_name}", + group="test", + resume="must" if test_run_id else None, + id=test_run_id + ) + vec_env = self.model.get_env() obs = vec_env.reset() @@ -44,5 +78,24 @@ def test_model(self, total_steps: int = 10000, deterministic: bool = True): total_reward += rewards if done: total_episodes += 1 + wandb.log({ + "test/total_reward": total_reward, + "test/episode": total_episodes, + }) total_reward = 0 # Reset reward for the new episode + + wandb_run.finish() + return wandb_run.id + + def alternate_train_test(self, iterations : int = 100, train_steps : int = 10000, + train_repeat_times : int = 1, test_steps : int = 10000, + train_run_id = None, test_run_id = None + ) -> None: + + for iteration in range(iterations): + train_run_id = self.train_model(timesteps=train_steps, + repeat_times=train_repeat_times, + train_run_id=train_run_id, + start_from=iteration*train_repeat_times +1) + test_run_id= self.test_model(total_steps=test_steps,test_run_id=test_run_id) \ No newline at end of file