Skip to content

[torch_xla] Fix NaN training loss #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged

[torch_xla] Fix NaN training loss #34

merged 1 commit into from
Jan 14, 2025

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Jan 14, 2025

Fixes #9

Fixed by switching to the Adafactor optimizer, which is what we have been using for model optimization in the past several months.

Tested:

  XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3
  torchprime/torch_xla_models/train.py \
      torchprime/torch_xla_models/configs/run.json

  NUM_SLICES=1 TPU_TYPE=v6e-256 launcher/run_xpk.sh \
     torchprime/torch_xla_models/train.py \
     --dataset_name wikitext \
     --dataset_config_name 'wikitext-2-raw-v1' \
     --output_dir /tmp \
     --cache_dir /tmp \
     --global_batch_size 256 \
     --logging_steps 10 \
     --max_steps 15 \
     --profile_step 5 \
     --model_id 'meta-llama/Meta-Llama-3-8B' \
     --tokenizer_name 'meta-llama/Meta-Llama-3-8B' \
     --block_size 8192 \
     --fsdp full_shard \
     --fsdp_config torchprime/torch_xla_models/configs/fsdp_config.json

Fixed by switching to the Adafactor optimizer, which is what we have
been using for model optimization in the past several months.

Tested:

  XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3
  torchprime/torch_xla_models/train.py \
      torchprime/torch_xla_models/configs/run.json

  NUM_SLICES=1 TPU_TYPE=v6e-256 launcher/run_xpk.sh \
     torchprime/torch_xla_models/train.py \
     --dataset_name wikitext \
     --dataset_config_name 'wikitext-2-raw-v1' \
     --output_dir /tmp \
     --cache_dir /tmp \
     --global_batch_size 256 \
     --logging_steps 10 \
     --max_steps 15 \
     --profile_step 5 \
     --model_id 'meta-llama/Meta-Llama-3-8B' \
     --tokenizer_name 'meta-llama/Meta-Llama-3-8B' \
     --block_size 8192 \
     --fsdp full_shard \
     --fsdp_config torchprime/torch_xla_models/configs/fsdp_config.json
@tengyifei tengyifei requested a review from bhavya01 January 14, 2025 17:53
Copy link
Collaborator

@bhavya01 bhavya01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the AdamW optimizer should also work well with the models. That's what MaxText uses.

Okay to submit this for now but can we open a separate issue that training loss is NaN with AdamW optimizer.

@tengyifei
Copy link
Collaborator Author

SGTM

@tengyifei tengyifei merged commit 3aa04b2 into main Jan 14, 2025
6 checks passed
@tengyifei tengyifei deleted the yifeit/fix-llama-nan branch January 26, 2025 07:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix the NaN training loss in Llama 3
2 participants