-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom environment from Stable Baselines3 (#103)
* feat: Added custom environment from stablebaselines3 * chore: Added typehints * test: Added unit tests for base reward * fix: Fixed error messages
- Loading branch information
Showing
5 changed files
with
121 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import unittest | ||
from abc import ABCMeta | ||
|
||
from urnai.rewards.reward_base import RewardBase | ||
|
||
|
||
class TestRewardBase(unittest.TestCase): | ||
|
||
def test_reset_method(self): | ||
# GIVEN | ||
RewardBase.__abstractmethods__ = set() | ||
|
||
class FakeReward(RewardBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
reward = FakeReward() | ||
|
||
# WHEN | ||
reset_return = reward.reset() | ||
|
||
# THEN | ||
assert isinstance(RewardBase, ABCMeta) | ||
assert reset_return is None | ||
|
||
def test_not_implemented_get_method(self): | ||
# GIVEN | ||
RewardBase.__abstractmethods__ = set() | ||
|
||
class FakeReward(RewardBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
reward = FakeReward() | ||
|
||
# WHEN / THEN | ||
self.assertRaises(NotImplementedError, reward.get, [[]], 0, False, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class RewardBase(ABC): | ||
""" | ||
Every Agent needs to own an instance of this base class in order to calculate | ||
its rewards. So every time we want to create a new agent, | ||
we should either use an existing RewardBase implementation or create a new one. | ||
""" | ||
|
||
@abstractmethod | ||
def get( | ||
self, | ||
obs: list[list], | ||
default_reward: int, | ||
terminated: bool, | ||
truncated: bool | ||
) -> int: | ||
raise NotImplementedError("Get method not implemented. You should implement " + | ||
"it in your RewardBase subclass.") | ||
|
||
@abstractmethod | ||
def reset(self) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Union | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
from gymnasium import spaces | ||
from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn | ||
|
||
from urnai.actions.action_space_base import ActionSpaceBase | ||
from urnai.environments.environment_base import EnvironmentBase | ||
from urnai.rewards.reward_base import RewardBase | ||
from urnai.states.state_base import StateBase | ||
|
||
|
||
class CustomEnv(gym.Env): | ||
"""Custom Environment that follows gym interface.""" | ||
|
||
metadata = {"render_modes": ["human"], "render_fps": 30} | ||
|
||
def __init__(self, env: EnvironmentBase, state: StateBase, | ||
urnai_action_space: ActionSpaceBase, reward: RewardBase, | ||
observation_space: spaces.Space, action_space: spaces.Space): | ||
super().__init__() | ||
|
||
self._env = env | ||
self._state = state | ||
self._action_space = urnai_action_space | ||
self._reward = reward | ||
self._obs = None | ||
# space variables, used internally by the gymnasium library | ||
self.action_space = action_space | ||
self.observation_space = observation_space | ||
|
||
def step( | ||
self, action: Union[int, np.ndarray] | ||
) -> GymStepReturn: | ||
action = self._action_space.get_action(action, self._obs) | ||
|
||
obs, reward, terminated, truncated = self._env.step(action) | ||
|
||
self._obs = obs[0] | ||
obs = self._state.update(self._obs) | ||
reward = self._reward.get(self._obs, reward[0], terminated, truncated) | ||
info = {} | ||
return obs, reward, terminated, truncated, info | ||
|
||
def reset( | ||
self, seed: int = None, options: dict = None | ||
) -> GymResetReturn: | ||
obs = self._env.reset() | ||
self._obs = obs[0] | ||
obs = self._state.update(self._obs) | ||
info = {} | ||
return obs, info | ||
|
||
def render(self, mode: str) -> None: | ||
raise NotImplementedError("Render method not implemented. If necessary, you " + | ||
"should implement it in your CustomEnv subclass.") | ||
|
||
def close(self) -> None: | ||
self._env.close() |