Skip to content

Gradient updates in AdversarialTrainer #866

@JurajZelman

Description

@JurajZelman

Feature description

Hi, what is the reason for doing the discriminator optimizer step in common.py outside of the batch loop (see the line 372 with self._disc_opt.step() or the attached snippet below the # do gradient step line)? Was this done to improve the stability of discriminator training?

At the moment, the gradients get accumulated and the optimizer step is done after this aggregation what might in some cases lead to more stable but in other cases also to incredibly slow (if not impossible) training. Therefore, if there is no other reason/issue with this, I think it would be nice to have an option to do the gradient step also inside of the batch loop.

I have noticed this when implementing our AIRL project with the update here. Such update led to quite a significant speedup in training of the discriminator/reward net, so I think having such an option for other people might be quite nice and save them a lot of training time.

for batch in batch_iter:
    disc_logits = self.logits_expert_is_high(
        batch["state"],
        batch["action"],
        batch["next_state"],
        batch["done"],
        batch["log_policy_act_prob"],
    )
    loss = F.binary_cross_entropy_with_logits(
        disc_logits,
        batch["labels_expert_is_one"].float(),
    )

    # Renormalise the loss to be averaged over the whole
    # batch size instead of the minibatch size.
    assert len(batch["state"]) == 2 * self.demo_minibatch_size
    loss *= self.demo_minibatch_size / self.demo_batch_size
    loss.backward()

# do gradient step
self._disc_opt.step()
self._disc_step += 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions