Skip to content

[torch_xla] MVP correctness and convergence check for Llama 3.0 8B #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tengyifei opened this issue Feb 6, 2025 · 14 comments · Fixed by #286
Closed

[torch_xla] MVP correctness and convergence check for Llama 3.0 8B #90

tengyifei opened this issue Feb 6, 2025 · 14 comments · Fixed by #286
Assignees

Comments

@tengyifei
Copy link
Collaborator

  • Grab a log of a Hugging Face Llama 3.0 8B run. Example: Llama 3.0 8B with gbs 256 on 1 Trillium pod.

    • Initialize the model with some fixed seed
    • Train the model for 50 steps
    • Record the loss at every step
    • At the end of the training, record the output of the model on some example tokens
    • Save this as a reference file
  • Create a test of training Llama 3.0 8B with torchprime with the same configs

  • Use the same seed to initialize the model

  • Compare the loss against the HF reference at every iteration

  • Compare the final output with that from HF

  • We should run this test at least in post submit CI

@zpcore
Copy link
Collaborator

zpcore commented Feb 12, 2025

I have the loss for 500 steps with the following:

Configuration:

llama 3 8B, model fsdp sharding over 256
8K sequence length
training dataset: wikitext-103-raw-v1
lr: 5.e-5
max_step: 500
minibatch on
Sampler: torch.utils.data.DistributedSampler

HF code snapshot link for generating data.
Also, in the run HF and TP am using the same code to initialize the model (e.g. in HF link).

Image

The concern is that TP doesn't reach the same loss as what HF can provide. I will check more detail on the optimizer and data representations to see if there are any difference.

@tengyifei
Copy link
Collaborator Author

That's a good catch

@tengyifei
Copy link
Collaborator Author

Maybe we need to drop the last batch from the data loader? (HF config had a dataloader_drop_last: true feature, which says to discard the last potentially incomplete batch).

If the batch size is too small for the last batch, maybe that could cause the loss spikes in the TP run.

@zpcore
Copy link
Collaborator

zpcore commented Feb 12, 2025

Maybe we need to drop the last batch from the data loader? (HF config had a dataloader_drop_last: true feature, which says to discard the last potentially incomplete batch).

If the batch size is too small for the last batch, maybe that could cause the loss spikes in the TP run.

Right, I think this can be different, but when I check the epoch number and step number, HF run and TP run matches, which means they reset the dataloader iterator at the same time and start a new epoch. Below are the last few loss data entries from HF and TP:

HF data link:

[INFO\|trainer.py:2295] 2025-02-12 21:41:02,193 >> Running Epoch: 37, Step 491, tr_loss_step: 4.584621429443359
[INFO\|trainer.py:2295] 2025-02-12 21:41:09,851 >> Running Epoch: 37, Step 492, tr_loss_step: 4.555438041687012
[INFO\|trainer.py:2295] 2025-02-12 21:41:17,509 >> Running Epoch: 37, Step 493, tr_loss_step: 4.5772833824157715
[INFO\|trainer.py:2295] 2025-02-12 21:41:25,165 >> Running Epoch: 38, Step 494, tr_loss_step: 4.5570387840271
[INFO\|trainer.py:2295] 2025-02-12 21:41:32,822 >> Running Epoch: 38, Step 495, tr_loss_step: 4.5764360427856445
[INFO\|trainer.py:2295] 2025-02-12 21:41:40,477 >> Running Epoch: 38, Step 496, tr_loss_step: 4.576854228973389
[INFO\|trainer.py:2295] 2025-02-12 21:41:48,135 >> Running Epoch: 38, Step 497, tr_loss_step: 4.583564758300781
[INFO\|trainer.py:2295] 2025-02-12 21:41:55,793 >> Running Epoch: 38, Step 498, tr_loss_step: 4.563490867614746

TP data link:

[2025-02-12 22:45:16,999][__main__][INFO] - Epoch: 37, Step: 491, loss: 5.2186, trace time: 970.14 ms
[2025-02-12 22:45:22,817][__main__][INFO] - Epoch: 37, Step: 492, loss: 5.2243, trace time: 961.56 ms
[2025-02-12 22:45:28,635][__main__][INFO] - Epoch: 37, Step: 493, loss: 5.2219, trace time: 968.68 ms
[2025-02-12 22:45:34,452][__main__][INFO] - Epoch: 38, Step: 494, loss: 5.2143, trace time: 955.37 ms
[2025-02-12 22:45:40,270][__main__][INFO] - Epoch: 38, Step: 495, loss: 5.2218, trace time: 970.47 ms
[2025-02-12 22:45:46,088][__main__][INFO] - Epoch: 38, Step: 496, loss: 5.2295, trace time: 979.35 ms
[2025-02-12 22:45:51,905][__main__][INFO] - Epoch: 38, Step: 497, loss: 5.2262, trace time: 965.77 ms
[2025-02-12 22:45:57,722][__main__][INFO] - Epoch: 38, Step: 498, loss: 5.2182, trace time: 972.73 ms

Let me double check the dataloader config.

@IsNoobgrammer
Copy link

The loss function for TP I assume is

def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, vocab_size: int):
  """
  Computes cross entropy loss of `logits` against the ground truth `labels` during
  next token prediction.

  Useful as the loss function of a LLM in pretraining or supervised finetuning.
  """
  # Shift so that tokens < n predict n
  shift_logits = logits[..., :-1, :].contiguous()
  shift_labels = labels[..., 1:].contiguous()
  # Flatten the tokens
  loss_fct = CrossEntropyLoss()
  shift_logits = shift_logits.view(-1, vocab_size)
  shift_labels = shift_labels.view(-1)
  shift_labels = shift_labels.to(shift_logits.device)
  return loss_fct(shift_logits, shift_labels)

Where-as the Loss Function for HF is

Label smoothed loss function ;

could you please verify if we are using same loss function

@tengyifei tengyifei assigned tengyifei and unassigned zpcore May 8, 2025
@tengyifei tengyifei changed the title [torch_xla] MVP correctness check for Llama 3.0 8B [torch_xla] MVP correctness and convergence check for Llama 3.0 8B May 27, 2025
@tengyifei
Copy link
Collaborator Author

tengyifei commented Jun 3, 2025

This convergence gap still exists as of the latest main branch of torchprime vs Pei's Hugging Face branch: http://tb/share/zgXh5d4v4P7oPoehDRhh4

I'll now be looking into this. Areas to investigate:

Appendix

torchprime training command:

tp run --name llama-3-8b-linear torchprime/torch_xla_models/train.py model=llama-3-8b ici_mesh.fsdp=256 profile_step=3 profile_duration=30000 max_steps=1000 logging_steps=1 global_batch_size=256 dataset_config_name=wikitext-103-raw-v1 run_name=tp-linear

huggingface training command:

tp run --use-hf torchprime/hf_models/train.py train_script.args.per_device_train_batch_size=256 +train_script.args.log_loss=true train_script.args.logging_strategy=steps +train_script.args.logging_steps=1 +train_script.args.logging_first_step=true +train_script.args.report_to=tensorboard train_script.args.max_steps=1000

@tengyifei
Copy link
Collaborator Author

tengyifei commented Jun 4, 2025

re loss function: I ran Hugging Face trainer and set a breakpoint in the compute_loss function. This is what I got:

> /workspaces/torchprime/local_transformers/src/transformers/trainer.py(3253)compute_loss()
-> if self.label_smoother is not None and "labels" in inputs:
(Pdb) print(self.label_smoother)
None
(Pdb) bt
  /workspace/local_transformers/examples/pytorch/language-modeling/run_clm.py(746)<module>()
-> main()
  /workspace/local_transformers/examples/pytorch/language-modeling/run_clm.py(694)main()
-> train_result = trainer.train(resume_from_checkpoint=checkpoint)
  /workspaces/torchprime/local_transformers/src/transformers/trainer.py(1928)train()
-> return inner_training_loop(
  /workspaces/torchprime/local_transformers/src/transformers/trainer.py(2281)_inner_training_loop()
-> tr_loss_step = self.training_step(model, inputs)
  /workspaces/torchprime/local_transformers/src/transformers/trainer.py(3233)training_step()
-> loss = self.compute_loss(model, inputs)
> /workspaces/torchprime/local_transformers/src/transformers/trainer.py(3253)compute_loss()
-> if self.label_smoother is not None and "labels" in inputs:
(Pdb)

This indicates that label_smoother defaults to None so the convergence gap shows up in a case where both are not using a label smoother.

Further pdb shows that the loss is computed in https://github.com/pytorch-tpu/transformers/blob/02289f4fbe375d0e464985dafe53e72690e435b8/src/transformers/models/llama/modeling_llama.py#L1256 so it's just regular CrossEntropyLoss.

@tengyifei
Copy link
Collaborator Author

Gradient accumulation

The Hugging Face trainer logs (http://shortn/_X54IUN9Pxu) says

ERROR 2025-06-04T23:40:00.652563158Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2123] 2025-06-04 23:40:00,652 >> ***** Running training *****
ERROR 2025-06-04T23:40:00.652596078Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2124] 2025-06-04 23:40:00,652 >> Num examples = 13,824
ERROR 2025-06-04T23:40:00.652599948Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2125] 2025-06-04 23:40:00,652 >> Num Epochs = 19
ERROR 2025-06-04T23:40:00.652603548Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2126] 2025-06-04 23:40:00,652 >> Instantaneous batch size per device = 256
ERROR 2025-06-04T23:40:00.652606888Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2129] 2025-06-04 23:40:00,652 >> Total train batch size (w. parallel, distributed & accumulation) = 256
ERROR 2025-06-04T23:40:00.652614828Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2130] 2025-06-04 23:40:00,652 >> Gradient Accumulation steps = 1
ERROR 2025-06-04T23:40:00.652617988Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2131] 2025-06-04 23:40:00,652 >> Total optimization steps = 1,000
ERROR 2025-06-04T23:40:00.653276808Z [resource.labels.containerName: jax-tpu] [INFO|trainer.py:2132] 2025-06-04 23:40:00,653 >> Number of trainable parameters = 8,030,261,248

So Hugging Face is not using gradient accumulation.

@tengyifei
Copy link
Collaborator Author

Dataset iteration ordering

I compared HF branch 1 with TP branch 2 in this notebook 3.

The inputs fed to the HF model and TP model are identical for the first 198 steps. This basically means they're getting the same data. So dataset differences is ruled out now.

@tengyifei
Copy link
Collaborator Author

Gradient clipping and optimizer differences

The Hugging Face trainer does gradient clipping with a default max norm of 1.

Also it seems to wrap the optimizer and group the params into two groups.

A detailed debugging session tracing through HF optimizer code is recorded here: https://gist.github.com/tengyifei/44840cfa1c61273ad6565d421208bc17

My suspicion is that the lack of gradient clipping is causing training instability in torchprime.

@tengyifei
Copy link
Collaborator Author

Turns out clipping gradients by norm is the only change needed to get torchprime into parity with huggingface! http://tb/share/XcYCmwKvyzGgWMAzgZ69n

@zpcore
Copy link
Collaborator

zpcore commented Jun 6, 2025

This is EXCELLENT!!! Now we know the mystery!

Do you know why they form parameters into two groups?

@yaoshiang
Copy link
Collaborator

Great work here. I wonder if the clipping and param grouping are necessary because AdaFactor is less granular than Adam, and so you need to "trim the peaks". AdaFactor is basically trying to do Adam but with basically a rank-1 low rank approximation, so some metrics will just not be as precise.

@tengyifei
Copy link
Collaborator Author

Do you know why they form parameters into two groups?

I don't. But, if you inspect the hyper-params in each group, they're actually identical, so this is a no-op anyways.

I wonder if the clipping and param grouping are necessary

ACK. Well they're the default in Hugging Face and I presume that config is optimized to finetune models in general.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants