Skip to content

Commit cc53ee1

Browse files
authored
Custom environment from Stable Baselines3 (#103)
* feat: Added custom environment from stablebaselines3 * chore: Added typehints * test: Added unit tests for base reward * fix: Fixed error messages
1 parent cc69f3e commit cc53ee1

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import unittest
2+
from abc import ABCMeta
3+
4+
from urnai.rewards.reward_base import RewardBase
5+
6+
7+
class TestRewardBase(unittest.TestCase):
8+
9+
def test_reset_method(self):
10+
# GIVEN
11+
RewardBase.__abstractmethods__ = set()
12+
13+
class FakeReward(RewardBase):
14+
def __init__(self):
15+
super().__init__()
16+
17+
reward = FakeReward()
18+
19+
# WHEN
20+
reset_return = reward.reset()
21+
22+
# THEN
23+
assert isinstance(RewardBase, ABCMeta)
24+
assert reset_return is None
25+
26+
def test_not_implemented_get_method(self):
27+
# GIVEN
28+
RewardBase.__abstractmethods__ = set()
29+
30+
class FakeReward(RewardBase):
31+
def __init__(self):
32+
super().__init__()
33+
34+
reward = FakeReward()
35+
36+
# WHEN / THEN
37+
self.assertRaises(NotImplementedError, reward.get, [[]], 0, False, False)

urnai/rewards/reward_base.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class RewardBase(ABC):
5+
"""
6+
Every Agent needs to own an instance of this base class in order to calculate
7+
its rewards. So every time we want to create a new agent,
8+
we should either use an existing RewardBase implementation or create a new one.
9+
"""
10+
11+
@abstractmethod
12+
def get(
13+
self,
14+
obs: list[list],
15+
default_reward: int,
16+
terminated: bool,
17+
truncated: bool
18+
) -> int:
19+
raise NotImplementedError("Get method not implemented. You should implement " +
20+
"it in your RewardBase subclass.")
21+
22+
@abstractmethod
23+
def reset(self) -> None: ...

urnai/sc2/environments/sc2environment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,6 @@ def _parse_timestep(
113113
terminated = any(o.player_result for o in self.env_instance._obs)
114114
current_steps = self.env_instance._episode_steps
115115
limit_steps = self.env_instance._episode_length
116-
truncated = current_steps >= limit_steps
116+
truncated = bool(current_steps >= limit_steps)
117117

118118
return obs, reward, terminated, truncated

urnai/sc2/environments/stablebaselines3/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Union
2+
3+
import gymnasium as gym
4+
import numpy as np
5+
from gymnasium import spaces
6+
from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn
7+
8+
from urnai.actions.action_space_base import ActionSpaceBase
9+
from urnai.environments.environment_base import EnvironmentBase
10+
from urnai.rewards.reward_base import RewardBase
11+
from urnai.states.state_base import StateBase
12+
13+
14+
class CustomEnv(gym.Env):
15+
"""Custom Environment that follows gym interface."""
16+
17+
metadata = {"render_modes": ["human"], "render_fps": 30}
18+
19+
def __init__(self, env: EnvironmentBase, state: StateBase,
20+
urnai_action_space: ActionSpaceBase, reward: RewardBase,
21+
observation_space: spaces.Space, action_space: spaces.Space):
22+
super().__init__()
23+
24+
self._env = env
25+
self._state = state
26+
self._action_space = urnai_action_space
27+
self._reward = reward
28+
self._obs = None
29+
# space variables, used internally by the gymnasium library
30+
self.action_space = action_space
31+
self.observation_space = observation_space
32+
33+
def step(
34+
self, action: Union[int, np.ndarray]
35+
) -> GymStepReturn:
36+
action = self._action_space.get_action(action, self._obs)
37+
38+
obs, reward, terminated, truncated = self._env.step(action)
39+
40+
self._obs = obs[0]
41+
obs = self._state.update(self._obs)
42+
reward = self._reward.get(self._obs, reward[0], terminated, truncated)
43+
info = {}
44+
return obs, reward, terminated, truncated, info
45+
46+
def reset(
47+
self, seed: int = None, options: dict = None
48+
) -> GymResetReturn:
49+
obs = self._env.reset()
50+
self._obs = obs[0]
51+
obs = self._state.update(self._obs)
52+
info = {}
53+
return obs, info
54+
55+
def render(self, mode: str) -> None:
56+
raise NotImplementedError("Render method not implemented. If necessary, you " +
57+
"should implement it in your CustomEnv subclass.")
58+
59+
def close(self) -> None:
60+
self._env.close()

0 commit comments

Comments
 (0)