From 5a2bd8b9756004516907387de9d61489bffb3fbc Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 16 Aug 2024 11:39:43 -0300 Subject: [PATCH] fix: Fixed choose_action method from agent and model --- urnai/agents/agent_base.py | 12 +++---- urnai/models/dqn_pytorch.py | 54 +++-------------------------- urnai/models/model_base.py | 11 ------ urnai/sc2/agents/sc2_agent.py | 64 ++++++++++++++++++++++++++++++++--- urnai/trainers/trainer.py | 43 +++++++++++++---------- 5 files changed, 95 insertions(+), 89 deletions(-) diff --git a/urnai/agents/agent_base.py b/urnai/agents/agent_base.py index 4772c78..54e17d4 100644 --- a/urnai/agents/agent_base.py +++ b/urnai/agents/agent_base.py @@ -28,12 +28,12 @@ def __init__(self, action_space : ActionSpaceBase, def step(self) -> None: ... - # @abstractmethod - # def choose_action(self, action_space : ActionSpaceBase) -> ActionBase: - # """ - # Method that contains the agent's strategy for choosing actions - # """ - # ... + @abstractmethod + def choose_action(self, action_space : ActionSpaceBase) -> int: + """ + Method that contains the agent's strategy for choosing actions + """ + ... def reset(self, episode=0) -> None: """ diff --git a/urnai/models/dqn_pytorch.py b/urnai/models/dqn_pytorch.py index 389c0a8..1f75593 100644 --- a/urnai/models/dqn_pytorch.py +++ b/urnai/models/dqn_pytorch.py @@ -64,12 +64,11 @@ class name def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase, gamma=0.99, learning_rate=0.001, learning_rate_min=0.0001, learning_rate_decay=0.99995, learning_rate_decay_ep_cutoff=0, - name='DQNPytorch', epsilon_start=1.0, epsilon_min=0.01, - epsilon_decay_rate=0.995, per_episode_epsilon_decay=False, + name='DQNPytorch', batch_size=64, memory_maxlen=50000, min_memory_size=1000, build_model=ModelBuilder.DEFAULT_BUILD_MODEL, seed_value=None, - cpu_only=False, epsilon_linear_decay=False, - lr_linear_decay=False, epsilon_decay_ep_start=0): + cpu_only=False, + lr_linear_decay=False): # ----------------------------------------------- self.seed_value = seed_value @@ -90,14 +89,6 @@ def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase, self.action_size = action_wrapper.size self.state_size = state_builder.dimension - # EXPLORATION PARAMETERS FOR EPSILON GREEDY STRATEGY - self.epsilon_greedy = epsilon_start - self.epsilon_min = epsilon_min - self.epsilon_decay_rate = epsilon_decay_rate - self.per_episode_epsilon_decay = per_episode_epsilon_decay - self.epsilon_linear_decay = epsilon_linear_decay - self.epsilon_decay_ep_start = epsilon_decay_ep_start - # self.tensorboard_callback_logdir = '' self.tensorboard_callback = None # ----------------------------------------------------------------- @@ -174,9 +165,6 @@ def learn(self, s, a, r, s_, done): # to do: add tau to model definition so that it can be passed here self.soft_update(self.model, self.target_model) - if not self.per_episode_epsilon_decay: - self.decay_epsilon() - def soft_update(self, local_model, target_model, tau=1e-3): """Soft update model parameters. θ_target = τ*θ_local + (1 - τ)*θ_target @@ -191,25 +179,6 @@ def soft_update(self, local_model, target_model, tau=1e-3): target_param.data.copy_( tau * local_param.data + (1 - tau) * target_param.data) - def choose_action(self, state, excluded_actions, is_training=True): - """ - If current epsilon greedy strategy is reached a random action will - be returned. If not, self.predict will be called to choose the action - with the highest Q-Value. - """ - if not is_training: - return self.predict(state, excluded_actions) - - else: - if np.random.rand() <= self.epsilon_greedy: - random_action = random.choice(self.actions) - # Removing excluded actions - while random_action in excluded_actions: - random_action = random.choice(self.actions) - return random_action - else: - return self.predict(state, excluded_actions) - def predict(self, state, excluded_actions): """Gets the action with the highest Q-value from our DQN PyTorch model""" state = torch.from_numpy(state).float().unsqueeze(0).to(device) @@ -264,11 +233,9 @@ def set_seeds(self): def ep_reset(self, episode=0): """ - This method is mainly used to enact the decay_epsilon and decay_lr + This method is mainly used to enact the decay_lr at the end of every episode. """ - if self.per_episode_epsilon_decay and episode >= self.epsilon_decay_ep_start: - self.decay_epsilon() if (episode > self.learning_rate_decay_ep_cutoff and self.learning_rate_decay != 1): @@ -286,19 +253,6 @@ def decay_lr(self): else: if self.learning_rate > self.learning_rate_min: self.learning_rate *= self.learning_rate_decay - - def decay_epsilon(self): - """ - Implements the epsilon greedy strategy, effectivelly lowering the current - epsilon greedy value by multiplying it by the epsilon_decay_rate - (the higher the value, the less it lowers the epsilon_decay). - """ - if self.epsilon_linear_decay: - if self.epsilon_greedy > self.epsilon_min: - self.epsilon_greedy -= (1 - self.epsilon_decay_rate) - else: - if self.epsilon_greedy > self.epsilon_min: - self.epsilon_greedy *= self.epsilon_decay_rate class QNetwork(nn.Module): diff --git a/urnai/models/model_base.py b/urnai/models/model_base.py index 21563bb..a8201dd 100644 --- a/urnai/models/model_base.py +++ b/urnai/models/model_base.py @@ -20,17 +20,6 @@ def predict(self, state) -> int: """Returns the best action for this given state""" ... - @abstractmethod - def choose_action(self, state, excluded_actions, is_testing=False) -> int: - """ - Implements the logic for choosing an action while training and while - testing an Agent. For most Reinforcement Learning Algorithms, this method - will choose an action directly When is_testing=True, and will implement the - exploration algorithm for when is_testing=False. - One such exploration algorithm commonly used it the epsilon greedy strategy. - """ - pass - def save(self, persist_path) -> None: self.persistence.save(persist_path) diff --git a/urnai/sc2/agents/sc2_agent.py b/urnai/sc2/agents/sc2_agent.py index ceb72d1..5471f70 100644 --- a/urnai/sc2/agents/sc2_agent.py +++ b/urnai/sc2/agents/sc2_agent.py @@ -1,6 +1,9 @@ import os +import random import sys +import numpy as np + from urnai.actions.action_space_base import ActionSpaceBase from urnai.agents.agent_base import AgentBase from urnai.models.model_base import ModelBase @@ -12,15 +15,29 @@ class SC2Agent(AgentBase): def __init__(self, action_space : ActionSpaceBase, state_builder : StateBase, - model: ModelBase, reward_builder: RewardBase): + model: ModelBase, reward_builder: RewardBase, epsilon_start=1.0, + epsilon_min=0.01, epsilon_decay_rate=0.995, + per_episode_epsilon_decay=False, epsilon_linear_decay=False, + epsilon_decay_ep_start=0): super().__init__(action_space, state_builder, model, reward_builder) # self.reward = 0 self.episodes = 0 self.steps = 0 + # EXPLORATION PARAMETERS FOR EPSILON GREEDY STRATEGY + self.epsilon_greedy = epsilon_start + self.epsilon_min = epsilon_min + self.epsilon_decay_rate = epsilon_decay_rate + self.per_episode_epsilon_decay = per_episode_epsilon_decay + self.epsilon_linear_decay = epsilon_linear_decay + self.epsilon_decay_ep_start = epsilon_decay_ep_start + + def reset(self, episode=0): super().reset(episode) self.episodes += 1 + if self.per_episode_epsilon_decay and episode >= self.epsilon_decay_ep_start: + self.decay_epsilon() def step(self, obs, done, is_training=True): self.steps += 1 @@ -28,10 +45,9 @@ def step(self, obs, done, is_training=True): if self.action_space.is_action_done(): current_state = self.state_space.update(obs) excluded_actions = self.action_space.get_excluded_actions(obs) - predicted_action_idx = self.model.choose_action(current_state, - excluded_actions, - is_training) - self.previous_action = predicted_action_idx + chosen_action_idx = self.choose_action(current_state, excluded_actions, + is_training) + self.previous_action = chosen_action_idx self.previous_state = current_state selected_action = [self.action_space.get_action(self.previous_action, obs)] @@ -42,3 +58,41 @@ def step(self, obs, done, is_training=True): # raise error.ActionError( # 'Invalid function structure. Function name: %s.' % selected_action[0]) return selected_action + + def choose_action(self, state, excluded_actions, is_training=True): + if is_training: + if np.random.rand() <= self.epsilon_greedy: + random_action = random.choice(self.action_space.get_actions()) + # Removing excluded actions + while random_action in excluded_actions: + random_action = random.choice(self.action_space.get_actions()) + return random_action + else: + return self.model.predict(state, excluded_actions) + else: + return self.model.predict(state, excluded_actions) + + def learn(self, obs, reward, done) -> None: + """ + If it is not the very first step in an episode, this method will + call the model's learn method. + """ + if self.previous_state is not None: + next_state = self.state_space.update(obs) + self.model.learn(self.previous_state, self.previous_action, + reward, next_state, done) + if not self.per_episode_epsilon_decay: + self.decay_epsilon() + + def decay_epsilon(self): + """ + Implements the epsilon greedy strategy, effectivelly lowering the current + epsilon greedy value by multiplying it by the epsilon_decay_rate + (the higher the value, the less it lowers the epsilon_decay). + """ + if self.epsilon_linear_decay: + if self.epsilon_greedy > self.epsilon_min: + self.epsilon_greedy -= (1 - self.epsilon_decay_rate) + else: + if self.epsilon_greedy > self.epsilon_min: + self.epsilon_greedy *= self.epsilon_decay_rate \ No newline at end of file diff --git a/urnai/trainers/trainer.py b/urnai/trainers/trainer.py index 8f8660d..2d8193c 100644 --- a/urnai/trainers/trainer.py +++ b/urnai/trainers/trainer.py @@ -1,6 +1,7 @@ import inspect import os import sys +from datetime import datetime currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(currentdir) @@ -11,9 +12,11 @@ class Trainer: # TODO: Add an option to play every x episodes, instead of just training non-stop - def __init__(self, env, agent, max_training_episodes, max_playing_episodes, - max_steps_training, max_steps_playing, - ): + def __init__( + self, env, agent, max_training_episodes, max_playing_episodes, + max_steps_training, max_steps_playing, enable_save=True, save_every=10, + save_path= None, file_name=None, + ): self.env = env self.agent = agent @@ -21,6 +24,16 @@ def __init__(self, env, agent, max_training_episodes, max_playing_episodes, self.max_playing_episodes = max_playing_episodes self.max_steps_training = max_steps_training self.max_steps_playing = max_steps_playing + self.enable_save = enable_save + self.save_every = save_every + if save_path is None: + save_path = os.path.expanduser('~') + os.path.sep + 'urnai_saved_trainings' + if file_name is None: + file_name=str(datetime.now()).replace(' ','_').replace(':','_').replace('.', + '_') + # self.full_save_path = save_path + os.path.sep + file_name + # TODO: Change this to a user defined path + self.full_save_path = "saves/" def train(self, reward_from_agent=True): self.training_loop(is_training=True, reward_from_agent=reward_from_agent) @@ -92,24 +105,20 @@ def training_loop(self, is_training, reward_from_agent=True): if done: print("Episode: %d, Reward: %d" % (current_episodes, ep_reward)) - self.agent.model.save("saves/") + # self.agent.model.save("saves/") break - # if this is not a test (evaluation), saving is enabled and we are in a - # multiple of our save_every variable then we save the model and generate - # graphs - # TODO - # if is_training \ - # and self.enable_save \ - # and current_episodes > 0 \ - # and current_episodes % self.save_every == 0: - # self.save(self.full_save_path) + # if the agent is training, saving is enabled and we are in a + # multiple of our save_every variable then we save the model + if is_training \ + and self.enable_save \ + and current_episodes > 0 \ + and current_episodes % self.save_every == 0: + self.agent.model.save(self.full_save_path) self.env.close() # Saving the model at the end of the training loop - # TODO - # if self.enable_save: - # if is_training: - # self.save(self.full_save_path) + if self.enable_save and is_training: + self.agent.model.save(self.full_save_path)