Skip to content

Commit 2fb8cc0

Browse files
Resolves #76 (2) (#94)
* Merge branch 'issue-76' of https://github.com/UFRN-URNAI/urnai-tools into issue-76 * refactor(issue-76): Adjusted to reviewer's suggestions * feat(issue-74): Add method for choosing action in Agent class #74 * feat(issue-76): Remove pointless methods in Agent class #76
1 parent 9d80743 commit 2fb8cc0

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

tests/units/agents/test_agent_base.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import unittest
2+
from abc import ABCMeta
3+
4+
from urnai.actions.action_space_base import ActionSpaceBase
5+
from urnai.agents.agent_base import AgentBase
6+
from urnai.models.model_base import ModelBase
7+
from urnai.states.state_base import StateBase
8+
9+
10+
class FakeAgent(AgentBase):
11+
AgentBase.__abstractmethods__ = set()
12+
13+
def __init__(self, action_space,
14+
state_space : StateBase,
15+
model : ModelBase,
16+
reward):
17+
super().__init__(action_space, state_space, model, reward)
18+
19+
class FakeState(StateBase):
20+
StateBase.__abstractmethods__ = set()
21+
22+
class FakeModel(ModelBase):
23+
ModelBase.__abstractmethods__ = set()
24+
25+
def ep_reset(self, ep):
26+
return None
27+
28+
class FakeReward:
29+
def get_reward(self, obs, reward, done):
30+
return None
31+
def reset(self):
32+
return None
33+
34+
class FakeActionSpace(ActionSpaceBase):
35+
ActionSpaceBase.__abstractmethods__ = set()
36+
37+
class TestAgentBase(unittest.TestCase):
38+
39+
def test_abstract_methods(self):
40+
# GIVEN
41+
fake_agent = FakeAgent(None, None, None, None)
42+
43+
# WHEN
44+
step_return = fake_agent.step()
45+
choose_action_return = fake_agent.choose_action(FakeActionSpace())
46+
47+
# THEN
48+
assert isinstance(AgentBase, ABCMeta)
49+
assert step_return is None
50+
assert choose_action_return is None
51+
52+
def test_reset(self):
53+
# GIVEN
54+
fake_action_space = FakeActionSpace()
55+
fake_state_space = FakeState()
56+
fake_model = FakeModel()
57+
fake_reward = FakeReward()
58+
fake_agent = FakeAgent(
59+
fake_action_space,
60+
fake_state_space,
61+
fake_model,
62+
fake_reward
63+
)
64+
65+
# WHEN
66+
reset_return = fake_agent.reset()
67+
68+
# THEN
69+
assert reset_return is None
70+
assert fake_agent.previous_action is None
71+
assert fake_agent.previous_state is None
72+
73+
def test_learn(self):
74+
# GIVEN
75+
fake_model = FakeModel()
76+
fake_agent = FakeAgent(None, None, fake_model, None)
77+
78+
# WHEN
79+
learn_return = fake_agent.learn("obs", "reward", "done")
80+
81+
# THEN
82+
assert learn_return is None

urnai/agents/agent_base.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from abc import ABC, abstractmethod
2+
3+
from urnai.actions.action_base import ActionBase
4+
from urnai.actions.action_space_base import ActionSpaceBase
5+
from urnai.models.model_base import ModelBase
6+
from urnai.states.state_base import StateBase
7+
8+
9+
class AgentBase(ABC):
10+
11+
def __init__(self, action_space : ActionSpaceBase,
12+
state_space : StateBase,
13+
model : ModelBase,
14+
reward):
15+
16+
self.action_space = action_space
17+
self.state_space = state_space
18+
self.model = model
19+
self.reward = reward
20+
21+
self.previous_action = None
22+
self.previous_state = None
23+
24+
self.attr_block_list = ['model', 'action_space',
25+
'state_space', 'reward']
26+
27+
@abstractmethod
28+
def step(self) -> None:
29+
...
30+
31+
@abstractmethod
32+
def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
33+
"""
34+
Method that contains the agent's strategy for choosing actions
35+
"""
36+
...
37+
38+
def reset(self, episode=0) -> None:
39+
"""
40+
Resets some Agent class variables, such as previous_action
41+
and previous_state. Also, calls the respective reset methods
42+
for the action_wrapper and model.
43+
"""
44+
self.previous_action = None
45+
self.previous_state = None
46+
self.action_space.reset()
47+
self.model.ep_reset(episode)
48+
self.reward.reset()
49+
self.state_space.reset()
50+
51+
def learn(self, obs, reward, done) -> None:
52+
"""
53+
If it is not the very first step in an episode, this method will
54+
call the model's learn method.
55+
"""
56+
if self.previous_state is not None:
57+
next_state = self.update_state(obs)
58+
self.model.learn(self.previous_state, self.previous_action,
59+
reward, next_state, done)

0 commit comments

Comments
 (0)