Skip to content

Commit

Permalink
feat: Added custom environment from stablebaselines3
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Sep 30, 2024
1 parent cc69f3e commit 1e7b2ba
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 1 deletion.
16 changes: 16 additions & 0 deletions urnai/rewards/reward_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod


class RewardBase(ABC):
"""
Every Agent needs to own an instance of this base class in order to calculate
its rewards. So every time we want to create a new agent,
we should either use an existing RewardBase implementation or create a new one.
"""

@abstractmethod
def get_reward(self, obs: list[list], default_reward: int, terminated: bool,
truncated: bool) -> int: ...

@abstractmethod
def reset(self) -> None: ...
2 changes: 1 addition & 1 deletion urnai/sc2/environments/sc2environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ def _parse_timestep(
terminated = any(o.player_result for o in self.env_instance._obs)
current_steps = self.env_instance._episode_steps
limit_steps = self.env_instance._episode_length
truncated = current_steps >= limit_steps
truncated = bool(current_steps >= limit_steps)

return obs, reward, terminated, truncated
Empty file.
51 changes: 51 additions & 0 deletions urnai/sc2/environments/stablebaselines3/custom_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import gymnasium as gym
from gymnasium import spaces

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.environments.environment_base import EnvironmentBase
from urnai.rewards.reward_base import RewardBase
from urnai.states.state_base import StateBase


class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface."""

metadata = {"render_modes": ["human"], "render_fps": 30}

def __init__(self, env: EnvironmentBase, state: StateBase,
urnai_action_space: ActionSpaceBase, reward: RewardBase,
observation_space: spaces.Space, action_space: spaces.Space):
super().__init__()

self._env = env
self._state = state
self._action_space = urnai_action_space
self._reward = reward
self._obs = None
# SB3 spaces
self.action_space = action_space
self.observation_space = observation_space

def step(self, action):
action = self._action_space.get_action(action, self._obs)

obs, reward, terminated, truncated = self._env.step(action)

self._obs = obs[0]
obs = self._state.update(self._obs)
reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated)
info = {}
return obs, reward, terminated, truncated, info

def reset(self, seed=None, options=None):
obs = self._env.reset()
self._obs = obs[0]
obs = self._state.update(self._obs)
info = {}
return obs, info

def render(self):
pass

def close(self):
self._env.close()

0 comments on commit 1e7b2ba

Please sign in to comment.