forked from Farama-Foundation/chatarena
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpettingzoo_tictactoe.py
131 lines (105 loc) · 4.29 KB
/
pettingzoo_tictactoe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
from typing import List, Union
# from pettingzoo.classic import tictactoe_v3
from chatarena.environments.base import Environment, TimeStep, register_env
from ..message import Message, MessagePool
def action_string_to_action(action: str) -> int:
pattern = r"(X|O): \((\d), (\d)\)"
match = re.match(pattern, action)
if not match:
return -1
items = [item for item in match.groups()]
coords = [int(coord) for coord in items[1:]]
row, column = coords
if row not in [1, 2, 3]:
return -1
if column not in [1, 2, 3]:
return -1
row = row - 1
column = column - 1
return row + column * 3
@register_env
class PettingzooTicTacToe(Environment):
type_name = "pettingzoo:tictactoe"
def __init__(self, player_names: List[str], **kwargs):
super().__init__(player_names=player_names, **kwargs)
self.env = tictactoe_v3.env()
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self._terminal = False
self.reset()
def reset(self):
self.env.reset()
self.current_player = 0
self.turn = 0
self.message_pool.reset()
obs_dict, reward, terminal, truncation, info = self.env.last()
observation = self.get_observation()
self._terminal = terminal
return TimeStep(observation=observation, reward=reward, terminal=terminal)
def get_next_player(self) -> str:
return self.player_names[self.current_player]
def get_observation(self, player_name=None) -> List[Message]:
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(
player_name, turn=self.turn + 1
)
def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
"""Moderator say something."""
message = Message(
agent_name="Moderator", content=text, turn=self.turn, visible_to=visible_to
)
self.message_pool.append_message(message)
def is_terminal(self) -> bool:
return self._terminal
def step(self, player_name: str, action: str) -> TimeStep:
assert (
player_name == self.get_next_player()
), f"Wrong player! It is {self.get_next_player()} turn."
message = Message(agent_name=player_name, content=action, turn=self.turn)
self.message_pool.append_message(message)
# Convert the action to the AlphaZero format
action_index = action_string_to_action(action)
if action_index == -1:
raise ValueError(f"Invalid action: {action}")
self.env.step(action_index)
obs_dict, reward, terminal, truncation, info = self.env.last()
self._terminal = terminal # Update the terminal state
reward = {
self.player_names[self.current_player]: reward,
self.player_names[1 - self.current_player]: 0,
}
self.current_player = 1 - self.current_player
self.turn += 1
self._moderator_speak("\n" + self.render_ansi(obs_dict["observation"]))
return TimeStep(
observation=self.get_observation(), reward=reward, terminal=terminal
)
def check_action(self, action: str, agent_name: str) -> bool:
# This can be implemented depending on how you want to validate actions for a given agent
action_index = action_string_to_action(action)
if action_index == -1:
return False
elif self.env.last()[0]["action_mask"][action_index] == 0:
return False
else:
return True
def render_ansi(self, observation):
string = ""
observation = observation.transpose(1, 0, 2)
for row in observation:
string += "|"
for column in row:
symbol = "_"
if column[self.current_player] == 1:
symbol = "X"
elif column[1 - self.current_player] == 1:
symbol = "O"
string += " " + symbol + " |"
string += "\n"
return string
def print(self):
obs_dict, reward, terminal, truncation, info = self.env.last()
print(self.render_ansi(obs_dict["observation"]))