-
Notifications
You must be signed in to change notification settings - Fork 281
Description
Feature description
Hi, what is the reason for doing the discriminator optimizer step in common.py
outside of the batch loop (see the line 372 with self._disc_opt.step()
or the attached snippet below the # do gradient step
line)? Was this done to improve the stability of discriminator training?
At the moment, the gradients get accumulated and the optimizer step is done after this aggregation what might in some cases lead to more stable but in other cases also to incredibly slow (if not impossible) training. Therefore, if there is no other reason/issue with this, I think it would be nice to have an option to do the gradient step also inside of the batch loop.
I have noticed this when implementing our AIRL project with the update here. Such update led to quite a significant speedup in training of the discriminator/reward net, so I think having such an option for other people might be quite nice and save them a lot of training time.
for batch in batch_iter:
disc_logits = self.logits_expert_is_high(
batch["state"],
batch["action"],
batch["next_state"],
batch["done"],
batch["log_policy_act_prob"],
)
loss = F.binary_cross_entropy_with_logits(
disc_logits,
batch["labels_expert_is_one"].float(),
)
# Renormalise the loss to be averaged over the whole
# batch size instead of the minibatch size.
assert len(batch["state"]) == 2 * self.demo_minibatch_size
loss *= self.demo_minibatch_size / self.demo_batch_size
loss.backward()
# do gradient step
self._disc_opt.step()
self._disc_step += 1