Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves #76 (2) #94

Merged
merged 4 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/units/agents/test_agent_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
from abc import ABCMeta

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.agents.agent_base import AgentBase
from urnai.models.model_base import ModelBase
from urnai.states.state_base import StateBase


class FakeAgent(AgentBase):
AgentBase.__abstractmethods__ = set()

def __init__(self, action_space,
state_space : StateBase,
model : ModelBase,
reward):
super().__init__(action_space, state_space, model, reward)

class FakeState(StateBase):
StateBase.__abstractmethods__ = set()

class FakeModel(ModelBase):
ModelBase.__abstractmethods__ = set()

def ep_reset(self, ep):
return None

class FakeReward:
def get_reward(self, obs, reward, done):
return None
def reset(self):
return None

class FakeActionSpace(ActionSpaceBase):
ActionSpaceBase.__abstractmethods__ = set()

class TestAgentBase(unittest.TestCase):

def test_abstract_methods(self):
# GIVEN
fake_agent = FakeAgent(None, None, None, None)

# WHEN
step_return = fake_agent.step()
choose_action_return = fake_agent.choose_action(FakeActionSpace())

# THEN
assert isinstance(AgentBase, ABCMeta)
assert step_return is None
assert choose_action_return is None

def test_reset(self):
# GIVEN
fake_action_space = FakeActionSpace()
fake_state_space = FakeState()
fake_model = FakeModel()
fake_reward = FakeReward()
fake_agent = FakeAgent(
fake_action_space,
fake_state_space,
fake_model,
fake_reward
)

# WHEN
reset_return = fake_agent.reset()

# THEN
assert reset_return is None
assert fake_agent.previous_action is None
assert fake_agent.previous_state is None

def test_learn(self):
# GIVEN
fake_model = FakeModel()
fake_agent = FakeAgent(None, None, fake_model, None)

# WHEN
learn_return = fake_agent.learn("obs", "reward", "done")

# THEN
assert learn_return is None
59 changes: 59 additions & 0 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod

from urnai.actions.action_base import ActionBase
from urnai.actions.action_space_base import ActionSpaceBase
from urnai.models.model_base import ModelBase
from urnai.states.state_base import StateBase


class AgentBase(ABC):

def __init__(self, action_space : ActionSpaceBase,
state_space : StateBase,
model : ModelBase,
reward):

self.action_space = action_space
self.state_space = state_space
self.model = model
self.reward = reward

self.previous_action = None
self.previous_state = None

self.attr_block_list = ['model', 'action_space',
'state_space', 'reward']

@abstractmethod
def step(self) -> None:
...

@abstractmethod
def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
"""
Method that contains the agent's strategy for choosing actions
"""
...

def reset(self, episode=0) -> None:
"""
Resets some Agent class variables, such as previous_action
and previous_state. Also, calls the respective reset methods
for the action_wrapper and model.
"""
self.previous_action = None
self.previous_state = None
self.action_space.reset()
self.model.ep_reset(episode)
self.reward.reset()
self.state_space.reset()

def learn(self, obs, reward, done) -> None:
"""
If it is not the very first step in an episode, this method will
call the model's learn method.
"""
if self.previous_state is not None:
alvarofpp marked this conversation as resolved.
Show resolved Hide resolved
next_state = self.update_state(obs)
self.model.learn(self.previous_state, self.previous_action,
reward, next_state, done)