Skip to content

Commit

Permalink
Custom environment from Stable Baselines3 (#103)
Browse files Browse the repository at this point in the history
* feat: Added custom environment from stablebaselines3

* chore: Added typehints

* test: Added unit tests for base reward

* fix: Fixed error messages
  • Loading branch information
RickFqt authored Oct 4, 2024
1 parent cc69f3e commit cc53ee1
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 1 deletion.
37 changes: 37 additions & 0 deletions tests/units/rewards/test_reward_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest
from abc import ABCMeta

from urnai.rewards.reward_base import RewardBase


class TestRewardBase(unittest.TestCase):

def test_reset_method(self):
# GIVEN
RewardBase.__abstractmethods__ = set()

class FakeReward(RewardBase):
def __init__(self):
super().__init__()

reward = FakeReward()

# WHEN
reset_return = reward.reset()

# THEN
assert isinstance(RewardBase, ABCMeta)
assert reset_return is None

def test_not_implemented_get_method(self):
# GIVEN
RewardBase.__abstractmethods__ = set()

class FakeReward(RewardBase):
def __init__(self):
super().__init__()

reward = FakeReward()

# WHEN / THEN
self.assertRaises(NotImplementedError, reward.get, [[]], 0, False, False)
23 changes: 23 additions & 0 deletions urnai/rewards/reward_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod


class RewardBase(ABC):
"""
Every Agent needs to own an instance of this base class in order to calculate
its rewards. So every time we want to create a new agent,
we should either use an existing RewardBase implementation or create a new one.
"""

@abstractmethod
def get(
self,
obs: list[list],
default_reward: int,
terminated: bool,
truncated: bool
) -> int:
raise NotImplementedError("Get method not implemented. You should implement " +
"it in your RewardBase subclass.")

@abstractmethod
def reset(self) -> None: ...
2 changes: 1 addition & 1 deletion urnai/sc2/environments/sc2environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ def _parse_timestep(
terminated = any(o.player_result for o in self.env_instance._obs)
current_steps = self.env_instance._episode_steps
limit_steps = self.env_instance._episode_length
truncated = current_steps >= limit_steps
truncated = bool(current_steps >= limit_steps)

return obs, reward, terminated, truncated
Empty file.
60 changes: 60 additions & 0 deletions urnai/sc2/environments/stablebaselines3/custom_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Union

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.environments.environment_base import EnvironmentBase
from urnai.rewards.reward_base import RewardBase
from urnai.states.state_base import StateBase


class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface."""

metadata = {"render_modes": ["human"], "render_fps": 30}

def __init__(self, env: EnvironmentBase, state: StateBase,
urnai_action_space: ActionSpaceBase, reward: RewardBase,
observation_space: spaces.Space, action_space: spaces.Space):
super().__init__()

self._env = env
self._state = state
self._action_space = urnai_action_space
self._reward = reward
self._obs = None
# space variables, used internally by the gymnasium library
self.action_space = action_space
self.observation_space = observation_space

def step(
self, action: Union[int, np.ndarray]
) -> GymStepReturn:
action = self._action_space.get_action(action, self._obs)

obs, reward, terminated, truncated = self._env.step(action)

self._obs = obs[0]
obs = self._state.update(self._obs)
reward = self._reward.get(self._obs, reward[0], terminated, truncated)
info = {}
return obs, reward, terminated, truncated, info

def reset(
self, seed: int = None, options: dict = None
) -> GymResetReturn:
obs = self._env.reset()
self._obs = obs[0]
obs = self._state.update(self._obs)
info = {}
return obs, info

def render(self, mode: str) -> None:
raise NotImplementedError("Render method not implemented. If necessary, you " +
"should implement it in your CustomEnv subclass.")

def close(self) -> None:
self._env.close()

0 comments on commit cc53ee1

Please sign in to comment.