Skip to content

Commit de65870

Browse files
Resolves #68 (#89)
* test(issue-68): Add unit tests for ActionSpaceBase Also modified ActionBase's tests to test action name. * refactor(issue-68): Remove name attribute from Action #89 * refactor(issue-68): Tests now in "given when then" format #89
1 parent f98f34b commit de65870

File tree

4 files changed

+125
-2
lines changed

4 files changed

+125
-2
lines changed

tests/units/actions/test_action_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ class TestActionBase(unittest.TestCase):
1313

1414
def test_abstract_methods(self):
1515

16+
# GIVEN
1617
fake_action = FakeAction()
18+
19+
# WHEN
1720
run_return = fake_action.run()
1821
check_return = fake_action.check("observation")
19-
is_complete_return = fake_action.is_complete()
22+
is_complete_return = fake_action.is_complete
23+
24+
# THEN
2025
assert fake_action.__id__ is None
2126
assert isinstance(ActionBase, ABCMeta)
2227
assert run_return is None
2328
assert check_return is None
2429
assert is_complete_return is None
25-
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
from abc import ABCMeta
3+
from unittest.mock import MagicMock
4+
5+
from urnai.actions.action_space_base import ActionSpaceBase
6+
7+
8+
class FakeActionSpace(ActionSpaceBase):
9+
ActionSpaceBase.__abstractmethods__ = set()
10+
...
11+
12+
class TestActionSpaceBase(unittest.TestCase):
13+
14+
def test_abstract_methods(self):
15+
16+
# GIVEN
17+
fake_action_space = FakeActionSpace()
18+
19+
# WHEN
20+
is_action_done_return = fake_action_space.is_action_done()
21+
reset_return = fake_action_space.reset()
22+
get_actions_return = fake_action_space.get_actions()
23+
get_excluded_actions_return = fake_action_space.get_excluded_actions("obs")
24+
get_actions_return = fake_action_space.get_action(0, "obs")
25+
26+
# THEN
27+
assert isinstance(ActionSpaceBase, ABCMeta)
28+
assert is_action_done_return is None
29+
assert reset_return is None
30+
assert get_actions_return is None
31+
assert get_excluded_actions_return is None
32+
33+
def test_get_named_actions(self):
34+
35+
# GIVEN
36+
fake_action_space = FakeActionSpace()
37+
38+
# WHEN
39+
fake_action_space.get_named_actions = MagicMock(return_value=[])
40+
41+
# THEN
42+
self.assertEqual(fake_action_space.get_named_actions(), [])
43+
44+
def test_size(self):
45+
46+
# GIVEN
47+
fake_action_space = FakeActionSpace()
48+
49+
# WHEN
50+
fake_action_space.get_actions = MagicMock(return_value=[])
51+
52+
# THEN
53+
self.assertEqual(fake_action_space.size, 0)

tests/units/actions/test_chain_of_actions_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ class TestChainOfActionsBase(unittest.TestCase):
1414

1515
def test_abstract_methods(self):
1616

17+
# GIVEN
1718
fake_chain_of_actions = FakeChainOfActions()
19+
20+
# WHEN
1821
get_action_return = fake_chain_of_actions.get_action(0)
1922
length_return = fake_chain_of_actions.length
23+
24+
# THEN
2025
assert isinstance(ChainOfActionsBase, ABCMeta)
2126
assert get_action_return is None
2227
assert (length_return == 0) is True

urnai/actions/action_space_base.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class ActionSpaceBase(ABC):
5+
"""
6+
ActionSpace works as an extra abstraction layer used by the agent
7+
to select actions. This means the agent doesn't select actions from
8+
action_set, but from ActionSpace.
9+
10+
This class is responsible of telling the agent which actions it can
11+
use and which ones are excluded from selection. It can also force the
12+
agent to use certain actions by combining them into multiple steps.
13+
"""
14+
15+
@abstractmethod
16+
def is_action_done(self) -> bool:
17+
"""
18+
Some agents must do multiple steps for a single action before they
19+
can choose another one. This method should implement the logic to
20+
tell whether the current action is done or not. That is, if all the
21+
steps for an action are complete.
22+
"""
23+
...
24+
25+
@abstractmethod
26+
def reset(self) -> None:
27+
"""
28+
Contains logic for resetting the action_space. This is used mostly
29+
for agents that require multiple steps for a single action.
30+
"""
31+
...
32+
33+
@abstractmethod
34+
def get_actions(self) -> list[int]:
35+
"""Returns all the actions that the agent can choose from."""
36+
...
37+
38+
@abstractmethod
39+
def get_excluded_actions(self, obs) -> list[int]:
40+
"""Returns a subset of actions that can't be chosen by the agent."""
41+
...
42+
43+
@abstractmethod
44+
def get_action(self, action_idx: int, obs):
45+
"""
46+
Receives an action index as a parameter and returns the corresponding
47+
action from the available actions. This method should return an action
48+
that can be used by the environment's step method.
49+
"""
50+
pass
51+
52+
def get_actions_id(self) -> list[str]:
53+
"""Returns the names of all the actions that the agent can choose from."""
54+
ids = []
55+
for action in self.get_actions():
56+
ids.append(action.__id__)
57+
return ids
58+
59+
@property
60+
def size(self) -> int:
61+
return len(self.get_actions())

0 commit comments

Comments
 (0)