Skip to content

Commit

Permalink
feat: Added method to alternate training and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Dec 13, 2024
1 parent 86b4dd5 commit c25ced0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 29 deletions.
30 changes: 11 additions & 19 deletions experiments/solves/solve_collectables_sb3.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import sys

import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from absl import app
from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3 import PPO
import wandb
from wandb.integration.sb3 import WandbCallback

from urnai.environments.stablebaselines3.custom_env import CustomEnv
from urnai.sc2.actions.collectables import CollectablesActionSpace
Expand All @@ -28,40 +28,32 @@ def declare_trainer():

# Define action and observation space
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=float)
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)

# Create the custom environment
custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)

# models_dir = "saves/models/DQN"
models_dir = "saves/models/PPO"
model_name = "PPOMlp"
models_dir = f"saves/models/{model_name}"
logdir = "saves/logs"

conf_dict = {
"policy":"MlpPolicy",
"model_save_name": "PPO"}

run = wandb.init(
project='solve_collectables',
config=conf_dict,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
)
"model_save_name": model_name}

model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir)
model=PPO("CnnPolicy", custom_env, verbose=1, tensorboard_log=logdir)

trainer = SB3Trainer(custom_env, models_dir, logdir, model)
trainer = SB3Trainer(custom_env, models_dir, logdir, model, model_name,
"solve_collectables", conf_dict)

return trainer

def main(unused_argv):
try:
trainer = declare_trainer()
trainer.train_model(timesteps=10000, reset_num_timesteps=False,
tb_log_name="PPO", repeat_times=30,
callback=WandbCallback())
# trainer.load_model(f"{trainer.models_dir}/290000")
# trainer.test_model(total_steps=10000, deterministic=True)
# trainer.load_model(f"{trainer.models_dir}/100000")
trainer.alternate_train_test(iterations=100, train_steps=10000, test_steps=1000)
except KeyboardInterrupt:
print("Training interrupted by user")

Expand Down
73 changes: 63 additions & 10 deletions urnai/trainers/stablebaselines3_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import os

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.type_aliases import MaybeCallback

import wandb
from urnai.environments.stablebaselines3.custom_env import CustomEnv
from wandb.integration.sb3 import WandbCallback


class SB3Trainer:
def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm):
def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str,
model : BaseAlgorithm, model_name : str,
wandb_project : str= "default_project", wandb_conf_dict : dict = None):
self.custom_env = custom_env
self.models_dir = models_dir
self.model = model
self.model_name = model_name
self.wandb_project = wandb_project
self.wandb_conf_dict = wandb_conf_dict

if not os.path.exists(models_dir):
os.makedirs(models_dir)
Expand All @@ -19,18 +27,44 @@ def __init__(self, custom_env, models_dir, logdir, model : BaseAlgorithm):
def load_model(self, model_path):
self.model = self.model.load(model_path, env = self.custom_env)

def train_model(self, timesteps: int = 10000, callback: MaybeCallback = None,
log_interval: int = 1, tb_log_name: str = "run",
reset_num_timesteps: bool = True, progress_bar: bool = False,
repeat_times: int = 1, start_from: int = 1):
def train_model(self, timesteps: int = 10000, log_interval: int = 1,
reset_num_timesteps: bool = False, progress_bar: bool = False,
repeat_times:int = 1, start_from:int = 1, train_run_id:str = None,
) -> str:

wandb_run = wandb.init(
project=self.wandb_project,
config=self.wandb_conf_dict,
name=f"train_{self.model_name}",
group="train",
sync_tensorboard=True,
resume="must" if train_run_id else None,
id=train_run_id
)

for repeat_time in range(repeat_times):
self.model.learn(total_timesteps = timesteps, callback = callback,
log_interval = log_interval, tb_log_name = tb_log_name,
self.model.learn(total_timesteps = timesteps, callback = WandbCallback(),
log_interval = log_interval,
reset_num_timesteps = reset_num_timesteps,
progress_bar = progress_bar)
progress_bar = progress_bar,
tb_log_name = self.model_name)
self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + start_from)}")

wandb_run.finish()
return wandb_run.id

def test_model(self, total_steps: int = 10000, deterministic: bool = True):
def test_model(self, total_steps: int = 10000,
deterministic: bool = True,
test_run_id : str = None
) -> str:
wandb_run = wandb.init(
project=self.wandb_project,
name=f"test_{self.model_name}",
group="test",
resume="must" if test_run_id else None,
id=test_run_id
)

vec_env = self.model.get_env()
obs = vec_env.reset()

Expand All @@ -44,5 +78,24 @@ def test_model(self, total_steps: int = 10000, deterministic: bool = True):
total_reward += rewards
if done:
total_episodes += 1
wandb.log({
"test/total_reward": total_reward,
"test/episode": total_episodes,
})
total_reward = 0 # Reset reward for the new episode

wandb_run.finish()
return wandb_run.id

def alternate_train_test(self, iterations : int = 100, train_steps : int = 10000,
train_repeat_times : int = 1, test_steps : int = 10000,
train_run_id = None, test_run_id = None
) -> None:

for iteration in range(iterations):
train_run_id = self.train_model(timesteps=train_steps,
repeat_times=train_repeat_times,
train_run_id=train_run_id,
start_from=iteration*train_repeat_times +1)
test_run_id= self.test_model(total_steps=test_steps,test_run_id=test_run_id)

0 comments on commit c25ced0

Please sign in to comment.