Skip to content

[torchax] Fix Llama 3.1 405B host memory space OOM #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2025

Conversation

tengyifei
Copy link
Collaborator

This fixes #28. Currently each graph uses >128GiB of host RAM per TPU chip, which is not supported. The OOMing host array is bf16[126, 2, 8192, 16384].

Based on the shape and https://pytorch.org/blog/high-performance-llama-2 I made an informed guess to annotate the decoder input with sharding constraints. That got rid of the OOM and we calculate an MFU of 28.65%.

Training output snippet:

program size: 0.316734 m chars
End compiling 41.968193726148456
Compile time: 41.968193726148456
Flops 447646288314368.0
GB accessed 433.546264576
0 loss 6272 bfloat16 step latency: 132.53140141186304
======
INPUT shape torch.Size([256, 8192])
1 loss 6272 bfloat16 step latency: 41.24516941001639
======
INPUT shape torch.Size([256, 8192])
2 loss 6272 bfloat16 step latency: 40.27528425701894
======
INPUT shape torch.Size([256, 8192])
3 loss 6272 bfloat16 step latency: 40.21413461607881
======
INPUT shape torch.Size([256, 8192])
4 loss 6272 bfloat16 step latency: 40.22760192491114
======
INPUT shape torch.Size([256, 8192])
5 loss 6272 bfloat16 step latency: 41.25447623594664
======
INPUT shape torch.Size([256, 8192])
6 loss 6272 bfloat16 step latency: 42.87271257000975
======

This fixes #28. Currently each graph uses >128GiB of host RAM per TPU
chip, which is not supported. The OOMing host array is
`bf16[126, 2, 8192, 16384]`.

Based on the shape and https://pytorch.org/blog/high-performance-llama-2
I made an informed guess to annotate the decoder input with sharding
constraints. That got rid of the OOM and we calculate an MFU of 28.65%.

Training output snippet:

```
program size: 0.316734 m chars
End compiling 41.968193726148456
Compile time: 41.968193726148456
Flops 447646288314368.0
GB accessed 433.546264576
0 loss 6272 bfloat16 step latency: 132.53140141186304
======
INPUT shape torch.Size([256, 8192])
1 loss 6272 bfloat16 step latency: 41.24516941001639
======
INPUT shape torch.Size([256, 8192])
2 loss 6272 bfloat16 step latency: 40.27528425701894
======
INPUT shape torch.Size([256, 8192])
3 loss 6272 bfloat16 step latency: 40.21413461607881
======
INPUT shape torch.Size([256, 8192])
4 loss 6272 bfloat16 step latency: 40.22760192491114
======
INPUT shape torch.Size([256, 8192])
5 loss 6272 bfloat16 step latency: 41.25447623594664
======
INPUT shape torch.Size([256, 8192])
6 loss 6272 bfloat16 step latency: 42.87271257000975
======
```
@tengyifei tengyifei requested a review from qihqi January 15, 2025 05:45
@qihqi qihqi merged commit 94aafb4 into main Jan 15, 2025
6 checks passed
@tengyifei tengyifei deleted the yifeit/torchax-405b-oom branch January 26, 2025 07:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix the torchax 405B host memory space OOM
2 participants