Skip to content

Commit 1eef080

Browse files
authored
TerminateIllegalWrapper fix (#1206)
1 parent 9f441fe commit 1eef080

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

pettingzoo/utils/wrappers/terminate_illegal.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# pyright reportGeneralTypeIssues=false
21
from __future__ import annotations
32

43
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
@@ -20,6 +19,7 @@ def __init__(
2019
self._illegal_value = illegal_reward
2120
self._prev_obs = None
2221
self._prev_info = None
22+
self._terminated = False # terminated by an illegal move
2323

2424
def reset(self, seed: int | None = None, options: dict | None = None) -> None:
2525
self._terminated = False
@@ -42,7 +42,6 @@ def step(self, action: ActionType) -> None:
4242
if self._prev_obs is None:
4343
self.observe(self.agent_selection)
4444
if isinstance(self._prev_obs, dict):
45-
assert self._prev_obs is not None
4645
assert (
4746
"action_mask" in self._prev_obs
4847
), f"`action_mask` not found in dictionary observation: {self._prev_obs}. Action mask must either be in `observation['action_mask']` or `info['action_mask']` to use TerminateIllegalWrapper."
@@ -60,7 +59,7 @@ def step(self, action: ActionType) -> None:
6059
self.terminations[self.agent_selection]
6160
or self.truncations[self.agent_selection]
6261
):
63-
self._was_dead_step(action) # pyright: ignore[reportGeneralTypeIssues]
62+
self.env.unwrapped._was_dead_step(action)
6463
elif (
6564
not self.terminations[self.agent_selection]
6665
and not self.truncations[self.agent_selection]
@@ -70,12 +69,10 @@ def step(self, action: ActionType) -> None:
7069
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
7170
self.env.unwrapped.terminations = {d: True for d in self.agents}
7271
self.env.unwrapped.truncations = {d: True for d in self.agents}
73-
self._prev_obs = None
74-
self._prev_info = None
7572
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
7673
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
77-
self._accumulate_rewards()
78-
self._deads_step_first()
74+
self.env.unwrapped._accumulate_rewards()
75+
self.env.unwrapped._deads_step_first()
7976
self._terminated = True
8077
else:
8178
super().step(action)

test/wrapper_test.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
import pytest
44

55
from pettingzoo.butterfly import pistonball_v6
6-
from pettingzoo.classic import texas_holdem_no_limit_v6
7-
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv
6+
from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3
7+
from pettingzoo.utils.wrappers import (
8+
BaseWrapper,
9+
MultiEpisodeEnv,
10+
MultiEpisodeParallelEnv,
11+
TerminateIllegalWrapper,
12+
)
813

914

1015
@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
@@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
6772
assert (
6873
steps == num_episodes * 125
6974
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."
75+
76+
77+
def _do_game(env: TerminateIllegalWrapper, seed: int) -> None:
78+
"""Run a single game with reproducible random moves."""
79+
assert isinstance(
80+
env, TerminateIllegalWrapper
81+
), "test_terminate_illegal must use TerminateIllegalWrapper"
82+
env.reset(seed)
83+
for agent in env.agents:
84+
# make the random moves reproducible
85+
env.action_space(agent).seed(seed)
86+
87+
for agent in env.agent_iter():
88+
_, _, termination, truncation, _ = env.last()
89+
90+
if termination or truncation:
91+
env.step(None)
92+
else:
93+
action = env.action_space(agent).sample()
94+
env.step(action)
95+
96+
97+
def test_terminate_illegal() -> None:
98+
"""Test for a problem with terminate illegal wrapper.
99+
100+
The problem is that env variables, including agent_selection, are set by
101+
calls from TerminateIllegalWrapper to env functions. However, they are
102+
called by the wrapper object, not the env so they are set in the wrapper
103+
object rather than the base env object. When the code later tries to run,
104+
the values get updated in the env code, but the wrapper pulls it's own
105+
values that shadow them.
106+
107+
The test here confirms that is fixed.
108+
"""
109+
# not using env() because we need to ensure that the env is
110+
# wrapped by TerminateIllegalWrapper
111+
raw_env = tictactoe_v3.raw_env()
112+
env = TerminateIllegalWrapper(raw_env, illegal_reward=-1)
113+
114+
_do_game(env, 42)
115+
# bug is triggered by a corrupted state after a game is terminated
116+
# due to an illegal move. So we need to run the game twice to
117+
# see the effect.
118+
_do_game(env, 42)
119+
120+
# get a list of what all the agent_selection values in the wrapper stack
121+
unwrapped = env
122+
agent_selections = []
123+
while unwrapped != env.unwrapped:
124+
# the actual value for this wrapper (or None if no value)
125+
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))
126+
assert isinstance(unwrapped, BaseWrapper)
127+
unwrapped = unwrapped.env
128+
129+
# last one from the actual env
130+
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))
131+
132+
# remove None from agent_selections
133+
agent_selections = [x for x in agent_selections if x is not None]
134+
135+
# all values must be the same, or else the wrapper and env are mismatched
136+
assert len(set(agent_selections)) == 1, "agent_selection mismatch"

0 commit comments

Comments
 (0)