Skip to content

Commit

Permalink
[TicTacToe] Extract games/tic_tac_toe.py (#1149)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Jan 8, 2024
1 parent 8903fbd commit 3277a9b
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 92 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _make_tictactoe_dwg(dwg, state: TictactoeState, config):
)
)

for i, mark in enumerate(state._x._board):
for i, mark in enumerate(state._x.board):
x = i % BOARD_WIDTH
y = i // BOARD_HEIGHT
if mark == 0: # 先手
Expand Down
68 changes: 68 additions & 0 deletions pgx/_src/games/tic_tac_toe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 The Pgx Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp
from jax import Array


class GameState(NamedTuple):
color: Array = jnp.int32(0) # 0 = X, 1 = O
# 0 1 2
# 3 4 5
# 6 7 8
board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1
winner: Array = jnp.int32(-1)


class Game:
def init(self) -> GameState:
return GameState()

def step(self, state: GameState, action: Array) -> GameState:
board = state.board.at[action].set(state.color)
idx = jnp.int32([[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 3, 6], [1, 4, 7], [2, 5, 8], [0, 4, 8], [2, 4, 6]]) # type: ignore
won = (board[idx] == state.color).all(axis=1).any()
winner = jax.lax.select(won, state.color, -1)
return state._replace( # type: ignore
board=state.board.at[action].set(state.color),
color=(state.color + 1) % 2,
winner=winner,
)

def observe(self, state: GameState, color: Optional[Array] = None) -> Array:
if color is None:
color = state.color

@jax.vmap
def plane(i):
return (state.board == i).reshape((3, 3))

x = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0]))
return jnp.stack(plane(x), -1)

def legal_action_mask(self, state: GameState) -> Array:
return state.board < 0

def is_terminal(self, state: GameState) -> Array:
return (state.winner >= 0) | jnp.all(state.board != -1)

def returns(self, state: GameState) -> Array:
return jax.lax.select(
state.winner >= 0,
jnp.float32([-1, -1]).at[state.winner].set(1),
jnp.zeros(2, jnp.float32),
)
101 changes: 22 additions & 79 deletions pgx/tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,18 @@
import jax.numpy as jnp

import pgx.core as core
from pgx._src.games.tic_tac_toe import Game, GameState
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)


@dataclass
class GameState:
_turn: Array = jnp.int32(0)
# 0 1 2
# 3 4 5
# 6 7 8
_board: Array = -jnp.ones(9, jnp.int32) # -1 (empty), 0, 1
winner: Array = jnp.int32(-1)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((3, 3, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
terminated: Array = jnp.bool_(False)
truncated: Array = jnp.bool_(False)
legal_action_mask: Array = jnp.ones(9, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_x: GameState = GameState()
Expand All @@ -52,18 +40,34 @@ def env_id(self) -> core.EnvId:
class TicTacToe(core.Env):
def __init__(self):
super().__init__()
self._game = Game()

def _init(self, key: PRNGKey) -> State:
return _init(key)
current_player = jnp.int32(jax.random.bernoulli(key))
x = self._game.init()
return State(current_player=current_player, _x=x) # type:ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
return _step(state, action)
x = self._game.step(state._x, action)
state = state.replace(current_player=(state.current_player + 1) % 2, _x=x) # type: ignore
legal_action_mask = self._game.legal_action_mask(x)
terminated = self._game.is_terminal(x)
rewards = self._game.returns(x)
assert isinstance(state, State)
should_flip = state.current_player == state._x.color
rewards = jax.lax.select(should_flip, rewards, jnp.flip(rewards))
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace( # type: ignore
legal_action_mask=legal_action_mask, rewards=rewards, terminated=terminated
)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
return _observe(state, player_id)
curr_color = state._x.color
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return self._game.observe(state._x, my_color)

@property
def id(self) -> core.EnvId:
Expand All @@ -76,64 +80,3 @@ def version(self) -> str:
@property
def num_players(self) -> int:
return 2


def _init(rng: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(current_player=current_player) # type:ignore


def _step_game_state(state: GameState, action: Array) -> GameState:
board = state._board.at[action].set(state._turn)
idx = jnp.int32([[0, 1, 2], [3, 4, 5], [6, 7, 8], [0, 3, 6], [1, 4, 7], [2, 5, 8], [0, 4, 8], [2, 4, 6]]) # type: ignore
won = (board[idx] == state._turn).all(axis=1).any()
winner = jax.lax.select(won, state._turn, -1)
return state.replace( # type: ignore
_board=state._board.at[action].set(state._turn),
_turn=(state._turn + 1) % 2,
winner=winner,
)


def _step(state: State, action: Array) -> State:
x = _step_game_state(state._x, action)
state = state.replace(current_player=(state.current_player + 1) % 2, _x=x) # type: ignore
legal_action_mask = _legal_action_mask(x)
terminated = _is_terminal(x)
rewards = _returns(x)
should_flip = state.current_player == state._x._turn
rewards = jax.lax.select(should_flip, rewards, jnp.flip(rewards))
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace(legal_action_mask=legal_action_mask, rewards=rewards, terminated=terminated) # type: ignore


def _legal_action_mask(state: GameState) -> Array:
return state._board < 0


def _is_terminal(state: GameState) -> Array:
return (state.winner >= 0) | jnp.all(state._board != -1)


def _returns(state: GameState) -> Array:
return jax.lax.select(
state.winner >= 0,
jnp.float32([-1, -1]).at[state.winner].set(1),
jnp.zeros(2, jnp.float32),
)


def _observe(state: State, player_id: Array) -> Array:
curr_color = state._x._turn
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return _observe_game_state(state._x, my_color)


def _observe_game_state(state: GameState, color: Array) -> Array:
@jax.vmap
def plane(i):
return (state._board == i).reshape((3, 3))

# flip if player_id is opposite
x = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0]))
return jnp.stack(plane(x), -1)
24 changes: 12 additions & 12 deletions tests/test_tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def test_step():
key = jax.random.PRNGKey(1)
state = init(key=key)
assert state.current_player == 1
assert state._x._turn == 0
assert state._x.color == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(
state._x._board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1])
state._x.board == jnp.int32([-1, -1, -1, -1, -1, -1, -1, -1, -1])
)
assert not state.terminated
# -1 -1 -1
Expand All @@ -35,13 +35,13 @@ def test_step():
action = jnp.int32(4)
state = step(state, action)
assert state.current_player == 0
assert state._x._turn == 1
assert state._x.color == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(
state._x._board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1])
state._x.board == jnp.int32([-1, -1, -1, -1, 0, -1, -1, -1, -1])
)
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
Expand All @@ -52,12 +52,12 @@ def test_step():
action = jnp.int32(0)
state = step(state, action)
assert state.current_player == 1
assert state._x._turn == 0
assert state._x.color == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 1, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._x._board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state._x.board == jnp.int32([1, -1, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 -1 -1
Expand All @@ -67,12 +67,12 @@ def test_step():
action = jnp.int32(1)
state = step(state, action)
assert state.current_player == 0
assert state._x._turn == 1
assert state._x.color == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state._x.board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, -1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 0 -1
Expand All @@ -82,12 +82,12 @@ def test_step():
action = jnp.int32(8)
state = step(state, action)
assert state.current_player == 1
assert state._x._turn == 0
assert state._x.color == 0
assert jnp.all(
state.legal_action_mask
== jnp.array([0, 0, 1, 1, 0, 1, 1, 1, 0], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1]))
assert jnp.all(state._x.board == jnp.int32([1, 0, -1, -1, 0, -1, -1, -1, 1]))
assert jnp.all(state.rewards == 0) # fmt: ignore
assert not state.terminated
# 1 0 -1
Expand All @@ -97,12 +97,12 @@ def test_step():
action = jnp.int32(7)
state = step(state, action)
assert state.current_player == 0
assert state._x._turn == 1
assert state._x.color == 1
assert jnp.all(
state.legal_action_mask
== jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1], jnp.bool_)
) # fmt: ignore
assert jnp.all(state._x._board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1]))
assert jnp.all(state._x.board == jnp.int32([1, 0, -1, -1, 0, -1, -1, 0, 1]))
assert jnp.all(state.rewards == jnp.int32([-1, 1])) # fmt: ignore
assert state.terminated
# 1 0 -1
Expand Down

0 comments on commit 3277a9b

Please sign in to comment.