Description
🚀 Feature
Right now, you can change the datasets used in a CombinedStreamingDataset
and resume training on epoch boundaries. It would be great if you could resume training with new datasets mid-epoch.
Motivation
When we're doing curriculum learning, we don't know the right number of steps or epochs to train. If we reach a sufficient validation loss, we kill the training and resume training with a new group of datasets (i.e. adjust the curriculum). Accordingly, we often have to kill training mid-epoch and restart with new datasets.
Pitch
If you are training with N datasets and kill training K steps into epoch N, change the underlying datasets, and resume training from a checkpoint that was saved mid-epoch, the trainer should jump to epoch N + 1 with the new datasets, the old optimizer state, the correct global batch index.
Alternatives
Right now, we have two workarounds:
- Delete the
loops
part of the last saved mid-epoch checkpoint, before we resume training with different datasets. This isn't a great solution - because you resume training at epoch 0, which messes with any learning rate schedulers we have. - Copy the
loops
part of the last saved epoch checkpoint into the more recent epoch that was saved mid-epoch, before we resume training with different datasets. This approximately mitigates the learning rate scheduler issue -- but isn't the cleanest solution and is a pain to do manually every time.