diff --git a/pgx/_src/games/go.py b/pgx/_src/games/go.py index 4335495e6..e3c6664fc 100644 --- a/pgx/_src/games/go.py +++ b/pgx/_src/games/go.py @@ -21,7 +21,7 @@ class GameState(NamedTuple): - color: Array = jnp.int32(0) # 0 = black, 1 = white + step_count: Array = jnp.int32(0) # ids of representative stone id (smallest) in the connected stones # positive for black, negative for white, and zero for empty. chain_id_board: Array = jnp.zeros(19 * 19, dtype=jnp.int32) @@ -31,6 +31,10 @@ class GameState(NamedTuple): ko: Array = jnp.int32(-1) # by SSK is_psk: Array = jnp.bool_(False) + @property + def color(self) -> Array: + return self.step_count % 2 + @property def size(self) -> int: return int(jnp.sqrt(self.chain_id_board.shape[-1]).astype(jnp.int32).item()) @@ -57,7 +61,7 @@ def step(self, state: GameState, action: Array) -> GameState: lambda: _apply_pass(state), ) # increment turns - state = state._replace(color=(state.color + 1) % 2) + state = state._replace(step_count=state.step_count + 1) # update board history board_history = jnp.roll(state.board_history, self.size**2) board_history = board_history.at[0].set(jnp.clip(state.chain_id_board, -1, 1).astype(jnp.int32))