Skip to content

Commit

Permalink
[ConnectFour] Separate game specific methods (#1152)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Jan 8, 2024
1 parent 1f18c3c commit 8420220
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions pgx/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GameState:
# [28, 29, 30, 31, 32, 33, 34],
# [35, 36, 37, 38, 39, 40, 41]]
_board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1
_blank_row: Array = jnp.full(7, 5)
winner: Array = jnp.int32(-1)


@dataclass
Expand Down Expand Up @@ -115,47 +115,69 @@ def _init(rng: PRNGKey) -> State:
return State(current_player=current_player) # type:ignore


def _step_game_state(state: GameState, action: Array) -> GameState:
board2d = state._board.reshape(6, 7)
num_filled = (board2d[:, action] >= 0).sum()
board2d = board2d.at[5 - num_filled, action].set(state._turn)
won = _win_check(board2d.flatten(), state._turn)
winner = jax.lax.select(won, state._turn, -1)
return state.replace( # type: ignore
_turn=1 - state._turn,
_board=board2d.flatten(),
winner=winner,
)


def _step(state: State, action: Array) -> State:
board = state._x._board
row = state._x._blank_row[action]
blank_row = state._x._blank_row.at[action].set(row - 1)
board = board.at[_to_idx(row, action)].set(state._x._turn)
won = _win_check(board, state._x._turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
lambda: jnp.zeros(2, jnp.float32),
x = _step_game_state(state._x, action)
state = state.replace( # type: ignore
current_player=1 - state.current_player,
_x=x,
)
terminated = is_terminal(state._x)
rewards = returns(state._x)
should_flip = state.current_player != state._x._turn
rewards = jax.lax.select(should_flip, jnp.flip(rewards), rewards)
rewards = jax.lax.select(terminated, rewards, jnp.zeros(2, jnp.float32))
return state.replace( # type: ignore
current_player=1 - state.current_player,
legal_action_mask=blank_row >= 0,
terminated=won | jnp.all(blank_row == -1),
rewards=reward,
_x=state._x.replace( # type: ignore
_turn=1 - state._x._turn,
_board=board,
_blank_row=blank_row,
),
legal_action_mask=legal_action_mask(state._x),
rewards=rewards,
terminated=terminated,
)


def returns(state: GameState) -> Array:
return jax.lax.cond(
state.winner >= 0,
lambda: jnp.float32([-1, -1]).at[state.winner].set(1),
lambda: jnp.zeros(2, jnp.float32),
)


def _to_idx(row, col):
return row * 7 + col
def is_terminal(state: GameState) -> Array:
board2d = state._board.reshape(6, 7)
return (state.winner >= 0) | jnp.all((board2d >= 0).sum(axis=0) == 6)


def legal_action_mask(state: GameState) -> Array:
board2d = state._board.reshape(6, 7)
return (board2d >= 0).sum(axis=0) < 6


def _win_check(board, turn) -> Array:
return ((board[IDX] == turn).all(axis=1)).any()


def _observe(state: State, player_id: Array) -> Array:
turns = jnp.int32([state._x._turn, 1 - state._x._turn])
turns = jax.lax.cond(
player_id == state.current_player,
lambda: turns,
lambda: jnp.flip(turns),
)
def _observe_game_state(state: GameState, color: Array) -> Array:
turns = jax.lax.select(color == 0, jnp.int32([0, 1]), jnp.int32([1, 0]))

def make(turn):
return state._x._board.reshape(6, 7) == turn
return state._board.reshape(6, 7) == turn

return jnp.stack(jax.vmap(make)(turns), -1)


def _observe(state: State, player_id: Array) -> Array:
curr_color = state._x._turn
my_color = jax.lax.select(player_id == state.current_player, curr_color, 1 - curr_color)
return _observe_game_state(state._x, my_color)

0 comments on commit 8420220

Please sign in to comment.