Skip to content

add llama more test #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 0 commits into from
Closed

add llama more test #26

wants to merge 0 commits into from

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Jan 10, 2025

Resolve #6.

Note, there are two issues in torch_xla run, as in torchprime/experimental/torchax_models/test/test_llama.py

  1. Error between torch_xla and native pytorch CPU run is qute obvious, at level of ~ 0.1.
  2. scaled_dot_product is not been replaced with flash_attention kernel for torch_xla. This requires the code change to the model.

@zpcore zpcore requested review from bhavya01 and tengyifei January 10, 2025 21:12
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like CPU checks failed due to missing torch_xla2. I'll fix.

freqs_cis=freqs_cis,
mask=torch.ones_like(self.input)).to("cpu").to(torch.bfloat16)
self.assertTrue(
torch.allclose(output.to("cpu"), self.native_output, atol=1e-1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you file another issue on https://github.com/AI-Hypercomputer/torchprime to investigate the 1e-1 and link it here? I might take a look.

@zpcore zpcore force-pushed the piz/llamatest branch 2 times, most recently from 7fe6686 to 194e15e Compare January 11, 2025 01:44
@zpcore zpcore closed this Jan 11, 2025
@zpcore zpcore mentioned this pull request Jan 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a CPU test to torchax Llama model
2 participants