Skip to content

Commit

Permalink
[ConnectFour] Separate game specific attributes (#1151)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Jan 8, 2024
1 parent ca1d28f commit 1f18c3c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/dwg/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _make_connect_four_dwg(dwg, state: ConnectFourState, config):
)

# stones
board = state._board
board = state._x._board
for xy, stone in enumerate(board):
if stone == -1:
continue
Expand Down
44 changes: 25 additions & 19 deletions pgx/connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,7 @@


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((6, 7, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(7, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Connect Four specific ---
class GameState:
_turn: Array = jnp.int32(0)
# 6x7 board
# [[ 0, 1, 2, 3, 4, 5, 6],
Expand All @@ -44,6 +36,18 @@ class State(core.State):
_board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1
_blank_row: Array = jnp.full(7, 5)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((6, 7, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(7, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
return "connect_four"
Expand Down Expand Up @@ -112,11 +116,11 @@ def _init(rng: PRNGKey) -> State:


def _step(state: State, action: Array) -> State:
board = state._board
row = state._blank_row[action]
blank_row = state._blank_row.at[action].set(row - 1)
board = board.at[_to_idx(row, action)].set(state._turn)
won = _win_check(board, state._turn)
board = state._x._board
row = state._x._blank_row[action]
blank_row = state._x._blank_row.at[action].set(row - 1)
board = board.at[_to_idx(row, action)].set(state._x._turn)
won = _win_check(board, state._x._turn)
reward = jax.lax.cond(
won,
lambda: jnp.float32([-1, -1]).at[state.current_player].set(1),
Expand All @@ -125,11 +129,13 @@ def _step(state: State, action: Array) -> State:
return state.replace( # type: ignore
current_player=1 - state.current_player,
legal_action_mask=blank_row >= 0,
_turn=1 - state._turn,
_board=board,
_blank_row=blank_row,
terminated=won | jnp.all(blank_row == -1),
rewards=reward,
_x=state._x.replace( # type: ignore
_turn=1 - state._x._turn,
_board=board,
_blank_row=blank_row,
),
)


Expand All @@ -142,14 +148,14 @@ def _win_check(board, turn) -> Array:


def _observe(state: State, player_id: Array) -> Array:
turns = jnp.int32([state._turn, 1 - state._turn])
turns = jnp.int32([state._x._turn, 1 - state._x._turn])
turns = jax.lax.cond(
player_id == state.current_player,
lambda: turns,
lambda: jnp.flip(turns),
)

def make(turn):
return state._board.reshape(6, 7) == turn
return state._x._board.reshape(6, 7) == turn

return jnp.stack(jax.vmap(make)(turns), -1)
2 changes: 1 addition & 1 deletion tests/test_connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_step():
@@.....
"""
# fmt: off
assert (state._board == jnp.array(
assert (state._x._board == jnp.array(
[1, 1, -1, -1, -1, -1, -1,
0, 0, -1, -1, -1, -1, -1,
1, 1, -1, -1, -1, -1, -1,
Expand Down

0 comments on commit 1f18c3c

Please sign in to comment.