From 1e7b2bacd5f821498b32a0314380ba59ca06e728 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Mon, 30 Sep 2024 08:25:41 -0300 Subject: [PATCH 1/4] feat: Added custom environment from stablebaselines3 --- urnai/rewards/reward_base.py | 16 ++++++ urnai/sc2/environments/sc2environment.py | 2 +- .../environments/stablebaselines3/__init__.py | 0 .../stablebaselines3/custom_env.py | 51 +++++++++++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 urnai/rewards/reward_base.py create mode 100644 urnai/sc2/environments/stablebaselines3/__init__.py create mode 100644 urnai/sc2/environments/stablebaselines3/custom_env.py diff --git a/urnai/rewards/reward_base.py b/urnai/rewards/reward_base.py new file mode 100644 index 00000000..7d6b0518 --- /dev/null +++ b/urnai/rewards/reward_base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + + +class RewardBase(ABC): + """ + Every Agent needs to own an instance of this base class in order to calculate + its rewards. So every time we want to create a new agent, + we should either use an existing RewardBase implementation or create a new one. + """ + + @abstractmethod + def get_reward(self, obs: list[list], default_reward: int, terminated: bool, + truncated: bool) -> int: ... + + @abstractmethod + def reset(self) -> None: ... diff --git a/urnai/sc2/environments/sc2environment.py b/urnai/sc2/environments/sc2environment.py index bcaf4f74..367ffcad 100644 --- a/urnai/sc2/environments/sc2environment.py +++ b/urnai/sc2/environments/sc2environment.py @@ -113,6 +113,6 @@ def _parse_timestep( terminated = any(o.player_result for o in self.env_instance._obs) current_steps = self.env_instance._episode_steps limit_steps = self.env_instance._episode_length - truncated = current_steps >= limit_steps + truncated = bool(current_steps >= limit_steps) return obs, reward, terminated, truncated diff --git a/urnai/sc2/environments/stablebaselines3/__init__.py b/urnai/sc2/environments/stablebaselines3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/urnai/sc2/environments/stablebaselines3/custom_env.py b/urnai/sc2/environments/stablebaselines3/custom_env.py new file mode 100644 index 00000000..30218c28 --- /dev/null +++ b/urnai/sc2/environments/stablebaselines3/custom_env.py @@ -0,0 +1,51 @@ +import gymnasium as gym +from gymnasium import spaces + +from urnai.actions.action_space_base import ActionSpaceBase +from urnai.environments.environment_base import EnvironmentBase +from urnai.rewards.reward_base import RewardBase +from urnai.states.state_base import StateBase + + +class CustomEnv(gym.Env): + """Custom Environment that follows gym interface.""" + + metadata = {"render_modes": ["human"], "render_fps": 30} + + def __init__(self, env: EnvironmentBase, state: StateBase, + urnai_action_space: ActionSpaceBase, reward: RewardBase, + observation_space: spaces.Space, action_space: spaces.Space): + super().__init__() + + self._env = env + self._state = state + self._action_space = urnai_action_space + self._reward = reward + self._obs = None + # SB3 spaces + self.action_space = action_space + self.observation_space = observation_space + + def step(self, action): + action = self._action_space.get_action(action, self._obs) + + obs, reward, terminated, truncated = self._env.step(action) + + self._obs = obs[0] + obs = self._state.update(self._obs) + reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated) + info = {} + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + obs = self._env.reset() + self._obs = obs[0] + obs = self._state.update(self._obs) + info = {} + return obs, info + + def render(self): + pass + + def close(self): + self._env.close() \ No newline at end of file From 515251d99d407b11b20b36ca54eceee274ad5285 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Mon, 30 Sep 2024 17:36:21 -0300 Subject: [PATCH 2/4] chore: Added typehints --- urnai/rewards/reward_base.py | 10 +++++++-- .../stablebaselines3/custom_env.py | 22 +++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/urnai/rewards/reward_base.py b/urnai/rewards/reward_base.py index 7d6b0518..34b125b1 100644 --- a/urnai/rewards/reward_base.py +++ b/urnai/rewards/reward_base.py @@ -9,8 +9,14 @@ class RewardBase(ABC): """ @abstractmethod - def get_reward(self, obs: list[list], default_reward: int, terminated: bool, - truncated: bool) -> int: ... + def get( + self, + obs: list[list], + default_reward: int, + terminated: bool, + truncated: bool + ) -> int: + raise NotImplementedError(...) @abstractmethod def reset(self) -> None: ... diff --git a/urnai/sc2/environments/stablebaselines3/custom_env.py b/urnai/sc2/environments/stablebaselines3/custom_env.py index 30218c28..2e3c5368 100644 --- a/urnai/sc2/environments/stablebaselines3/custom_env.py +++ b/urnai/sc2/environments/stablebaselines3/custom_env.py @@ -1,5 +1,9 @@ +from typing import Union + import gymnasium as gym +import numpy as np from gymnasium import spaces +from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn from urnai.actions.action_space_base import ActionSpaceBase from urnai.environments.environment_base import EnvironmentBase @@ -22,30 +26,34 @@ def __init__(self, env: EnvironmentBase, state: StateBase, self._action_space = urnai_action_space self._reward = reward self._obs = None - # SB3 spaces + # space variables, used internally by the gymnasium library self.action_space = action_space self.observation_space = observation_space - def step(self, action): + def step( + self, action: Union[int, np.ndarray] + ) -> GymStepReturn: action = self._action_space.get_action(action, self._obs) obs, reward, terminated, truncated = self._env.step(action) self._obs = obs[0] obs = self._state.update(self._obs) - reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated) + reward = self._reward.get(self._obs, reward[0], terminated, truncated) info = {} return obs, reward, terminated, truncated, info - def reset(self, seed=None, options=None): + def reset( + self, seed: int = None, options: dict = None + ) -> GymResetReturn: obs = self._env.reset() self._obs = obs[0] obs = self._state.update(self._obs) info = {} return obs, info - def render(self): - pass + def render(self, mode: str) -> None: + raise NotImplementedError(...) - def close(self): + def close(self) -> None: self._env.close() \ No newline at end of file From cc97280bab2013b428a130a625c2d5476d7bbc89 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 4 Oct 2024 13:00:02 -0300 Subject: [PATCH 3/4] test: Added unit tests for base reward --- tests/units/rewards/test_reward_base.py | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/units/rewards/test_reward_base.py diff --git a/tests/units/rewards/test_reward_base.py b/tests/units/rewards/test_reward_base.py new file mode 100644 index 00000000..b77487f9 --- /dev/null +++ b/tests/units/rewards/test_reward_base.py @@ -0,0 +1,37 @@ +import unittest +from abc import ABCMeta + +from urnai.rewards.reward_base import RewardBase + + +class TestRewardBase(unittest.TestCase): + + def test_reset_method(self): + # GIVEN + RewardBase.__abstractmethods__ = set() + + class FakeReward(RewardBase): + def __init__(self): + super().__init__() + + reward = FakeReward() + + # WHEN + reset_return = reward.reset() + + # THEN + assert isinstance(RewardBase, ABCMeta) + assert reset_return is None + + def test_not_implemented_get_method(self): + # GIVEN + RewardBase.__abstractmethods__ = set() + + class FakeReward(RewardBase): + def __init__(self): + super().__init__() + + reward = FakeReward() + + # WHEN / THEN + self.assertRaises(NotImplementedError, reward.get, [[]], 0, False, False) From 6edabcde2a343ef75dbac9fa79e2013df442237f Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 4 Oct 2024 17:26:49 -0300 Subject: [PATCH 4/4] fix: Fixed error messages --- urnai/rewards/reward_base.py | 3 ++- urnai/sc2/environments/stablebaselines3/custom_env.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/urnai/rewards/reward_base.py b/urnai/rewards/reward_base.py index 34b125b1..2ecf107d 100644 --- a/urnai/rewards/reward_base.py +++ b/urnai/rewards/reward_base.py @@ -16,7 +16,8 @@ def get( terminated: bool, truncated: bool ) -> int: - raise NotImplementedError(...) + raise NotImplementedError("Get method not implemented. You should implement " + + "it in your RewardBase subclass.") @abstractmethod def reset(self) -> None: ... diff --git a/urnai/sc2/environments/stablebaselines3/custom_env.py b/urnai/sc2/environments/stablebaselines3/custom_env.py index 2e3c5368..4a527f24 100644 --- a/urnai/sc2/environments/stablebaselines3/custom_env.py +++ b/urnai/sc2/environments/stablebaselines3/custom_env.py @@ -53,7 +53,8 @@ def reset( return obs, info def render(self, mode: str) -> None: - raise NotImplementedError(...) + raise NotImplementedError("Render method not implemented. If necessary, you " + + "should implement it in your CustomEnv subclass.") def close(self) -> None: self._env.close() \ No newline at end of file