Skip to content

Resolves #76 (2) #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/units/agents/test_agent_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest
from abc import ABCMeta

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.agents.agent_base import AgentBase
from urnai.models.model_base import ModelBase
from urnai.states.state_base import StateBase


class FakeAgent(AgentBase):
AgentBase.__abstractmethods__ = set()

def __init__(self, action_space,
state_space : StateBase,
model : ModelBase,
reward_builder):
super().__init__(action_space, state_space, model, reward_builder)

class FakeState(StateBase):
StateBase.__abstractmethods__ = set()

class FakeModel(ModelBase):
ModelBase.__abstractmethods__ = set()

def ep_reset(self, ep):
return None

class FakeRewardBuilder:
def get_reward(self, obs, reward, done):
return None
def reset(self):
return None

class FakeActionSpace(ActionSpaceBase):
ActionSpaceBase.__abstractmethods__ = set()

class TestAgentBase(unittest.TestCase):

def test_abstract_methods(self):
# GIVEN
fake_agent = FakeAgent(None, None, None, None)

# WHEN
step_return = fake_agent.step()

# THEN
assert isinstance(AgentBase, ABCMeta)
assert step_return is None

def test_reset(self):
# GIVEN
fake_action_space = FakeActionSpace()
fake_state_space = FakeState()
fake_model = FakeModel()
fake_reward_builder = FakeRewardBuilder()
fake_agent = FakeAgent(
fake_action_space,
fake_state_space,
fake_model,
fake_reward_builder
)

# WHEN
reset_return = fake_agent.reset()

# THEN
assert reset_return is None
assert fake_agent.previous_action is None
assert fake_agent.previous_state is None

def test_learn(self):
# GIVEN
fake_model = FakeModel()
fake_agent = FakeAgent(None, None, fake_model, None)

# WHEN
learn_return = fake_agent.learn("obs", "reward", "done")

# THEN
assert learn_return is None

def test_update_state(self):
# GIVEN
fake_state_space = FakeState()
fake_agent = FakeAgent(None, fake_state_space, None, None)

# WHEN
update_state_return = fake_agent.update_state("obs")

# THEN
self.assertEqual(update_state_return, fake_state_space.update("obs"))

def get_reward(self):
# GIVEN
fake_reward_builder = FakeRewardBuilder()
fake_agent = FakeAgent(None, None, None, FakeRewardBuilder)

# WHEN
get_reward_return = fake_agent.get_reward("obs", "reward", "done")

# THEN
self.assertEqual(get_reward_return, fake_reward_builder.get_reward(
"obs", "reward", "done"))

def test_state_dim(self):
# GIVEN
fake_state_space = FakeState()

# WHEN
fake_agent = FakeAgent(None, fake_state_space, None, None)

# THEN
self.assertEqual(fake_agent.state_dim, fake_state_space.dimension)
70 changes: 70 additions & 0 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from abc import ABC, abstractmethod

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.models.model_base import ModelBase
from urnai.states.state_base import StateBase


class AgentBase(ABC):

def __init__(self, action_space : ActionSpaceBase,
state_space : StateBase,
model : ModelBase,
reward_builder):

self.action_space = action_space
self.state_space = state_space
self.model = model
self.reward_builder = reward_builder

self.previous_action = None
self.previous_state = None

self.attr_block_list = ['model', 'action_space',
'state_space', 'reward_builder']

@abstractmethod
def step(self) -> None:
...

def reset(self, episode=0) -> None:
"""
Resets some Agent class variables, such as previous_action
and previous_state. Also, calls the respective reset methods
for the action_wrapper and model.
"""
self.previous_action = None
self.previous_state = None
self.action_space.reset()
self.model.ep_reset(episode)
self.reward_builder.reset()
self.state_space.reset()

def learn(self, obs, reward, done) -> None:
"""
If it is not the very first step in an episode, this method will
call the model's learn method.
"""
if self.previous_state is not None:
next_state = self.update_state(obs)
self.model.learn(self.previous_state, self.previous_action,
reward, next_state, done)


def update_state(self, obs) -> list:
"""
Returns the state of the game environment
"""
return self.state_space.update(obs)

def get_reward(self, obs, reward, done) -> None:
"""
Calls the get_reward method from the reward_builder, effectivelly
returning the reward value.
"""
return self.reward_builder.get_reward(obs, reward, done)

@property
def state_dim(self) -> int:
"""Returns the dimensions of the state builder"""
return self.state_space.dimension