Skip to content

Add DDP token averaging for equivalent non-parallel training similar to #34191 #34242

@sbwww

Description

@sbwww

Feature request

Token averaging in gradient accumulation was fixed in #34191 . But token averaging in DDP seems to have the same issue.


Expected behaivor

With all the tokens contributing to loss in each step (in each GPU, gradient accumulation step, and microbatch), the equation becomes:

$$ntokens=\sum\limits_{GPUs} \sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

I believe we should average the above tokens at the same time for equivalent non-parallel training.


Current issue

Prior to #34191, the loss/gradients were averaged on $\sum\limits_{GPUs}$, $\sum\limits_{gas}$, and $\sum\limits_{microb}$ separately. And, the introduction of num_items_in_batch in #34191 refers to:

$$ntokens=\sum\limits_{gas} \sum\limits_{microb} (label\neq-100)$$

So, the loss/gradients are now averaged on $\sum\limits_{GPUs}$ and $\left(\sum\limits_{gas}\sum\limits_{microb}\right)$ separately. However, this still does not seem equivalent to non-parallel training.

Can we also incorporate $\sum\limits_{GPUs}$ when determining num_items_in_batch? Something like all_reduce(num_items_in_batch)?

Motivation

DDP seems not fully equivalent to non-parallel training.

related comments: #34191 (comment)

Your contribution

Found some fairseq implementation of this feature

https://github.com/facebookresearch/fairseq/blob/018621f3cca02ca9de945dc082c3fb1a7f9f2deb/fairseq/trainer.py#L932-L949

Metadata

Metadata

Assignees

No one assigned

    Labels

    DiscussionDiscussion on a topic (keep it focused or open a new issue though)Feature requestRequest for a new feature

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions