Skip to content

Commit 5a2bd8b

Browse files
committed
fix: Fixed choose_action method from agent and model
1 parent 63d565a commit 5a2bd8b

File tree

5 files changed

+95
-89
lines changed

5 files changed

+95
-89
lines changed

urnai/agents/agent_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def __init__(self, action_space : ActionSpaceBase,
2828
def step(self) -> None:
2929
...
3030

31-
# @abstractmethod
32-
# def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
33-
# """
34-
# Method that contains the agent's strategy for choosing actions
35-
# """
36-
# ...
31+
@abstractmethod
32+
def choose_action(self, action_space : ActionSpaceBase) -> int:
33+
"""
34+
Method that contains the agent's strategy for choosing actions
35+
"""
36+
...
3737

3838
def reset(self, episode=0) -> None:
3939
"""

urnai/models/dqn_pytorch.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,11 @@ class name
6464
def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase,
6565
gamma=0.99, learning_rate=0.001, learning_rate_min=0.0001,
6666
learning_rate_decay=0.99995, learning_rate_decay_ep_cutoff=0,
67-
name='DQNPytorch', epsilon_start=1.0, epsilon_min=0.01,
68-
epsilon_decay_rate=0.995, per_episode_epsilon_decay=False,
67+
name='DQNPytorch',
6968
batch_size=64, memory_maxlen=50000, min_memory_size=1000,
7069
build_model=ModelBuilder.DEFAULT_BUILD_MODEL, seed_value=None,
71-
cpu_only=False, epsilon_linear_decay=False,
72-
lr_linear_decay=False, epsilon_decay_ep_start=0):
70+
cpu_only=False,
71+
lr_linear_decay=False):
7372

7473
# -----------------------------------------------
7574
self.seed_value = seed_value
@@ -90,14 +89,6 @@ def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase,
9089
self.action_size = action_wrapper.size
9190
self.state_size = state_builder.dimension
9291

93-
# EXPLORATION PARAMETERS FOR EPSILON GREEDY STRATEGY
94-
self.epsilon_greedy = epsilon_start
95-
self.epsilon_min = epsilon_min
96-
self.epsilon_decay_rate = epsilon_decay_rate
97-
self.per_episode_epsilon_decay = per_episode_epsilon_decay
98-
self.epsilon_linear_decay = epsilon_linear_decay
99-
self.epsilon_decay_ep_start = epsilon_decay_ep_start
100-
10192
# self.tensorboard_callback_logdir = ''
10293
self.tensorboard_callback = None
10394
# -----------------------------------------------------------------
@@ -174,9 +165,6 @@ def learn(self, s, a, r, s_, done):
174165
# to do: add tau to model definition so that it can be passed here
175166
self.soft_update(self.model, self.target_model)
176167

177-
if not self.per_episode_epsilon_decay:
178-
self.decay_epsilon()
179-
180168
def soft_update(self, local_model, target_model, tau=1e-3):
181169
"""Soft update model parameters.
182170
θ_target = τ*θ_local + (1 - τ)*θ_target
@@ -191,25 +179,6 @@ def soft_update(self, local_model, target_model, tau=1e-3):
191179
target_param.data.copy_(
192180
tau * local_param.data + (1 - tau) * target_param.data)
193181

194-
def choose_action(self, state, excluded_actions, is_training=True):
195-
"""
196-
If current epsilon greedy strategy is reached a random action will
197-
be returned. If not, self.predict will be called to choose the action
198-
with the highest Q-Value.
199-
"""
200-
if not is_training:
201-
return self.predict(state, excluded_actions)
202-
203-
else:
204-
if np.random.rand() <= self.epsilon_greedy:
205-
random_action = random.choice(self.actions)
206-
# Removing excluded actions
207-
while random_action in excluded_actions:
208-
random_action = random.choice(self.actions)
209-
return random_action
210-
else:
211-
return self.predict(state, excluded_actions)
212-
213182
def predict(self, state, excluded_actions):
214183
"""Gets the action with the highest Q-value from our DQN PyTorch model"""
215184
state = torch.from_numpy(state).float().unsqueeze(0).to(device)
@@ -264,11 +233,9 @@ def set_seeds(self):
264233

265234
def ep_reset(self, episode=0):
266235
"""
267-
This method is mainly used to enact the decay_epsilon and decay_lr
236+
This method is mainly used to enact the decay_lr
268237
at the end of every episode.
269238
"""
270-
if self.per_episode_epsilon_decay and episode >= self.epsilon_decay_ep_start:
271-
self.decay_epsilon()
272239

273240
if (episode > self.learning_rate_decay_ep_cutoff
274241
and self.learning_rate_decay != 1):
@@ -286,19 +253,6 @@ def decay_lr(self):
286253
else:
287254
if self.learning_rate > self.learning_rate_min:
288255
self.learning_rate *= self.learning_rate_decay
289-
290-
def decay_epsilon(self):
291-
"""
292-
Implements the epsilon greedy strategy, effectivelly lowering the current
293-
epsilon greedy value by multiplying it by the epsilon_decay_rate
294-
(the higher the value, the less it lowers the epsilon_decay).
295-
"""
296-
if self.epsilon_linear_decay:
297-
if self.epsilon_greedy > self.epsilon_min:
298-
self.epsilon_greedy -= (1 - self.epsilon_decay_rate)
299-
else:
300-
if self.epsilon_greedy > self.epsilon_min:
301-
self.epsilon_greedy *= self.epsilon_decay_rate
302256

303257

304258
class QNetwork(nn.Module):

urnai/models/model_base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@ def predict(self, state) -> int:
2020
"""Returns the best action for this given state"""
2121
...
2222

23-
@abstractmethod
24-
def choose_action(self, state, excluded_actions, is_testing=False) -> int:
25-
"""
26-
Implements the logic for choosing an action while training and while
27-
testing an Agent. For most Reinforcement Learning Algorithms, this method
28-
will choose an action directly When is_testing=True, and will implement the
29-
exploration algorithm for when is_testing=False.
30-
One such exploration algorithm commonly used it the epsilon greedy strategy.
31-
"""
32-
pass
33-
3423
def save(self, persist_path) -> None:
3524
self.persistence.save(persist_path)
3625

urnai/sc2/agents/sc2_agent.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
2+
import random
23
import sys
34

5+
import numpy as np
6+
47
from urnai.actions.action_space_base import ActionSpaceBase
58
from urnai.agents.agent_base import AgentBase
69
from urnai.models.model_base import ModelBase
@@ -12,26 +15,39 @@
1215

1316
class SC2Agent(AgentBase):
1417
def __init__(self, action_space : ActionSpaceBase, state_builder : StateBase,
15-
model: ModelBase, reward_builder: RewardBase):
18+
model: ModelBase, reward_builder: RewardBase, epsilon_start=1.0,
19+
epsilon_min=0.01, epsilon_decay_rate=0.995,
20+
per_episode_epsilon_decay=False, epsilon_linear_decay=False,
21+
epsilon_decay_ep_start=0):
1622
super().__init__(action_space, state_builder, model, reward_builder)
1723
# self.reward = 0
1824
self.episodes = 0
1925
self.steps = 0
2026

27+
# EXPLORATION PARAMETERS FOR EPSILON GREEDY STRATEGY
28+
self.epsilon_greedy = epsilon_start
29+
self.epsilon_min = epsilon_min
30+
self.epsilon_decay_rate = epsilon_decay_rate
31+
self.per_episode_epsilon_decay = per_episode_epsilon_decay
32+
self.epsilon_linear_decay = epsilon_linear_decay
33+
self.epsilon_decay_ep_start = epsilon_decay_ep_start
34+
35+
2136
def reset(self, episode=0):
2237
super().reset(episode)
2338
self.episodes += 1
39+
if self.per_episode_epsilon_decay and episode >= self.epsilon_decay_ep_start:
40+
self.decay_epsilon()
2441

2542
def step(self, obs, done, is_training=True):
2643
self.steps += 1
2744

2845
if self.action_space.is_action_done():
2946
current_state = self.state_space.update(obs)
3047
excluded_actions = self.action_space.get_excluded_actions(obs)
31-
predicted_action_idx = self.model.choose_action(current_state,
32-
excluded_actions,
33-
is_training)
34-
self.previous_action = predicted_action_idx
48+
chosen_action_idx = self.choose_action(current_state, excluded_actions,
49+
is_training)
50+
self.previous_action = chosen_action_idx
3551
self.previous_state = current_state
3652
selected_action = [self.action_space.get_action(self.previous_action, obs)]
3753

@@ -42,3 +58,41 @@ def step(self, obs, done, is_training=True):
4258
# raise error.ActionError(
4359
# 'Invalid function structure. Function name: %s.' % selected_action[0])
4460
return selected_action
61+
62+
def choose_action(self, state, excluded_actions, is_training=True):
63+
if is_training:
64+
if np.random.rand() <= self.epsilon_greedy:
65+
random_action = random.choice(self.action_space.get_actions())
66+
# Removing excluded actions
67+
while random_action in excluded_actions:
68+
random_action = random.choice(self.action_space.get_actions())
69+
return random_action
70+
else:
71+
return self.model.predict(state, excluded_actions)
72+
else:
73+
return self.model.predict(state, excluded_actions)
74+
75+
def learn(self, obs, reward, done) -> None:
76+
"""
77+
If it is not the very first step in an episode, this method will
78+
call the model's learn method.
79+
"""
80+
if self.previous_state is not None:
81+
next_state = self.state_space.update(obs)
82+
self.model.learn(self.previous_state, self.previous_action,
83+
reward, next_state, done)
84+
if not self.per_episode_epsilon_decay:
85+
self.decay_epsilon()
86+
87+
def decay_epsilon(self):
88+
"""
89+
Implements the epsilon greedy strategy, effectivelly lowering the current
90+
epsilon greedy value by multiplying it by the epsilon_decay_rate
91+
(the higher the value, the less it lowers the epsilon_decay).
92+
"""
93+
if self.epsilon_linear_decay:
94+
if self.epsilon_greedy > self.epsilon_min:
95+
self.epsilon_greedy -= (1 - self.epsilon_decay_rate)
96+
else:
97+
if self.epsilon_greedy > self.epsilon_min:
98+
self.epsilon_greedy *= self.epsilon_decay_rate

urnai/trainers/trainer.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import os
33
import sys
4+
from datetime import datetime
45

56
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
67
parentdir = os.path.dirname(currentdir)
@@ -11,16 +12,28 @@
1112
class Trainer:
1213
# TODO: Add an option to play every x episodes, instead of just training non-stop
1314

14-
def __init__(self, env, agent, max_training_episodes, max_playing_episodes,
15-
max_steps_training, max_steps_playing,
16-
):
15+
def __init__(
16+
self, env, agent, max_training_episodes, max_playing_episodes,
17+
max_steps_training, max_steps_playing, enable_save=True, save_every=10,
18+
save_path= None, file_name=None,
19+
):
1720

1821
self.env = env
1922
self.agent = agent
2023
self.max_training_episodes = max_training_episodes
2124
self.max_playing_episodes = max_playing_episodes
2225
self.max_steps_training = max_steps_training
2326
self.max_steps_playing = max_steps_playing
27+
self.enable_save = enable_save
28+
self.save_every = save_every
29+
if save_path is None:
30+
save_path = os.path.expanduser('~') + os.path.sep + 'urnai_saved_trainings'
31+
if file_name is None:
32+
file_name=str(datetime.now()).replace(' ','_').replace(':','_').replace('.',
33+
'_')
34+
# self.full_save_path = save_path + os.path.sep + file_name
35+
# TODO: Change this to a user defined path
36+
self.full_save_path = "saves/"
2437

2538
def train(self, reward_from_agent=True):
2639
self.training_loop(is_training=True, reward_from_agent=reward_from_agent)
@@ -92,24 +105,20 @@ def training_loop(self, is_training, reward_from_agent=True):
92105

93106
if done:
94107
print("Episode: %d, Reward: %d" % (current_episodes, ep_reward))
95-
self.agent.model.save("saves/")
108+
# self.agent.model.save("saves/")
96109
break
97110

98-
# if this is not a test (evaluation), saving is enabled and we are in a
99-
# multiple of our save_every variable then we save the model and generate
100-
# graphs
101-
# TODO
102-
# if is_training \
103-
# and self.enable_save \
104-
# and current_episodes > 0 \
105-
# and current_episodes % self.save_every == 0:
106-
# self.save(self.full_save_path)
111+
# if the agent is training, saving is enabled and we are in a
112+
# multiple of our save_every variable then we save the model
113+
if is_training \
114+
and self.enable_save \
115+
and current_episodes > 0 \
116+
and current_episodes % self.save_every == 0:
117+
self.agent.model.save(self.full_save_path)
107118

108119

109120
self.env.close()
110121

111122
# Saving the model at the end of the training loop
112-
# TODO
113-
# if self.enable_save:
114-
# if is_training:
115-
# self.save(self.full_save_path)
123+
if self.enable_save and is_training:
124+
self.agent.model.save(self.full_save_path)

0 commit comments

Comments
 (0)