-
Notifications
You must be signed in to change notification settings - Fork 5
[torch_xla] MVP correctness and convergence check for Llama 3.0 8B #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I have the loss for 500 steps with the following: Configuration:
HF code snapshot link for generating data. ![]() The concern is that TP doesn't reach the same loss as what HF can provide. I will check more detail on the optimizer and data representations to see if there are any difference. |
That's a good catch |
Maybe we need to drop the last batch from the data loader? (HF config had a If the batch size is too small for the last batch, maybe that could cause the loss spikes in the TP run. |
Right, I think this can be different, but when I check the epoch number and step number, HF run and TP run matches, which means they reset the dataloader iterator at the same time and start a new epoch. Below are the last few loss data entries from HF and TP: HF data link:
TP data link:
Let me double check the dataloader config. |
The loss function for TP I assume is def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, vocab_size: int):
"""
Computes cross entropy loss of `logits` against the ground truth `labels` during
next token prediction.
Useful as the loss function of a LLM in pretraining or supervised finetuning.
"""
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
return loss_fct(shift_logits, shift_labels) Where-as the Loss Function for HF is Label smoothed loss function ; could you please verify if we are using same loss function |
This convergence gap still exists as of the latest I'll now be looking into this. Areas to investigate:
Appendixtorchprime training command:
huggingface training command:
|
re loss function: I ran Hugging Face trainer and set a breakpoint in the
This indicates that Further |
Gradient accumulationThe Hugging Face trainer logs (http://shortn/_X54IUN9Pxu) says
So Hugging Face is not using gradient accumulation. |
Gradient clipping and optimizer differencesThe Hugging Face trainer does gradient clipping with a default max norm of Also it seems to wrap the optimizer and group the params into two groups. A detailed debugging session tracing through HF optimizer code is recorded here: https://gist.github.com/tengyifei/44840cfa1c61273ad6565d421208bc17 My suspicion is that the lack of gradient clipping is causing training instability in torchprime. |
Turns out clipping gradients by norm is the only change needed to get torchprime into parity with huggingface! http://tb/share/XcYCmwKvyzGgWMAzgZ69n |
This is EXCELLENT!!! Now we know the mystery! Do you know why they form parameters into two groups? |
Great work here. I wonder if the clipping and param grouping are necessary because AdaFactor is less granular than Adam, and so you need to "trim the peaks". AdaFactor is basically trying to do Adam but with basically a rank-1 low rank approximation, so some metrics will just not be as precise. |
I don't. But, if you inspect the hyper-params in each group, they're actually identical, so this is a no-op anyways.
ACK. Well they're the default in Hugging Face and I presume that config is optimized to finetune models in general. |
Grab a log of a Hugging Face Llama 3.0 8B run. Example: Llama 3.0 8B with gbs 256 on 1 Trillium pod.
Create a test of training Llama 3.0 8B with torchprime with the same configs
Use the same seed to initialize the model
Compare the loss against the HF reference at every iteration
Compare the final output with that from HF
We should run this test at least in post submit CI
The text was updated successfully, but these errors were encountered: