Skip to content

Commit

Permalink
Merge pull request #78 from UFRN-URNAI/issue-72
Browse files Browse the repository at this point in the history
Resolves #72
  • Loading branch information
alvarofpp authored Mar 9, 2024
2 parents 0dfb6fe + d7c62cc commit 6d1f718
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
Empty file added tests/units/states/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions tests/units/states/test_state_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest
from abc import ABCMeta

from urnai.states.state_base import StateBase


class TestStateBase(unittest.TestCase):

def test_abstract_methods(self):
StateBase.__abstractmethods__ = set()

class FakeState(StateBase):
def __init__(self):
super().__init__()

fake_state = FakeState()
update_return = fake_state.update("observation")
state = fake_state.state
dimension = fake_state.dimension
reset_return = fake_state.reset()
assert isinstance(StateBase, ABCMeta)
assert update_return is None
assert state is None
assert dimension is None
assert reset_return is None
Empty file added urnai/states/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions urnai/states/state_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any, List


class StateBase(ABC):
"""
Every Agent needs to own an instance of this base class
in order to define its State.
So every time we want to create a new agent,
we should either use an existing State implementation or create a new one.
"""

@abstractmethod
def update(self, obs) -> List[Any]:
"""
This method receives as a parameter an Observation and returns a State, which
is usually a list of features extracted from the Observation. The Agent uses
this State during training to receive a new action from its model and also to
make it learn, that's why this method should always return a list.
"""
pass

@property
@abstractmethod
def state(self):
"""Returns the State currently saved."""
pass

@property
@abstractmethod
def dimension(self):
"""Returns the dimensions of the States returned by the update method."""
pass

@abstractmethod
def reset(self):
"""Resets the State currently saved."""
pass

0 comments on commit 6d1f718

Please sign in to comment.