Skip to content

Commit

Permalink
[Example] Extract loss input computation (#1107)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 10, 2023
1 parent 5fcdf0e commit b410c6f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
1 change: 1 addition & 0 deletions examples/alphazero/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ chess, shogi, and go through self-play"

## Change history

- **[#1107](https://github.com/sotetsuk/pgx/pull/1107)** Extract `compute_loss_input` ([wandb report](https://api.wandb.ai/links/sotetsuk/979hmps8)).
- **[#1106](https://github.com/sotetsuk/pgx/pull/1106)** Use `optax.softmax_cross_entropy` ([wandb report](https://api.wandb.ai/links/sotetsuk/8w0or84k)).
- **[#1088](https://github.com/sotetsuk/pgx/pull/1088)** Adjust to API v2 ([wandb report](https://api.wandb.ai/links/sotetsuk/0g44pjsg)).
- **[#1055](https://github.com/sotetsuk/pgx/pull/1055)** Use default Gumbel AlphaZero hyperparameters ([wandb report](https://api.wandb.ai/links/sotetsuk/o8752t54)).
Expand Down
51 changes: 28 additions & 23 deletions examples/alphazero/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,29 +109,20 @@ def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.St
return recurrent_fn_output, state


class Sample(NamedTuple):
class SelfplayOutput(NamedTuple):
obs: jnp.ndarray
policy_tgt: jnp.ndarray
value_tgt: jnp.ndarray
mask: jnp.ndarray
reward: jnp.ndarray
terminated: jnp.ndarray
action_weights: jnp.ndarray
discount: jnp.ndarray


@jax.pmap
def selfplay(
model,
rng_key: jnp.ndarray,
) -> Sample:
def selfplay(model, rng_key: jnp.ndarray) -> SelfplayOutput:
model_params, model_state = model
batch_size = config.selfplay_batch_size // num_devices

class StepFnOutput(NamedTuple):
obs: jnp.ndarray
reward: jnp.ndarray
terminated: jnp.ndarray
action_weights: jnp.ndarray
discount: jnp.ndarray

def step_fn(state, key) -> StepFnOutput:
def step_fn(state, key) -> SelfplayOutput:
key1, key2 = jax.random.split(key)
observation = state.observation

Expand All @@ -155,7 +146,7 @@ def step_fn(state, key) -> StepFnOutput:
state = jax.vmap(auto_reset(env.step, env.init))(state, policy_output.action, keys)
discount = -1.0 * jnp.ones_like(value)
discount = jnp.where(state.terminated, 0.0, discount)
return state, StepFnOutput(
return state, SelfplayOutput(
obs=observation,
action_weights=policy_output.action_weights,
reward=state.rewards[jnp.arange(state.rewards.shape[0]), actor],
Expand All @@ -170,6 +161,19 @@ def step_fn(state, key) -> StepFnOutput:
key_seq = jax.random.split(rng_key, config.max_num_steps)
_, data = jax.lax.scan(step_fn, state, key_seq)

return data


class Sample(NamedTuple):
obs: jnp.ndarray
policy_tgt: jnp.ndarray
value_tgt: jnp.ndarray
mask: jnp.ndarray


@jax.pmap
def compute_loss_input(data: SelfplayOutput) -> Sample:
batch_size = config.selfplay_batch_size // num_devices
# If episode is truncated, there is no value target
# So when we compute value loss, we need to mask it
value_mask = jnp.cumsum(data.terminated[::-1, :], axis=0)[::-1, :] >= 1
Expand All @@ -195,16 +199,16 @@ def body_fn(carry, i):
)


def loss_fn(model_params, model_state, data: Sample):
def loss_fn(model_params, model_state, samples: Sample):
(logits, value), model_state = forward.apply(
model_params, model_state, data.obs, is_eval=False
model_params, model_state, samples.obs, is_eval=False
)

policy_loss = optax.softmax_cross_entropy(logits, data.policy_tgt)
policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt)
policy_loss = jnp.mean(policy_loss)

value_loss = optax.l2_loss(value, data.value_tgt)
value_loss = jnp.mean(value_loss * data.mask) # mask if the episode is truncated
value_loss = optax.l2_loss(value, samples.value_tgt)
value_loss = jnp.mean(value_loss * samples.mask) # mask if the episode is truncated

return policy_loss + value_loss, (model_state, policy_loss, value_loss)

Expand Down Expand Up @@ -323,7 +327,8 @@ def body_fn(val):
# Selfplay
rng_key, subkey = jax.random.split(rng_key)
keys = jax.random.split(subkey, num_devices)
samples: Sample = selfplay(model, keys)
data: SelfplayOutput = selfplay(model, keys)
samples: Sample = compute_loss_input(data)

# Shuffle samples and make minibatches
samples = jax.device_get(samples) # (#devices, batch, max_num_steps, ...)
Expand Down

0 comments on commit b410c6f

Please sign in to comment.