Skip to content

Add llama test #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 13, 2025
Merged

Add llama test #29

merged 3 commits into from
Jan 13, 2025

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Jan 11, 2025

Resolve #6.

Note, there are two issues in torch_xla run, as in torchprime/experimental/torchax_models/test/test_llama.py

  1. Error between torch_xla and native pytorch CPU run is qute obvious, at level of ~ 0.1.
  2. scaled_dot_product is not been replaced with flash_attention kernel for torch_xla. This requires the code change to the model.

I kind of messed up PR #26 with rebased to a commit. Opened this new one instead...

@zpcore
Copy link
Collaborator Author

zpcore commented Jan 11, 2025

Create issue #30 to track the output numerical error.

@zpcore zpcore requested a review from tengyifei January 11, 2025 02:07
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but the PR needs to be formatted. Check out the contributing section: you can run yapf --recursive -i '*.py' torchprime launcher to format everything.

@zpcore zpcore merged commit 356bea5 into main Jan 13, 2025
6 checks passed
@zpcore zpcore deleted the piz/llamatest branch January 13, 2025 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a CPU test to torchax Llama model
2 participants