Skip to content

Restart training with new data, mid-epoch #436

Open
@schopra8

Description

@schopra8

🚀 Feature

Right now, you can change the datasets used in a CombinedStreamingDataset and resume training on epoch boundaries. It would be great if you could resume training with new datasets mid-epoch.

Motivation

When we're doing curriculum learning, we don't know the right number of steps or epochs to train. If we reach a sufficient validation loss, we kill the training and resume training with a new group of datasets (i.e. adjust the curriculum). Accordingly, we often have to kill training mid-epoch and restart with new datasets.

Pitch

If you are training with N datasets and kill training K steps into epoch N, change the underlying datasets, and resume training from a checkpoint that was saved mid-epoch, the trainer should jump to epoch N + 1 with the new datasets, the old optimizer state, the correct global batch index.

Alternatives

Right now, we have two workarounds:

  • Delete the loops part of the last saved mid-epoch checkpoint, before we resume training with different datasets. This isn't a great solution - because you resume training at epoch 0, which messes with any learning rate schedulers we have.
  • Copy the loops part of the last saved epoch checkpoint into the more recent epoch that was saved mid-epoch, before we resume training with different datasets. This approximately mitigates the learning rate scheduler issue -- but isn't the cleanest solution and is a pain to do manually every time.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions