-
-
Notifications
You must be signed in to change notification settings - Fork 428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Super suit migration example #1091
Draft
GiovanniGrotto
wants to merge
2
commits into
Farama-Foundation:master
Choose a base branch
from
GiovanniGrotto:SuperSuit_migration
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from reward_lambda import reward_lambda_v0, AecRewardLambda | ||
from observation_lambda import observation_lambda_v0, AecObservationLambda | ||
from utils.basic_transforms import color_reduction | ||
from typing import Literal, Any | ||
from types import ModuleType | ||
from pettingzoo import AECEnv, ParallelEnv | ||
from gymnasium.spaces import Space | ||
import numpy as np | ||
|
||
|
||
def basic_obs_wrapper(env: AECEnv | ParallelEnv, module: ModuleType, param: Any) -> AecObservationLambda: | ||
""" | ||
Wrap an environment to modify its observation space and observations using a specified module and parameter. | ||
|
||
This function takes an environment, a module, and a parameter, and creates a new environment with an observation | ||
space and observations modified based on the provided module and parameter. | ||
|
||
Parameters: | ||
- env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped. | ||
- module: The module responsible for modifying the observation space and observations. | ||
- param: The parameter used to modify the observation space and observations. | ||
|
||
Returns: | ||
- AecObservationLambda: A wrapped environment that applies the observation space and observation modifications. #TODO fix this line | ||
|
||
Example: | ||
```python | ||
modified_env = basic_obs_wrapper(original_env, my_module, my_param) | ||
``` | ||
In the above example, `modified_env` is a new environment that has its observation space and observations modified | ||
according to the `my_module` and `my_param`. | ||
""" | ||
|
||
def change_space(space: Space): # Box? | ||
module.check_param(space, param) | ||
space = module.change_obs_space(space, param) | ||
return space | ||
|
||
def change_obs(obs: np.ndarray, obs_space: Space): # not sure about ndarray | ||
return module.change_observation(obs, obs_space, param) | ||
|
||
return observation_lambda_v0(env, change_obs, change_space) | ||
|
||
|
||
def color_reduction_v0(env: AECEnv | ParallelEnv, mode: Literal["full", "R", "G", "B"] = "full") -> AecObservationLambda: | ||
""" | ||
Wrap an environment to perform color reduction on its observations. | ||
|
||
This function takes an environment and an optional mode to specify the color reduction technique. It then creates | ||
a new environment that performs color reduction on the observations based on the specified mode. | ||
|
||
Parameters: | ||
- env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped. | ||
- mode (Union[str, color_reduction.COLOR_RED_LIST], optional): The color reduction mode to apply (default is "full"). | ||
Valid modes are defined in the color_reduction module. | ||
|
||
Returns: | ||
- AecObservationLambda: A wrapped environment that applies color reduction to its observations. #TODO fix this line | ||
|
||
Example: | ||
```python | ||
reduced_color_env = color_reduction_v0(original_env, mode="grayscale") | ||
``` | ||
In the above example, `reduced_color_env` is a new environment that performs grayscale color reduction on its | ||
observations. | ||
""" | ||
|
||
return basic_obs_wrapper(env, color_reduction, mode) | ||
|
||
|
||
def clip_reward_v0(env: AECEnv | ParallelEnv, lower_bound: float = -1, upper_bound: float = 1) -> AecRewardLambda: | ||
""" | ||
Clip rewards in an environment using the specified lower and upper bounds. | ||
|
||
This function applies a reward clipping transformation to an environment's rewards. It takes an environment and | ||
two optional bounds: `lower_bound` and `upper_bound`. Any reward in the environment that falls below the | ||
`lower_bound` will be set to `lower_bound`, and any reward that exceeds the `upper_bound` will be set to | ||
`upper_bound`. Rewards within the specified range are left unchanged. | ||
|
||
Parameters: | ||
- env (Generic[AgentID, ObsType, ActionType]): The environment on which to apply the reward clipping. | ||
- lower_bound (float, optional): The lower bound for clipping rewards (default is -1). | ||
- upper_bound (float, optional): The upper bound for clipping rewards (default is 1). | ||
|
||
Returns: | ||
- AecRewardLambda: A reward transformation function that applies the specified reward clipping when called. #TODO fix this line | ||
|
||
Example: | ||
```python | ||
clipped_env = clip_reward_v0(my_environment, lower_bound=-0.5, upper_bound=0.5) | ||
``` | ||
In the above example, the rewards in `my_environment` will be clipped to the range [-0.5, 0.5]. | ||
""" | ||
|
||
return reward_lambda_v0(env, lambda rew: max(min(rew, upper_bound), lower_bound)) |
140 changes: 140 additions & 0 deletions
140
pettingzoo/utils/wrappers/supersuit/observation_lambda.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import functools | ||
import numpy as np | ||
from gymnasium.spaces import Box, Discrete | ||
from utils.base_aec_wrapper import BaseWrapper | ||
from typing import Callable | ||
from pettingzoo import AECEnv, ParallelEnv | ||
from pettingzoo.utils.env import ActionType, AgentID | ||
|
||
|
||
class AecObservationLambda(BaseWrapper): | ||
""" | ||
A wrapper for AEC environments that allows the modification of observation spaces and observations. | ||
|
||
Args: | ||
env (AECEnv | ParallelEnv): The environment to be wrapped. | ||
change_observation_fn (Callable): A function that modifies observations. | ||
change_obs_space_fn (Callable, optional): A function that modifies observation spaces. Default is None. | ||
|
||
Raises: | ||
AssertionError: If `change_observation_fn` is not callable, or if `change_obs_space_fn` is provided and is not callable. | ||
|
||
Note: | ||
- The `change_observation_fn` should be a function that accepts observation data and optionally the observation space and agent ID as arguments and returns a modified observation. | ||
- The `change_obs_space_fn` should be a function that accepts an old observation space and optionally the agent ID as arguments and returns a modified observation space. | ||
|
||
Attributes: | ||
change_observation_fn (Callable): The function used to modify observations. | ||
change_obs_space_fn (Callable, optional): The function used to modify observation spaces. | ||
|
||
Methods: | ||
_modify_action(agent: str, action: Discrete) -> Discrete: | ||
Modify the action. | ||
|
||
_check_wrapper_params() -> None: | ||
Check wrapper parameters for consistency. | ||
|
||
observation_space(agent: str) -> Box: | ||
Get the modified observation space for a specific agent. | ||
|
||
_modify_observation(agent: str, observation: Box) -> Box: | ||
Modify the observation. | ||
|
||
""" | ||
def __init__(self, env: AECEnv | ParallelEnv, change_observation_fn: Callable, change_obs_space_fn: Callable = None): | ||
assert callable( | ||
change_observation_fn | ||
), "change_observation_fn needs to be a function. It is {}".format( | ||
change_observation_fn | ||
) | ||
assert change_obs_space_fn is None or callable( | ||
change_obs_space_fn | ||
), "change_obs_space_fn needs to be a function. It is {}".format( | ||
change_obs_space_fn | ||
) | ||
|
||
self.change_observation_fn = change_observation_fn | ||
self.change_obs_space_fn = change_obs_space_fn | ||
|
||
super().__init__(env) | ||
|
||
if hasattr(self, "possible_agents"): | ||
for agent in self.possible_agents: | ||
# call any validation logic in this function | ||
self.observation_space(agent) | ||
|
||
def _modify_action(self, agent: AgentID, action: ActionType) -> ActionType: | ||
""" | ||
Modify the action. | ||
|
||
Args: | ||
agent (str): The agent for which to modify the action. | ||
action (Discrete): The original action. | ||
|
||
Returns: | ||
Discrete: The modified action. | ||
""" | ||
return action | ||
|
||
def _check_wrapper_params(self) -> None: | ||
""" | ||
Check wrapper parameters for consistency. | ||
|
||
Raises: | ||
AssertionError: If the provided parameters are inconsistent. | ||
""" | ||
if self.change_obs_space_fn is None and hasattr(self, "possible_agents"): | ||
for agent in self.possible_agents: | ||
assert isinstance( | ||
self.observation_space(agent), Box | ||
), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces" | ||
|
||
@functools.lru_cache(maxsize=None) | ||
def observation_space(self, agent: AgentID) -> Box: | ||
""" | ||
Get the modified observation space for a specific agent. | ||
|
||
Args: | ||
agent (str): The agent for which to retrieve the observation space. | ||
|
||
Returns: | ||
Box: The modified observation space. | ||
""" | ||
if self.change_obs_space_fn is None: | ||
space = self.env.observation_space(agent) | ||
try: | ||
trans_low = self.change_observation_fn(space.low, space, agent) | ||
trans_high = self.change_observation_fn(space.high, space, agent) | ||
except TypeError: | ||
trans_low = self.change_observation_fn(space.low, space) | ||
trans_high = self.change_observation_fn(space.high, space) | ||
new_low = np.minimum(trans_low, trans_high) | ||
new_high = np.maximum(trans_low, trans_high) | ||
|
||
return Box(low=new_low, high=new_high, dtype=new_low.dtype) | ||
else: | ||
old_obs_space = self.env.observation_space(agent) | ||
try: | ||
return self.change_obs_space_fn(old_obs_space, agent) | ||
except TypeError: | ||
return self.change_obs_space_fn(old_obs_space) | ||
|
||
def _modify_observation(self, agent: AgentID, observation: Box) -> Box: | ||
""" | ||
Modify the observation. | ||
|
||
Args: | ||
agent (str): The agent for which to modify the observation. | ||
observation (Box): The original observation. | ||
|
||
Returns: | ||
Box: The modified observation. | ||
""" | ||
old_obs_space = self.env.observation_space(agent) | ||
try: | ||
return self.change_observation_fn(observation, old_obs_space, agent) | ||
except TypeError: | ||
return self.change_observation_fn(observation, old_obs_space) | ||
|
||
|
||
observation_lambda_v0 = AecObservationLambda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from utils.base_aec_wrapper import PettingzooWrap | ||
from utils.make_defaultdict import make_defaultdict | ||
from typing import Callable | ||
from pettingzoo import AECEnv, ParallelEnv | ||
from pettingzoo.utils.env import ActionType | ||
|
||
|
||
class AecRewardLambda(PettingzooWrap): | ||
""" | ||
A wrapper for AEC environments that allows the modification of rewards. | ||
|
||
Args: | ||
env (AECEnv | ParallelEnv): The environment to be wrapped. | ||
change_reward_fn (Callable): A function that modifies rewards. | ||
|
||
Raises: | ||
AssertionError: If `change_reward_fn` is not callable. | ||
|
||
Attributes: | ||
_change_reward_fn (Callable): The function used to modify rewards. | ||
|
||
Methods: | ||
reset(seed: int = None, options: dict = None) -> None: | ||
Reset the environment, applying the reward modification to initial rewards. | ||
|
||
step(action: ActionType) -> None: | ||
Take a step in the environment, applying the reward modification to the received rewards. | ||
|
||
""" | ||
def __init__(self, env: AECEnv | ParallelEnv, change_reward_fn: Callable): | ||
assert callable( | ||
change_reward_fn | ||
), f"change_reward_fn needs to be a function. It is {change_reward_fn}" | ||
self._change_reward_fn = change_reward_fn | ||
|
||
super().__init__(env) | ||
|
||
def _check_wrapper_params(self) -> None: | ||
""" | ||
Check wrapper parameters for consistency. | ||
|
||
This method is currently empty and does not perform any checks. | ||
""" | ||
pass | ||
|
||
def _modify_spaces(self) -> None: | ||
""" | ||
Modify the spaces of the wrapped environment. | ||
|
||
This method is currently empty and does not modify the spaces. | ||
""" | ||
pass | ||
|
||
def reset(self, seed: int = None, options: dict = None) -> None: | ||
""" | ||
Reset the environment, applying the reward modification to initial rewards. | ||
|
||
Args: | ||
seed (int, optional): A seed for environment randomization. Default is None. | ||
options (dict, optional): Additional options for environment initialization. Default is None. | ||
""" | ||
super().reset(seed=seed, options=options) | ||
self.rewards = { | ||
agent: self._change_reward_fn(reward) | ||
for agent, reward in self.rewards.items() | ||
} | ||
self.__cumulative_rewards = make_defaultdict({a: 0 for a in self.agents}) | ||
self._accumulate_rewards() | ||
|
||
def step(self, action: ActionType) -> None: | ||
""" | ||
Take a step in the environment, applying the reward modification to the received rewards. | ||
|
||
Args: | ||
action (ActionType): The action to be taken in the environment. | ||
""" | ||
agent = self.env.agent_selection | ||
super().step(action) | ||
self.rewards = { | ||
agent: self._change_reward_fn(reward) | ||
for agent, reward in self.rewards.items() | ||
} | ||
self.__cumulative_rewards[agent] = 0 | ||
self._cumulative_rewards = self.__cumulative_rewards | ||
self._accumulate_rewards() | ||
|
||
|
||
reward_lambda_v0 = AecRewardLambda | ||
""" example: | ||
reward_lambda_v0 = WrapperChooser( | ||
aec_wrapper=AecRewardLambda, par_wrapper=ParRewardLambda | ||
)""" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from pettingzoo import AECEnv | ||
from pettingzoo.utils.agent_selector import agent_selector | ||
|
||
|
||
class DummyEnv(AECEnv): | ||
metadata = {"render_modes": ["human"], "is_parallelizable": True} | ||
|
||
def __init__(self, observations, observation_spaces, action_spaces): | ||
super().__init__() | ||
self._observations = observations | ||
self._observation_spaces = observation_spaces | ||
|
||
self.agents = sorted([x for x in observation_spaces.keys()]) | ||
self.possible_agents = self.agents[:] | ||
self._agent_selector = agent_selector(self.agents) | ||
self.agent_selection = self._agent_selector.reset() | ||
self._action_spaces = action_spaces | ||
|
||
self.steps = 0 | ||
|
||
def observation_space(self, agent): | ||
return self._observation_spaces[agent] | ||
|
||
def action_space(self, agent): | ||
return self._action_spaces[agent] | ||
|
||
def observe(self, agent): | ||
return self._observations[agent] | ||
|
||
def step(self, action, observe=True): | ||
if ( | ||
self.terminations[self.agent_selection] | ||
or self.truncations[self.agent_selection] | ||
): | ||
return self._was_dead_step(action) | ||
self._cumulative_rewards[self.agent_selection] = 0 | ||
self.agent_selection = self._agent_selector.next() | ||
self.steps += 1 | ||
if self.steps >= 5 * len(self.agents): | ||
self.truncations = {a: True for a in self.agents} | ||
|
||
self._accumulate_rewards() | ||
self._deads_step_first() | ||
|
||
def reset(self, seed=None, options=None): | ||
self.agents = self.possible_agents[:] | ||
self._agent_selector = agent_selector(self.agents) | ||
self.agent_selection = self._agent_selector.reset() | ||
self.rewards = {a: 1 for a in self.agents} | ||
self._cumulative_rewards = {a: 0 for a in self.agents} | ||
self.terminations = {a: False for a in self.agents} | ||
self.truncations = {a: False for a in self.agents} | ||
self.infos = {a: {} for a in self.agents} | ||
self.steps = 0 | ||
|
||
def close(self): | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could have the docstrings be more similar to the syntax in
conversions.py
and other wrappers in PettingZoo that would be good. Also I don't know the wrapper super well but I thinkmy_module
andmy_param
may not be the correct names. In general using GPT to write this stuff is pretty risky because it could very well just be completely made up and I don't have the time to look through all of the specifics to ensure it's correct. If you can do so yourself and double check then that's great but I'm hesitant to include too much details because it could be incorrect. Look elsewhere throughout the repo to see how the formatting is done for the docstrings.And for the example format, see chess.py as we have an example using the
>>>
format which gets tested under the doctests with pytest.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks for the advice, I'll fix the docstring format and write them from scratch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, definitely not use the output of CGPT directly for docstrings.