Skip to content

Commit 6729638

Browse files
committed
Add meta-gradient LPG tricks (multiple critic updates, advantage normalization)
1 parent f89dcbd commit 6729638

File tree

4 files changed

+54
-21
lines changed

4 files changed

+54
-21
lines changed

README.md

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,39 @@ echo [KEY] > setup/wandb_key
5757

5858
# Running experiments
5959
Meta-training is executed with `python3.8 train.py`, with all arguments found in [`experiments/parse_args.py`](https://github.com/EmptyJackson/groove/blob/main/experiments/parse_args.py).
60-
* `--log --wandb_entity [entity] --wandb_project [project]` enables logging to WandB.
61-
* `--num_agents [agents]` sets the meta-training batch size.
62-
* `--num_mini_batches [mini_batches]` computes each update in sequential mini-batches, in order to execute large batches with little memory. *RECOMMENDED: lower this to the smallest value that fits in memory.*
63-
* `--debug` disables JIT compilation.
60+
| Argument | Description |
61+
| --- | --- |
62+
| `--env_mode [env_mode]` | Sets the environment mode (below). |
63+
| `--num_agents [agents]` | Sets the meta-training batch size. |
64+
| `--num_mini_batches [mini_batches]` | Computes each update in sequential mini-batches, in order to execute large batches with little memory. *RECOMMENDED: lower this to the smallest value that fits in memory.* |
65+
| `--debug` | Disables JIT compilation. |
66+
| `--log --wandb_entity [entity] --wandb_project [project]` | Enables logging to WandB. |
67+
68+
69+
### Grid-World environments
70+
71+
| Environment mode | Description | Lifetime (# of updates) |
72+
| --- | --- | --- |
73+
|`tabular`|Five tabular levels from [LPG](https://arxiv.org/abs/2007.08794)|Variable|
74+
|`mazes`|Maze levels from [MiniMax](https://github.com/facebookresearch/minimax)|2500|
75+
|`all_shortlife`|Uniformly sampled levels|250|
76+
|`all_vrandlife`|Uniformly sampled levels|10-250 (Log-sampled)|
77+
78+
79+
### Examples
80+
| Experiment | Command |
81+
| --- | --- |
82+
| LPG (meta-gradient) | `python3.8 train.py --num_agents 512 --num_mini_batches 16 --log --wandb_entity [entity] --wandb_project [project]` |
83+
| GROOVE | LPG with `--score_function alg_regret` (algorithmic regret is computed every step due to end-to-end compulation, so currently very inefficient) |
84+
| TA-LPG | LPG with `--num_mini_batches 8 --use_es --lifetime_conditioning --lpg_learning_rate 0.01 --env_mode all_vrandlife` |
85+
6486

6587
### Docker
6688
To execute CPU or GPU docker containers, run the relevant script (with the GPU index as the first argument for the GPU script).
6789
```
6890
./run_gpu.sh [GPU id] python3.8 train.py [args]
6991
```
7092

71-
### Examples
72-
* LPG: `python3.8 train.py --num_agents 512 --num_mini_batches 16 --log --wandb_entity [entity] --wandb_project [project]`
73-
* GROOVE: LPG with `--score_function alg_regret`
74-
* TA-LPG: LPG with `--num_mini_batches 8 --use_es --lifetime_conditioning --lpg_learning_rate 0.01`
75-
7693
# Citation
7794
If you use this implementation in your work, please cite us with the following:
7895
```
@@ -96,5 +113,6 @@ If you use this implementation in your work, please cite us with the following:
96113

97114
# Coming soon
98115

116+
* Speed up GROOVE by removing recomputation of algorithmic regret every step.
99117
* Meta-testing script for checkpointed models.
100118
* Alternative UED metrics (PVL, MaxMC).

agents/lpg_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def _train_step(carry, _):
126126
metrics = LPGAgentMetrics(
127127
pi_l2, actor_entropy, critic_loss, y_l2, critic_entropy
128128
)
129-
return (rng, agent_state), metrics
129+
return (rng, agent_state), (rollout, metrics)
130130

131131
# --- Perform K agent updates ---
132-
carry_out, metrics = jax.lax.scan(
132+
carry_out, (rollout, metrics) = jax.lax.scan(
133133
_train_step,
134134
(rng, agent_state),
135135
None,
136136
length=num_train_steps,
137137
)
138138
_, agent_state = carry_out
139-
return agent_state, jax.tree_map(jnp.mean, metrics)
139+
return agent_state, rollout, jax.tree_map(jnp.mean, metrics)

experiments/parse_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def parse_args(cmd_args=sys.argv[1:]):
1616
"--env_name", help="Environment name", type=str, default="GridWorld-v0"
1717
)
1818
parser.add_argument(
19-
"--env_mode", help="Environment mode", type=str, default="all_vrandlife"
19+
"--env_mode", help="Environment mode", type=str, default="all_shortlife"
2020
)
2121
parser.add_argument(
2222
"--env_workers",

meta/train.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _train_agent(lpg_params, rng, agent_state, value_critic_state):
3939

4040
# --- Perform K agent train steps ---
4141
rng, _rng = jax.random.split(rng)
42-
agent_state, agent_metrics = agent_train_fn(_rng, _lpg_train_state, agent_state)
42+
agent_state, rollouts, agent_metrics = agent_train_fn(
43+
_rng, _lpg_train_state, agent_state
44+
)
4345

4446
# --- Rollout updated agent ---
4547
rng, _rng = jax.random.split(rng)
@@ -56,19 +58,32 @@ def _train_agent(lpg_params, rng, agent_state, value_critic_state):
5658
)
5759

5860
# --- Update value function ---
59-
def _compute_value_loss(critic_params):
61+
def _compute_value_loss(critic_params, rollouts):
6062
value_critic_state.replace(params=critic_params)
6163
value_loss, adv = jax.vmap(
6264
compute_advantage, in_axes=(None, 0, None, None)
63-
)(value_critic_state, eval_rollouts, gamma, gae_lambda)
65+
)(value_critic_state, rollouts, gamma, gae_lambda)
6466
return value_loss.mean(), adv
6567

66-
(value_loss, adv), value_critic_grad = jax.value_and_grad(
67-
_compute_value_loss, has_aux=True
68-
)(value_critic_state.params)
69-
value_critic_state = value_critic_state.apply_gradients(grads=value_critic_grad)
68+
def _update_critic(value_critic_state, rollouts):
69+
losses, value_critic_grad = jax.value_and_grad(
70+
_compute_value_loss, has_aux=True
71+
)(value_critic_state.params, rollouts)
72+
return value_critic_state.apply_gradients(grads=value_critic_grad), losses
73+
74+
# Iteratively update on train rollouts
75+
value_critic_state, _ = jax.lax.scan(
76+
_update_critic, value_critic_state, rollouts
77+
)
78+
# Update critic on evaluation rollout
79+
value_critic_state, (value_loss, adv) = _update_critic(
80+
value_critic_state, eval_rollouts
81+
)
7082

7183
# --- Compute regularized LPG loss ---
84+
# Normalize advantage across batch
85+
adv = jnp.divide(jnp.subtract(adv, jnp.mean(adv)), jnp.std(adv) + 1e-8)
86+
7287
def _compute_lpg_loss(rollout, adv):
7388
actor = agent_state.actor_state
7489
action_probs = actor.apply_fn({"params": actor.params}, rollout.obs)
@@ -157,7 +172,7 @@ def _compute_candidate_fitness(rng, candidate_params, agent_state):
157172
rng, _rng = jax.random.split(rng)
158173

159174
# --- Train an agent using LPG with candidate parameters ---
160-
agent_state, metrics = agent_train_fn(
175+
agent_state, _, metrics = agent_train_fn(
161176
rng=_rng,
162177
lpg_train_state=candidate_train_state,
163178
agent_state=agent_state,

0 commit comments

Comments
 (0)