Skip to content

Support dynamic batch size #34406

@fzyzcjy

Description

@fzyzcjy

Feature request

Hi thanks for the library! When training, I realize that, if a micro batch contains too few tokens, the throughput will be quite bad (i.e. average time per token is large). However, I cannot increase the batch size, because there are long (e.g. 2000 tokens) and short (e.g. 500 tokens) sequences in the training data. The batch size that make short sequences run fast will make long sequences OOM.

Therefore, I am proposing to have dynamic (micro) batch size. For example, suppose we have batch_size=16. Then, before this proposal, we have e.g. micro_batch_size=2 & grad_accum=8. After this proposal, for short sequences, use 4 samples in this micro batch; for long sequences, use 2 samples in this micro batch. After they sum up to 16 samples, we can compute the loss and consider this step is done.

Motivation

(see above)

Your contribution

I am happy to PR

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions