From 84202206dbd54c47fe6d2d92b11431fa0c0f580b Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Mon, 8 Jan 2024 21:05:43 +0900 Subject: [PATCH] [ConnectFour] Separate game specific methods (#1152) --- pgx/connect_four.py | 80 +++++++++++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 29 deletions(-) diff --git a/pgx/connect_four.py b/pgx/connect_four.py index 244cb4eb5..81938fc45 100644 --- a/pgx/connect_four.py +++ b/pgx/connect_four.py @@ -34,7 +34,7 @@ class GameState: # [28, 29, 30, 31, 32, 33, 34], # [35, 36, 37, 38, 39, 40, 41]] _board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1 - _blank_row: Array = jnp.full(7, 5) + winner: Array = jnp.int32(-1) @dataclass @@ -115,47 +115,69 @@ def _init(rng: PRNGKey) -> State: return State(current_player=current_player) # type:ignore +def _step_game_state(state: GameState, action: Array) -> GameState: + board2d = state._board.reshape(6, 7) + num_filled = (board2d[:, action] >= 0).sum() + board2d = board2d.at[5 - num_filled, action].set(state._turn) + won = _win_check(board2d.flatten(), state._turn) + winner = jax.lax.select(won, state._turn, -1) + return state.replace( # type: ignore + _turn=1 - state._turn, + _board=board2d.flatten(), + winner=winner, + ) + + def _step(state: State, action: Array) -> State: - board = state._x._board - row = state._x._blank_row[action] - blank_row = state._x._blank_row.at[action].set(row - 1) - board = board.at[_to_idx(row, action)].set(state._x._turn) - won = _win_check(board, state._x._turn) - reward = jax.lax.cond( - won, - lambda: jnp.float32([-1, -1]).at[state.current_player].set(1), - lambda: jnp.zeros(2, jnp.float32), + x = _step_game_state(state._x, action) + state = state.replace( # type: ignore + current_player=1 - state.current_player, + _x=x, ) + terminated = is_terminal(state._x) + rewards = returns(state._x) + should_flip = state.current_player != state._x._turn + rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards) + rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32)) return state.replace( # type: ignore - current_player=1 - state.current_player, - legal_action_mask=blank_row >= 0, - terminated=won | jnp.all(blank_row == -1), - rewards=reward, - _x=state._x.replace( # type: ignore - _turn=1 - state._x._turn, - _board=board, - _blank_row=blank_row, - ), + legal_action_mask=legal_action_mask(state._x), + rewards=rewards, + terminated=terminated, + ) + + +def returns(state: GameState) -> Array: + return jax.lax.cond( + state.winner >= 0, + lambda: jnp.float32([-1, -1]).at[state.winner].set(1), + lambda: jnp.zeros(2, jnp.float32), ) -def _to_idx(row, col): - return row * 7 + col +def is_terminal(state: GameState) -> Array: + board2d = state._board.reshape(6, 7) + return (state.winner >= 0) | jnp.all((board2d >= 0).sum(axis=0) == 6) + + +def legal_action_mask(state: GameState) -> Array: + board2d = state._board.reshape(6, 7) + return (board2d >= 0).sum(axis=0) < 6 def _win_check(board, turn) -> Array: return ((board[IDX] == turn).all(axis=1)).any() -def _observe(state: State, player_id: Array) -> Array: - turns = jnp.int32([state._x._turn, 1 - state._x._turn]) - turns = jax.lax.cond( - player_id == state.current_player, - lambda: turns, - lambda: jnp.flip(turns), - ) +def _observe_game_state(state: GameState, color: Array) -> Array: + turns = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0])) def make(turn): - return state._x._board.reshape(6, 7) == turn + return state._board.reshape(6, 7) == turn return jnp.stack(jax.vmap(make)(turns), -1) + + +def _observe(state: State, player_id: Array) -> Array: + curr_color = state._x._turn + my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color) + return _observe_game_state(state._x, my_color)