Skip to content

Commit

Permalink
[Go] Tidy (#1137)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 28, 2023
1 parent da29e7e commit 0f6ba86
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 99 deletions.
122 changes: 53 additions & 69 deletions pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,70 +19,63 @@
from jax import Array
from jax import numpy as jnp

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)


class GameState(NamedTuple):
size: Array = jnp.int32(19)
color: Array = jnp.int32(0) # 0 = black, 1 = white
# ids of representative stone id (smallest) in the connected stones
# positive for black, negative for white, and zero for empty.
chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32)
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32)
turn: Array = jnp.int32(0) # 0 = black's turn, 1 = white's turn
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32) # mainly for obs
num_captured_stones: Array = jnp.zeros(2, dtype=jnp.int32) # [b, w]
consecutive_pass_count: Array = jnp.int32(0)
consecutive_pass_count: Array = jnp.int32(0) # two consecutive pass ends the game
ko: Array = jnp.int32(-1) # by SSK
komi: Array = jnp.float32(7.5)
is_psk: Array = FALSE
is_psk: Array = jnp.bool_(False)

@property
def size(self) -> int:
return int(jnp.sqrt(self.chain_id_board.shape[-1]).astype(jnp.int32).item())


class Game:
def __init__(self, size: int = 19, komi: float = 7.5):
def __init__(self, size: int = 19, komi: float = 7.5, history_length: int = 8):
self.size = size
self.komi = komi
self.history_length = history_length

def init(self) -> GameState:
return GameState(
size=jnp.int32(self.size),
chain_id_board=jnp.zeros(self.size**2, dtype=jnp.int32),
board_history=jnp.full((8, self.size**2), 2, dtype=jnp.int32),
komi=jnp.float32(self.komi),
)

def step(self, x: GameState, action: int) -> GameState:
x = x._replace(ko=jnp.int32(-1))

def step(self, state: GameState, action: Array) -> GameState:
state = state._replace(ko=jnp.int32(-1))
# update state
x = jax.lax.cond(
state = jax.lax.cond(
(action < self.size * self.size),
lambda: _not_pass_move(x, action, self.size),
lambda: _pass_move(x),
lambda: _apply_action(state, action, self.size),
lambda: _apply_pass(state),
)

# increment turns
x = x._replace(turn=(x.turn + 1) % 2)

state = state._replace(color=(state.color + 1) % 2)
# update board history
board_history = jnp.roll(x.board_history, self.size**2)
board_history = board_history.at[0].set(jnp.clip(x.chain_id_board, -1, 1).astype(jnp.int32))
x = x._replace(board_history=board_history)

board_history = jnp.roll(state.board_history, self.size**2)
board_history = board_history.at[0].set(jnp.clip(state.chain_id_board, -1, 1).astype(jnp.int32))
state = state._replace(board_history=board_history)
# check PSK
x = x._replace(is_psk=_check_PSK(x))
state = state._replace(is_psk=_check_PSK(state))
return state

return x

def observe(self, x: GameState, my_turn, history_length):
my_color = jnp.int32([1, -1])[my_turn]
def observe(self, state: GameState, color: Array):
my_color_sign = jnp.int32([1, -1])[color]

@jax.vmap
def _make(i):
color = jnp.int32([1, -1])[i % 2] * my_color
return x.board_history[i // 2] == color
c = jnp.int32([1, -1])[i % 2] * my_color_sign
return state.board_history[i // 2] == c

log = _make(jnp.arange(history_length * 2))
color = jnp.full_like(log[0], my_turn) # black=0, white=1
log = _make(jnp.arange(self.history_length * 2))
color = jnp.full_like(log[0], color) # black=0, white=1

return jnp.vstack([log, color]).transpose().reshape((self.size, self.size, -1))

Expand All @@ -95,9 +88,7 @@ def legal_action_mask(self, state: GameState) -> Array:
num_pseudo, idx_sum, idx_squared_sum = _count(state, self.size)

chain_ix = jnp.abs(state.chain_id_board) - 1
# fmt: off
in_atari = (idx_sum[chain_ix] ** 2) == idx_squared_sum[chain_ix] * num_pseudo[chain_ix]
# fmt: on
has_liberty = (state.chain_id_board * my_color > 0) & ~in_atari
kills_opp = (state.chain_id_board * opp_color > 0) & in_atari

Expand All @@ -116,36 +107,36 @@ def is_neighbor_ok(xy):
legal_action_mask = jax.lax.cond(
(state.ko == -1),
lambda: legal_action_mask,
lambda: legal_action_mask.at[state.ko].set(FALSE),
lambda: legal_action_mask.at[state.ko].set(False),
)
return jnp.append(legal_action_mask, TRUE) # pass is always legal
return jnp.append(legal_action_mask, True) # pass is always legal

def is_terminal(self, x: GameState):
two_consecutive_pass = x.consecutive_pass_count >= 2
return two_consecutive_pass | x.is_psk
def is_terminal(self, state: GameState):
two_consecutive_pass = state.consecutive_pass_count >= 2
return two_consecutive_pass | state.is_psk

def terminal_values(self, x: GameState):
score = _count_point(x, self.size)
def terminal_values(self, state: GameState):
score = _count_point(state, self.size)
reward_bw = jax.lax.select(
score[0] - x.komi > score[1],
score[0] - self.komi > score[1],
jnp.array([1, -1], dtype=jnp.float32),
jnp.array([-1, 1], dtype=jnp.float32),
)
to_play = x.turn
reward_bw = jax.lax.select(x.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw)
to_play = state.color
reward_bw = jax.lax.select(state.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), reward_bw)
return reward_bw


def _pass_move(state: GameState) -> GameState:
def _apply_pass(state: GameState) -> GameState:
return state._replace(consecutive_pass_count=state.consecutive_pass_count + 1)


def _not_pass_move(state: GameState, action, size) -> GameState:
def _apply_action(state: GameState, action, size) -> GameState:
state = state._replace(consecutive_pass_count=jnp.int32(0))
xy = action
num_captured_stones_before = state.num_captured_stones[state.turn]
num_captured_stones_before = state.num_captured_stones[state.color]

ko_may_occur = _ko_may_occur(state, xy)
ko_may_occur = _ko_may_occur(state, xy, size)

# Remove killed stones
adj_xy = _neighbour(xy, size)
Expand All @@ -172,13 +163,11 @@ def _not_pass_move(state: GameState, action, size) -> GameState:
state = jax.lax.fori_loop(0, 4, lambda i, s: _merge_around_xy(i, s, xy, size), state)

# Check Ko
# fmt: off
state = jax.lax.cond(
state.num_captured_stones[state.turn] - num_captured_stones_before == 1,
state.num_captured_stones[state.color] - num_captured_stones_before == 1,
lambda: state,
lambda: state._replace(ko=jnp.int32(-1))
lambda: state._replace(ko=jnp.int32(-1)),
)
# fmt: on

return state

Expand Down Expand Up @@ -231,7 +220,7 @@ def _remove_stones(state: GameState, rm_chain_id, rm_stone_xy, ko_may_occur) ->
)
return state._replace(
chain_id_board=chain_id_board,
num_captured_stones=state.num_captured_stones.at[state.turn].add(num_captured_stones),
num_captured_stones=state.num_captured_stones.at[state.color].add(num_captured_stones),
ko=ko,
)

Expand All @@ -247,11 +236,11 @@ def _count(state: GameState, size):
def _count_neighbor(xy):
neighbors = _neighbour(xy, size)
on_board = neighbors != -1
# fmt: off
return (jnp.where(on_board, is_empty[neighbors], ZERO).sum(),
jnp.where(on_board, idx_sum[neighbors], ZERO).sum(),
jnp.where(on_board, idx_squared_sum[neighbors], ZERO).sum())
# fmt: on
return (
jnp.where(on_board, is_empty[neighbors], ZERO).sum(),
jnp.where(on_board, idx_sum[neighbors], ZERO).sum(),
jnp.where(on_board, idx_squared_sum[neighbors], ZERO).sum(),
)

idx = jnp.arange(size**2)
num_pseudo, idx_sum, idx_squared_sum = _count_neighbor(idx)
Expand All @@ -272,15 +261,14 @@ def _idx_squared_sum(x):


def _my_color(state: GameState):
return jnp.int32([1, -1])[state.turn]
return jnp.int32([1, -1])[state.color]


def _opponent_color(state: GameState):
return jnp.int32([-1, 1])[state.turn]
return jnp.int32([-1, 1])[state.color]


def _ko_may_occur(state: GameState, xy: int) -> Array:
size = state.size
def _ko_may_occur(state: GameState, xy: int, size: int) -> Array:
x = xy // size
y = xy % size
oob = jnp.bool_([x - 1 < 0, x + 1 >= size, y - 1 < 0, y + 1 >= size])
Expand Down Expand Up @@ -329,10 +317,8 @@ def _check_PSK(state: GameState):
Anyway, we believe it's effect is very small as PSK rarely happens, especially in 19x19 board.
"""
# fmt: off
not_passed = state.consecutive_pass_count == 0
is_psk = not_passed & (jnp.abs(state.board_history[0] - state.board_history[1:]).sum(axis=1) == 0).any()
# fmt: on
return is_psk


Expand Down Expand Up @@ -363,8 +349,6 @@ def fill_opp(x):
mask = is_opp_neighbours(b)
return jnp.where(mask, -1, b), mask.any()

# fmt off
b, _ = jax.lax.while_loop(lambda x: x[1], fill_opp, (board, TRUE))
# fmt on
b, _ = jax.lax.while_loop(lambda x: x[1], fill_opp, (board, True))

return (b == 0).sum()
16 changes: 2 additions & 14 deletions pgx/_src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,8 @@ def _set_config_by_state(self, _state: State): # noqa: C901
from pgx._src.dwg.go import _make_go_dwg

self.config["GRID_SIZE"] = 25
try:
self.config["BOARD_WIDTH"] = int(
_state._x.size[0] # type:ignore
)
self.config["BOARD_HEIGHT"] = int(
_state._x.size[0] # type:ignore
)
except IndexError:
self.config["BOARD_WIDTH"] = int(
_state._x.size # type: ignore
) # type:ignore
self.config["BOARD_HEIGHT"] = int(
_state._x.size # type: ignore
) # type:ignore
self.config["BOARD_WIDTH"] = _state._x.size # type: ignore
self.config["BOARD_HEIGHT"] = _state._x.size # type: ignore
self._make_dwg_group = _make_go_dwg # type:ignore
if (self.config["COLOR_THEME"] is None and self.config["COLOR_THEME"] == "dark") or self.config[
"COLOR_THEME"
Expand Down
24 changes: 9 additions & 15 deletions pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ class State(core.State):

@property
def env_id(self) -> core.EnvId:
try:
size = int(self._x.size.item())
except TypeError:
size = int(self._x.size[0].item())
return f"go_{size}x{size}" # type: ignore
return f"go_{self._x.size}x{self._x.size}" # type: ignore

@staticmethod
def _from_sgf(sgf: str):
Expand All @@ -58,16 +54,14 @@ def __init__(
):
super().__init__()
assert isinstance(size, int)
self.size = size
self.komi = komi
self.history_length = history_length
self.max_termination_steps = self.size * self.size * 2
self._game = go.Game(size=size, komi=komi)
self.max_termination_steps = size * size * 2
self._game = go.Game(size=size, komi=komi, history_length=history_length)

def _init(self, key: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(key))
size = self._game.size
return State( # type:ignore
legal_action_mask=jnp.ones(self.size**2 + 1, dtype=jnp.bool_),
legal_action_mask=jnp.ones(size**2 + 1, dtype=jnp.bool_),
current_player=current_player,
_x=self._game.init(),
)
Expand All @@ -91,7 +85,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
# fmt: on
assert isinstance(state, State)
reward_bw = self._game.terminal_values(state._x)
should_flip = state.current_player == state._x.turn
should_flip = state.current_player == state._x.color
rewards = jax.lax.select(should_flip, reward_bw, jnp.flip(reward_bw))
rewards = jax.lax.select(state.terminated, rewards, jnp.zeros_like(rewards))
return state.replace(rewards=rewards) # type:ignore
Expand Down Expand Up @@ -130,10 +124,10 @@ def _observe(self, state: core.State, player_id: Array) -> Array:
assert isinstance(state, State)
my_turn = jax.lax.select(
player_id == state.current_player,
state._x.turn,
1 - state._x.turn,
state._x.color,
1 - state._x.color,
)
return self._game.observe(state._x, my_turn, self.history_length)
return self._game.observe(state._x, my_turn)

@property
def id(self) -> core.EnvId:
Expand Down
28 changes: 27 additions & 1 deletion tests/test_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_observe():
)
# fmt: on
assert state.current_player == 1
assert state._x.turn % 2 == 0 # black turn
assert state._x.color % 2 == 0 # black turn
obs = observe(state, 0) # white
assert obs.shape == (5, 5, 17)
assert (obs[:, :, 0] == (curr_board == -1)).all()
Expand Down Expand Up @@ -1192,6 +1192,32 @@ def test_max_step_termination():
assert not (state.rewards == jnp.float32([0, 0])).all() # should not tie


def test_env_id():
env = Go(size=9)
init_fn = jax.jit(env.init)
state = init_fn(jax.random.PRNGKey(0))
assert state.env_id == "go_9x9"
init_fn = jax.jit(jax.vmap(env.init))
state = init_fn(jax.random.split(jax.random.PRNGKey(0)))
assert state.env_id == "go_9x9"

env = Go(size=19)
init_fn = jax.jit(env.init)
state = init_fn(jax.random.PRNGKey(0))
assert state.env_id == "go_19x19"
init_fn = jax.jit(jax.vmap(env.init))
state = init_fn(jax.random.split(jax.random.PRNGKey(0)))
assert state.env_id == "go_19x19"

env = Go(size=5)
init_fn = jax.jit(env.init)
state = init_fn(jax.random.PRNGKey(0))
assert state.env_id == "go_5x5"
init_fn = jax.jit(jax.vmap(env.init))
state = init_fn(jax.random.split(jax.random.PRNGKey(0)))
assert state.env_id == "go_5x5"


def test_api():
import pgx
env = pgx.make("go_9x9")
Expand Down

0 comments on commit 0f6ba86

Please sign in to comment.