Skip to content

Commit

Permalink
[Go] rename terminal_values to returns (#1140)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 29, 2023
1 parent 0d704aa commit d181ff8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def is_terminal(self, state: GameState):
two_consecutive_pass = state.consecutive_pass_count >= 2
return two_consecutive_pass | state.is_psk

def terminal_values(self, state: GameState):
def returns(self, state: GameState):
score = _count_point(state, self.size)
reward_bw = jax.lax.select(
score[0] - self.komi > score[1],
Expand Down
2 changes: 1 addition & 1 deletion pgx/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _step(self, state: core.State, action: Array, key) -> State:
state = state.replace(terminated=(state.terminated | _terminated)) # type:ignore
# fmt: on
assert isinstance(state, State)
reward_bw = self._game.terminal_values(state._x)
reward_bw = self._game.returns(state._x)
should_flip = state.current_player == state._x.color
rewards = jax.lax.select(should_flip, reward_bw, jnp.flip(reward_bw))
rewards = jax.lax.select(state.terminated, rewards, jnp.zeros_like(rewards))
Expand Down

0 comments on commit d181ff8

Please sign in to comment.