Skip to content

Commit 0792982

Browse files
committedApr 23, 2024
Fix truncated ES training
1 parent f0a2fdd commit 0792982

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed
 

‎meta/meta.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,20 @@ def create_lpg_train_state(rng, args):
3030
return ESTrainState(train_state, es_strategy, es_params, es_state)
3131

3232

33-
def make_lpg_train_step(args, rollout_manager):
33+
def make_lpg_train_step(args, level_sampler):
3434
lpg_hypers = LpgHyperparams.from_run_args(args)
3535
if args.use_es:
36+
# Train an agent entirely when using ES
37+
lpg_hypers = lpg_hypers.replace(num_agent_updates=level_sampler.max_lifetime)
3638
return partial(
3739
lpg_es_train_step,
38-
rollout_manager=rollout_manager,
40+
rollout_manager=level_sampler.rollout_manager,
3941
num_mini_batches=args.num_mini_batches,
4042
lpg_hypers=lpg_hypers,
4143
)
4244
return partial(
4345
lpg_meta_grad_train_step,
44-
rollout_manager=rollout_manager,
46+
rollout_manager=level_sampler.rollout_manager,
4547
num_mini_batches=args.num_mini_batches,
4648
gamma=args.gamma,
4749
gae_lambda=args.gae_lambda,

‎train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _train_fn(rng):
2727
)
2828

2929
# --- TRAIN LOOP ---
30-
lpg_train_step_fn = make_lpg_train_step(args, level_sampler.rollout_manager)
30+
lpg_train_step_fn = make_lpg_train_step(args, level_sampler)
3131

3232
def _meta_train_loop(carry, _):
3333
rng, train_state, agent_states, value_critic_states, level_buffer = carry

0 commit comments

Comments
 (0)
Please sign in to comment.