From e8135b0facc42d7644b5193092d6dbdd97c66ea5 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 27 Feb 2024 15:03:36 -0300 Subject: [PATCH 1/3] chore(issue-72): Create base class for the state representation --- urnai/states/__init__.py | 0 urnai/states/state_base.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 urnai/states/__init__.py create mode 100644 urnai/states/state_base.py diff --git a/urnai/states/__init__.py b/urnai/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/urnai/states/state_base.py b/urnai/states/state_base.py new file mode 100644 index 00000000..132d2b41 --- /dev/null +++ b/urnai/states/state_base.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Any, List + + +class StateBase(ABC): + """ + Every Agent needs to own an instance of this base class + in order to define its State. + So every time we want to create a new agent, + we should either use an existing State implementation or create a new one. + """ + + @abstractmethod + def update(self, obs) -> List[Any]: + """ + This method receives as a parameter an Observation and returns a State, which + is usually a list of features extracted from the Observation. The Agent uses + this State during training to receive a new action from its model and also to + make it learn, that's why this method should always return a list. + """ + pass + + @property + @abstractmethod + def get_state(self): + """Returns the State currently saved.""" + pass + + @property + @abstractmethod + def get_dimension(self): + """Returns the dimensions of the States returned by the update method.""" + pass + + @abstractmethod + def reset(self): + """Resets the State currently saved.""" + pass \ No newline at end of file From 3637b2fe2d3a2b45ca082fe195369d9154541109 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Wed, 28 Feb 2024 07:11:23 -0300 Subject: [PATCH 2/3] chore(issue-72): Add unit test for State base class --- tests/units/states/__init__.py | 0 tests/units/states/test_state_base.py | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/units/states/__init__.py create mode 100644 tests/units/states/test_state_base.py diff --git a/tests/units/states/__init__.py b/tests/units/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/units/states/test_state_base.py b/tests/units/states/test_state_base.py new file mode 100644 index 00000000..d669b07b --- /dev/null +++ b/tests/units/states/test_state_base.py @@ -0,0 +1,25 @@ +import unittest +from abc import ABCMeta + +from urnai.states.state_base import StateBase + + +class TestStateBase(unittest.TestCase): + + def test_abstract_methods(self): + StateBase.__abstractmethods__ = set() + + class FakeState(StateBase): + def __init__(self): + super().__init__() + + f = FakeState() + update_return = f.update("observation") + state = f.get_state + dimension = f.get_dimension + reset_return = f.reset() + assert isinstance(StateBase, ABCMeta) + assert update_return is None + assert state is None + assert dimension is None + assert reset_return is None From d7c62ccdb24a7e53ff79a20d036cd2825f281bf9 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 1 Mar 2024 07:25:31 -0300 Subject: [PATCH 3/3] refactor(): Change variable and attribute names --- tests/units/states/test_state_base.py | 10 +++++----- urnai/states/state_base.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/units/states/test_state_base.py b/tests/units/states/test_state_base.py index d669b07b..c42b803d 100644 --- a/tests/units/states/test_state_base.py +++ b/tests/units/states/test_state_base.py @@ -13,11 +13,11 @@ class FakeState(StateBase): def __init__(self): super().__init__() - f = FakeState() - update_return = f.update("observation") - state = f.get_state - dimension = f.get_dimension - reset_return = f.reset() + fake_state = FakeState() + update_return = fake_state.update("observation") + state = fake_state.state + dimension = fake_state.dimension + reset_return = fake_state.reset() assert isinstance(StateBase, ABCMeta) assert update_return is None assert state is None diff --git a/urnai/states/state_base.py b/urnai/states/state_base.py index 132d2b41..8e0a9035 100644 --- a/urnai/states/state_base.py +++ b/urnai/states/state_base.py @@ -22,13 +22,13 @@ def update(self, obs) -> List[Any]: @property @abstractmethod - def get_state(self): + def state(self): """Returns the State currently saved.""" pass @property @abstractmethod - def get_dimension(self): + def dimension(self): """Returns the dimensions of the States returned by the update method.""" pass