|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | from pettingzoo.butterfly import pistonball_v6 |
6 | | -from pettingzoo.classic import texas_holdem_no_limit_v6 |
7 | | -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv |
| 6 | +from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3 |
| 7 | +from pettingzoo.utils.wrappers import ( |
| 8 | + BaseWrapper, |
| 9 | + MultiEpisodeEnv, |
| 10 | + MultiEpisodeParallelEnv, |
| 11 | + TerminateIllegalWrapper, |
| 12 | +) |
8 | 13 |
|
9 | 14 |
|
10 | 15 | @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) |
@@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: |
67 | 72 | assert ( |
68 | 73 | steps == num_episodes * 125 |
69 | 74 | ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." |
| 75 | + |
| 76 | + |
| 77 | +def _do_game(env: TerminateIllegalWrapper, seed: int) -> None: |
| 78 | + """Run a single game with reproducible random moves.""" |
| 79 | + assert isinstance( |
| 80 | + env, TerminateIllegalWrapper |
| 81 | + ), "test_terminate_illegal must use TerminateIllegalWrapper" |
| 82 | + env.reset(seed) |
| 83 | + for agent in env.agents: |
| 84 | + # make the random moves reproducible |
| 85 | + env.action_space(agent).seed(seed) |
| 86 | + |
| 87 | + for agent in env.agent_iter(): |
| 88 | + _, _, termination, truncation, _ = env.last() |
| 89 | + |
| 90 | + if termination or truncation: |
| 91 | + env.step(None) |
| 92 | + else: |
| 93 | + action = env.action_space(agent).sample() |
| 94 | + env.step(action) |
| 95 | + |
| 96 | + |
| 97 | +def test_terminate_illegal() -> None: |
| 98 | + """Test for a problem with terminate illegal wrapper. |
| 99 | +
|
| 100 | + The problem is that env variables, including agent_selection, are set by |
| 101 | + calls from TerminateIllegalWrapper to env functions. However, they are |
| 102 | + called by the wrapper object, not the env so they are set in the wrapper |
| 103 | + object rather than the base env object. When the code later tries to run, |
| 104 | + the values get updated in the env code, but the wrapper pulls it's own |
| 105 | + values that shadow them. |
| 106 | +
|
| 107 | + The test here confirms that is fixed. |
| 108 | + """ |
| 109 | + # not using env() because we need to ensure that the env is |
| 110 | + # wrapped by TerminateIllegalWrapper |
| 111 | + raw_env = tictactoe_v3.raw_env() |
| 112 | + env = TerminateIllegalWrapper(raw_env, illegal_reward=-1) |
| 113 | + |
| 114 | + _do_game(env, 42) |
| 115 | + # bug is triggered by a corrupted state after a game is terminated |
| 116 | + # due to an illegal move. So we need to run the game twice to |
| 117 | + # see the effect. |
| 118 | + _do_game(env, 42) |
| 119 | + |
| 120 | + # get a list of what all the agent_selection values in the wrapper stack |
| 121 | + unwrapped = env |
| 122 | + agent_selections = [] |
| 123 | + while unwrapped != env.unwrapped: |
| 124 | + # the actual value for this wrapper (or None if no value) |
| 125 | + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) |
| 126 | + assert isinstance(unwrapped, BaseWrapper) |
| 127 | + unwrapped = unwrapped.env |
| 128 | + |
| 129 | + # last one from the actual env |
| 130 | + agent_selections.append(unwrapped.__dict__.get("agent_selection", None)) |
| 131 | + |
| 132 | + # remove None from agent_selections |
| 133 | + agent_selections = [x for x in agent_selections if x is not None] |
| 134 | + |
| 135 | + # all values must be the same, or else the wrapper and env are mismatched |
| 136 | + assert len(set(agent_selections)) == 1, "agent_selection mismatch" |
0 commit comments