Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing experimental gradient accumulation API #8584

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

rpsilva-aws
Copy link
Contributor

@rpsilva-aws rpsilva-aws commented Jan 16, 2025

In this PR, we introduce experimental.gradient_accumulation which leverages XLA's While op to accumulate gradients.

Training loop with traditional gradient accumulation
===> Preparing data..
Epoch 0 step 8 loss 1.1098170280456543
Epoch 0 step 16 loss 1.1719611883163452
Epoch 0 step 24 loss 3.453134536743164
Epoch 0 step 32 loss 2.518792152404785
Epoch 0 step 40 loss 6.67546272277832
Epoch 0 step 48 loss 4.609560012817383
Epoch 0 step 56 loss 5.953202247619629
Epoch 0 step 64 loss 1.325960636138916
Training loop with XLA's `While` gradient accumulation
===> Preparing data..
Epoch 0 step 8 loss 1.1098170280456543
Epoch 0 step 16 loss 1.1719611883163452
Epoch 0 step 24 loss 3.453134536743164
Epoch 0 step 32 loss 2.518792152404785
Epoch 0 step 40 loss 6.67546272277832
Epoch 0 step 48 loss 4.609560012817383
Epoch 0 step 56 loss 5.953202247619629
Epoch 0 step 64 loss 1.325960636138916

@rpsilva-aws rpsilva-aws marked this pull request as ready for review January 16, 2025 19:28
@rpsilva-aws
Copy link
Contributor Author

@jeffhataws @tengyifei

@tengyifei
Copy link
Collaborator

@rpsilva-aws do you plan on merging this into r2.6?

@rpsilva-aws
Copy link
Contributor Author

@tengyifei Ideally, yes. It's perfectly fine for the 3-layer MLP, but we're seeing a small difference for Llama runs (difference being, from a previous local patch set that was just before cleaning some of the code), so we're just quickly identifying what it is.

@tengyifei
Copy link
Collaborator

Okay, please aim to sort out all critical issues by Jan 21 if you're aiming for 2.6 so that we could review and cherrypick it by Jan 22. 2.6 release is quicking drawing in and I would like a few days to test all the builds.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 3 times, most recently from 08831d6 to 567ccb5 Compare January 21, 2025 23:25
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_v2 branch 2 times, most recently from 4589eb2 to dfbef15 Compare January 22, 2025 01:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants