Skip to content

Commit 79de877

Browse files
ffeltenelliottower
andauthored
Update wrappers to use __getattr__ instead of redefining attributes (Farama-Foundation#1140)
Co-authored-by: elliottower <[email protected]>
1 parent 7a67cde commit 79de877

File tree

7 files changed

+44
-154
lines changed

7 files changed

+44
-154
lines changed

pettingzoo/atari/base_atari_env.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,17 @@ def __init__(
169169
self._screen = None
170170
self._seed(seed)
171171

172-
def _seed(self, seed=None):
173-
if seed is None:
174-
_, seed = seeding.np_random()
172+
def _seed(self, seed):
173+
self.np_random, seed = seeding.np_random(seed)
175174
self.ale.setInt(b"random_seed", seed)
176175
self.ale.loadROM(self.rom_path)
177176
self.ale.setMode(self.mode)
178177

179178
def reset(self, seed=None, options=None):
180179
if seed is not None:
181180
self._seed(seed=seed)
181+
else:
182+
self.np_random, seed = seeding.np_random()
182183
self.ale.reset_game()
183184
self.agents = self.possible_agents[:]
184185
self.terminations = {agent: False for agent in self.possible_agents}

pettingzoo/utils/wrappers/base.py

+6-77
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import warnings
43
from typing import Any
54

65
import gymnasium.spaces
@@ -19,72 +18,12 @@ def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
1918
super().__init__()
2019
self.env = env
2120

22-
try:
23-
self.possible_agents = self.env.possible_agents
24-
except AttributeError:
25-
pass
26-
27-
self.metadata = self.env.metadata
28-
29-
# we don't want these defined as we don't want them used before they are gotten
30-
31-
# self.agent_selection = self.env.agent_selection
32-
33-
# self.rewards = self.env.rewards
34-
# self.dones = self.env.dones
35-
36-
# we don't want to care one way or the other whether environments have an infos or not before reset
37-
try:
38-
self.infos = self.env.infos
39-
except AttributeError:
40-
pass
41-
42-
# Not every environment has the .state_space attribute implemented
43-
try:
44-
self.state_space = (
45-
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
46-
)
47-
except AttributeError:
48-
pass
49-
5021
def __getattr__(self, name: str) -> Any:
5122
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
52-
if name.startswith("_"):
23+
if name.startswith("_") and name != "_cumulative_rewards":
5324
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
5425
return getattr(self.env, name)
5526

56-
@property
57-
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
58-
warnings.warn(
59-
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
60-
)
61-
try:
62-
return {
63-
agent: self.observation_space(agent) for agent in self.possible_agents
64-
}
65-
except AttributeError as e:
66-
raise AttributeError(
67-
"The base environment does not have an `observation_spaces` dict attribute. Use the environment's `observation_space` method instead"
68-
) from e
69-
70-
@property
71-
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
72-
warnings.warn(
73-
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
74-
)
75-
try:
76-
return {agent: self.action_space(agent) for agent in self.possible_agents}
77-
except AttributeError as e:
78-
raise AttributeError(
79-
"The base environment does not have an action_spaces dict attribute. Use the environment's `action_space` method instead"
80-
) from e
81-
82-
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
83-
return self.env.observation_space(agent)
84-
85-
def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
86-
return self.env.action_space(agent)
87-
8827
@property
8928
def unwrapped(self) -> AECEnv:
9029
return self.env.unwrapped
@@ -98,14 +37,6 @@ def render(self) -> None | np.ndarray | str | list:
9837
def reset(self, seed: int | None = None, options: dict | None = None):
9938
self.env.reset(seed=seed, options=options)
10039

101-
self.agent_selection = self.env.agent_selection
102-
self.rewards = self.env.rewards
103-
self.terminations = self.env.terminations
104-
self.truncations = self.env.truncations
105-
self.infos = self.env.infos
106-
self.agents = self.env.agents
107-
self._cumulative_rewards = self.env._cumulative_rewards
108-
10940
def observe(self, agent: AgentID) -> ObsType | None:
11041
return self.env.observe(agent)
11142

@@ -115,13 +46,11 @@ def state(self) -> np.ndarray:
11546
def step(self, action: ActionType) -> None:
11647
self.env.step(action)
11748

118-
self.agent_selection = self.env.agent_selection
119-
self.rewards = self.env.rewards
120-
self.terminations = self.env.terminations
121-
self.truncations = self.env.truncations
122-
self.infos = self.env.infos
123-
self.agents = self.env.agents
124-
self._cumulative_rewards = self.env._cumulative_rewards
49+
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
50+
return self.env.observation_space(agent)
51+
52+
def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
53+
return self.env.action_space(agent)
12554

12655
def __str__(self) -> str:
12756
"""Returns a name which looks like: "max_observation<space_invaders_v1>"."""

pettingzoo/utils/wrappers/base_parallel.py

+8-50
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,26 @@
11
from __future__ import annotations
22

3-
import warnings
4-
53
import gymnasium.spaces
64
import numpy as np
7-
from gymnasium.utils import seeding
85

96
from pettingzoo.utils.env import ActionType, AgentID, ObsType, ParallelEnv
107

118

129
class BaseParallelWrapper(ParallelEnv[AgentID, ObsType, ActionType]):
1310
def __init__(self, env: ParallelEnv[AgentID, ObsType, ActionType]):
11+
super().__init__()
1412
self.env = env
1513

16-
self.metadata = env.metadata
17-
try:
18-
self.possible_agents = env.possible_agents
19-
except AttributeError:
20-
pass
21-
22-
# Not every environment has the .state_space attribute implemented
23-
try:
24-
self.state_space = (
25-
self.env.state_space # pyright: ignore[reportGeneralTypeIssues]
26-
)
27-
except AttributeError:
28-
pass
14+
def __getattr__(self, name: str):
15+
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
16+
if name.startswith("_"):
17+
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
18+
return getattr(self.env, name)
2919

3020
def reset(
3121
self, seed: int | None = None, options: dict | None = None
3222
) -> tuple[dict[AgentID, ObsType], dict[AgentID, dict]]:
33-
self.np_random, _ = seeding.np_random(seed)
34-
35-
res, info = self.env.reset(seed=seed, options=options)
36-
self.agents = self.env.agents
37-
return res, info
23+
return self.env.reset(seed=seed, options=options)
3824

3925
def step(
4026
self, actions: dict[AgentID, ActionType]
@@ -45,9 +31,7 @@ def step(
4531
dict[AgentID, bool],
4632
dict[AgentID, dict],
4733
]:
48-
res = self.env.step(actions)
49-
self.agents = self.env.agents
50-
return res
34+
return self.env.step(actions)
5135

5236
def render(self) -> None | np.ndarray | str | list:
5337
return self.env.render()
@@ -62,32 +46,6 @@ def unwrapped(self) -> ParallelEnv:
6246
def state(self) -> np.ndarray:
6347
return self.env.state()
6448

65-
@property
66-
def observation_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
67-
warnings.warn(
68-
"The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead."
69-
)
70-
try:
71-
return {
72-
agent: self.observation_space(agent) for agent in self.possible_agents
73-
}
74-
except AttributeError as e:
75-
raise AttributeError(
76-
"The base environment does not have an `observation_spaces` dict attribute. Use the environments `observation_space` method instead"
77-
) from e
78-
79-
@property
80-
def action_spaces(self) -> dict[AgentID, gymnasium.spaces.Space]:
81-
warnings.warn(
82-
"The `action_spaces` dictionary is deprecated. Use the `action_space` function instead."
83-
)
84-
try:
85-
return {agent: self.action_space(agent) for agent in self.possible_agents}
86-
except AttributeError as e:
87-
raise AttributeError(
88-
"The base environment does not have an action_spaces dict attribute. Use the environments `action_space` method instead"
89-
) from e
90-
9149
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
9250
return self.env.observation_space(agent)
9351

pettingzoo/utils/wrappers/multi_episode_env.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,20 @@ def step(self, action: ActionType) -> None:
5959
None:
6060
"""
6161
super().step(action)
62-
if self.agents:
62+
if self.env.agents:
6363
return
6464

6565
# if we've crossed num_episodes, truncate all agents
6666
# and let the environment terminate normally
6767
if self._episodes_elapsed >= self._num_episodes:
68-
self.truncations = {agent: True for agent in self.agents}
68+
self.env.unwrapped.truncations = {agent: True for agent in self.env.agents}
6969
return
7070

7171
# if no more agents and haven't had enough episodes,
7272
# increment the number of episodes and the seed for reset
7373
self._episodes_elapsed += 1
7474
self._seed = self._seed + 1 if self._seed else None
7575
super().reset(seed=self._seed, options=self._options)
76-
self.truncations = {agent: False for agent in self.agents}
77-
self.terminations = {agent: False for agent in self.agents}
7876

7977
def __str__(self) -> str:
8078
"""__str__.

pettingzoo/utils/wrappers/order_enforcing.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def __getattr__(self, value: str) -> Any:
4545
elif value == "render_mode" and hasattr(self.env, "render_mode"):
4646
return self.env.render_mode # pyright: ignore[reportGeneralTypeIssues]
4747
elif value == "possible_agents":
48-
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
48+
try:
49+
return self.env.possible_agents
50+
except AttributeError:
51+
EnvLogger.error_possible_agents_attribute_missing("possible_agents")
4952
elif value == "observation_spaces":
5053
raise AttributeError(
5154
"The base environment does not have an possible_agents attribute. Use the environments `observation_space` method instead"
@@ -58,20 +61,22 @@ def __getattr__(self, value: str) -> Any:
5861
raise AttributeError(
5962
"agent_order has been removed from the API. Please consider using agent_iter instead."
6063
)
61-
elif value in {
62-
"rewards",
63-
"terminations",
64-
"truncations",
65-
"infos",
66-
"agent_selection",
67-
"num_agents",
68-
"agents",
69-
}:
64+
elif (
65+
value
66+
in {
67+
"rewards",
68+
"terminations",
69+
"truncations",
70+
"infos",
71+
"agent_selection",
72+
"num_agents",
73+
"agents",
74+
}
75+
and not self._has_reset
76+
):
7077
raise AttributeError(f"{value} cannot be accessed before reset")
7178
else:
72-
raise AttributeError(
73-
f"'{type(self).__name__}' object has no attribute '{value}'"
74-
)
79+
return super().__getattr__(value)
7580

7681
def render(self) -> None | np.ndarray | str | list:
7782
if not self._has_reset:

pettingzoo/utils/wrappers/terminate_illegal.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ def step(self, action: ActionType) -> None:
6767
and not _prev_action_mask[action]
6868
):
6969
EnvLogger.warn_on_illegal_move()
70-
self._cumulative_rewards[self.agent_selection] = 0
71-
self.terminations = {d: True for d in self.agents}
72-
self.truncations = {d: True for d in self.agents}
70+
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
71+
self.env.unwrapped.terminations = {d: True for d in self.agents}
72+
self.env.unwrapped.truncations = {d: True for d in self.agents}
7373
self._prev_obs = None
7474
self._prev_info = None
75-
self.rewards = {d: 0 for d in self.truncations}
76-
self.rewards[current_agent] = float(self._illegal_value)
75+
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
76+
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
7777
self._accumulate_rewards()
7878
self._deads_step_first()
7979
self._terminated = True

test/action_mask_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from pettingzoo.test import api_test, seed_test
5+
from pettingzoo.test import seed_test
66
from pettingzoo.test.example_envs import (
77
generated_agents_env_action_mask_info_v0,
88
generated_agents_env_action_mask_obs_v0,
@@ -20,7 +20,6 @@
2020
def test_action_mask(env_constructor: Type[AECEnv]):
2121
"""Test that environments function deterministically in cases where action mask is in observation, or in info."""
2222
seed_test(env_constructor)
23-
api_test(env_constructor())
2423

2524
# Step through the environment according to example code given in AEC documentation (following action mask)
2625
env = env_constructor()

0 commit comments

Comments
 (0)