Skip to content

Commit

Permalink
[ConnectFour] Extract games/connect_four.py (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Jan 8, 2024
1 parent 8420220 commit 3fbe27c
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 125 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _make_connect_four_dwg(dwg, state: ConnectFourState, config):
)

# stones
board = state._x._board
board = state._x.board
for xy, stone in enumerate(board):
if stone == -1:
continue
Expand Down
99 changes: 99 additions & 0 deletions pgx/_src/games/connect_four.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 The Pgx Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp
from jax import Array


class GameState(NamedTuple):
color: Array = jnp.int32(0)
# 6x7 board
# [[ 0, 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12, 13],
# [14, 15, 16, 17, 18, 19, 20],
# [21, 22, 23, 24, 25, 26, 27],
# [28, 29, 30, 31, 32, 33, 34],
# [35, 36, 37, 38, 39, 40, 41]]
board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1
winner: Array = jnp.int32(-1)


class Game:
def init(self) -> GameState:
return GameState()

def step(self, state: GameState, action: Array) -> GameState:
board2d = state.board.reshape(6, 7)
num_filled = (board2d[:, action] >= 0).sum()
board2d = board2d.at[5 - num_filled, action].set(state.color)
won = ((board2d.flatten()[IDX] == state.color).all(axis=1)).any()
winner = jax.lax.select(won, state.color, -1)
return state._replace( # type: ignore
color=1 - state.color,
board=board2d.flatten(),
winner=winner,
)

def observe(self, state: GameState, color: Optional[Array] = None) -> Array:
def make(turn):
return state.board.reshape(6, 7) == turn

turns = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0]))
return jnp.stack(jax.vmap(make)(turns), -1)

def legal_action_mask(self, state: GameState) -> Array:
board2d = state.board.reshape(6, 7)
return (board2d >= 0).sum(axis=0) < 6

def is_terminal(self, state: GameState) -> Array:
board2d = state.board.reshape(6, 7)
return (state.winner >= 0) | jnp.all((board2d >= 0).sum(axis=0) == 6)

def returns(self, state: GameState) -> Array:
return jax.lax.select(
state.winner >= 0,
jnp.float32([-1, -1]).at[state.winner].set(1),
jnp.zeros(2, jnp.float32),
)


def _make_win_cache():
idx = []
# Vertical
for i in range(3):
for j in range(7):
a = i * 7 + j
idx.append([a, a + 7, a + 14, a + 21])
# Horizontal
for i in range(6):
for j in range(4):
a = i * 7 + j
idx.append([a, a + 1, a + 2, a + 3])

# Diagonal
for i in range(3):
for j in range(4):
a = i * 7 + j
idx.append([a, a + 8, a + 16, a + 24])
for i in range(3):
for j in range(3, 7):
a = i * 7 + j
idx.append([a, a + 6, a + 12, a + 18])
return jnp.int32(idx)


IDX = _make_win_cache()
149 changes: 26 additions & 123 deletions pgx/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,18 @@
import jax.numpy as jnp

import pgx.core as core
from pgx._src.games.connect_four import Game, GameState
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)


@dataclass
class GameState:
_turn: Array = jnp.int32(0)
# 6x7 board
# [[ 0, 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12, 13],
# [14, 15, 16, 17, 18, 19, 20],
# [21, 22, 23, 24, 25, 26, 27],
# [28, 29, 30, 31, 32, 33, 34],
# [35, 36, 37, 38, 39, 40, 41]]
_board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1
winner: Array = jnp.int32(-1)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((6, 7, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
terminated: Array = jnp.bool_(False)
truncated: Array = jnp.bool_(False)
legal_action_mask: Array = jnp.ones(7, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_x: GameState = GameState()
Expand All @@ -56,18 +40,38 @@ def env_id(self) -> core.EnvId:
class ConnectFour(core.Env):
def __init__(self):
super().__init__()
self._game = Game()

def _init(self, key: PRNGKey) -> State:
return _init(key)
current_player = jnp.int32(jax.random.bernoulli(key))
return State(current_player=current_player, _x=self._game.init()) # type:ignore

def _step(self, state: core.State, action: Array, key) -> State:
del key
assert isinstance(state, State)
return _step(state, action)
x = self._game.step(state._x, action)
state = state.replace( # type: ignore
current_player=1 - state.current_player,
_x=x,
)
assert isinstance(state, State)
legal_action_mask = self._game.legal_action_mask(state._x)
terminated = self._game.is_terminal(state._x)
rewards = self._game.returns(state._x)
should_flip = state.current_player != state._x.color
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace( # type: ignore
legal_action_mask=legal_action_mask,
rewards=rewards,
terminated=terminated,
)

def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
return _observe(state, player_id)
curr_color = state._x.color
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return self._game.observe(state._x, my_color)

@property
def id(self) -> core.EnvId:
Expand All @@ -80,104 +84,3 @@ def version(self) -> str:
@property
def num_players(self) -> int:
return 2


def _make_win_cache():
idx = []
# Vertical
for i in range(3):
for j in range(7):
a = i * 7 + j
idx.append([a, a + 7, a + 14, a + 21])
# Horizontal
for i in range(6):
for j in range(4):
a = i * 7 + j
idx.append([a, a + 1, a + 2, a + 3])

# Diagonal
for i in range(3):
for j in range(4):
a = i * 7 + j
idx.append([a, a + 8, a + 16, a + 24])
for i in range(3):
for j in range(3, 7):
a = i * 7 + j
idx.append([a, a + 6, a + 12, a + 18])
return jnp.int32(idx)


IDX = _make_win_cache()


def _init(rng: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(current_player=current_player) # type:ignore


def _step_game_state(state: GameState, action: Array) -> GameState:
board2d = state._board.reshape(6, 7)
num_filled = (board2d[:, action] >= 0).sum()
board2d = board2d.at[5 - num_filled, action].set(state._turn)
won = _win_check(board2d.flatten(), state._turn)
winner = jax.lax.select(won, state._turn, -1)
return state.replace( # type: ignore
_turn=1 - state._turn,
_board=board2d.flatten(),
winner=winner,
)


def _step(state: State, action: Array) -> State:
x = _step_game_state(state._x, action)
state = state.replace( # type: ignore
current_player=1 - state.current_player,
_x=x,
)
terminated = is_terminal(state._x)
rewards = returns(state._x)
should_flip = state.current_player != state._x._turn
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace( # type: ignore
legal_action_mask=legal_action_mask(state._x),
rewards=rewards,
terminated=terminated,
)


def returns(state: GameState) -> Array:
return jax.lax.cond(
state.winner >= 0,
lambda: jnp.float32([-1, -1]).at[state.winner].set(1),
lambda: jnp.zeros(2, jnp.float32),
)


def is_terminal(state: GameState) -> Array:
board2d = state._board.reshape(6, 7)
return (state.winner >= 0) | jnp.all((board2d >= 0).sum(axis=0) == 6)


def legal_action_mask(state: GameState) -> Array:
board2d = state._board.reshape(6, 7)
return (board2d >= 0).sum(axis=0) < 6


def _win_check(board, turn) -> Array:
return ((board[IDX] == turn).all(axis=1)).any()


def _observe_game_state(state: GameState, color: Array) -> Array:
turns = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0]))

def make(turn):
return state._board.reshape(6, 7) == turn

return jnp.stack(jax.vmap(make)(turns), -1)


def _observe(state: State, player_id: Array) -> Array:
curr_color = state._x._turn
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return _observe_game_state(state._x, my_color)
2 changes: 1 addition & 1 deletion tests/test_connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_step():
@@.....
"""
# fmt: off
assert (state._x._board == jnp.array(
assert (state._x.board == jnp.array(
[1, 1, -1, -1, -1, -1, -1,
0, 0, -1, -1, -1, -1, -1,
1, 1, -1, -1, -1, -1, -1,
Expand Down

0 comments on commit 3fbe27c

Please sign in to comment.