You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is primarily due to different handling of Rotary Positional Embedding (RoPE) weight permutations in Llama3.
The Llama 2/Mistral/Mixtral require a specific permutation of query (Q) and key (K) projection weights when converting from MaxText to Hugging Face format. It seems like the original script (via MaxText.max_utils.unpermute_from_match_maxtext_rope) performed this.
Llama 3 Family (3, 3.1, 3.2) do not require this same permutation; their Q/K weights from MaxText are already in the Hugging Face expected order for RoPE. Using the old conversion script runs without errors, but the converted models are really bad.
I have a working script here. It might be a bit messy, and I have not tested this on Mistral/Mixtral/Llama2. I have verified that the output looks good for the converted Llama3.1 checkpoint. Comparing the converted model with the original checkpoint shows that they are identical.
The old script was hardcoded for float16. I changed this to bfloat16.
Since the script added some extra logic, it might be better to not build on top of the old script. So I did not do a PR on this.
The text was updated successfully, but these errors were encountered:
Uh oh!
There was an error while loading. Please reload this page.
@SamuelMarks @khatwanimohit
Even if Llama3 in mentioned as one of the models in llama_mistral_mixtral_orbax_to_hf.py, it does not convert these models correctly to HF.
This is primarily due to different handling of Rotary Positional Embedding (RoPE) weight permutations in Llama3.
The Llama 2/Mistral/Mixtral require a specific permutation of query (Q) and key (K) projection weights when converting from MaxText to Hugging Face format. It seems like the original script (via MaxText.max_utils.unpermute_from_match_maxtext_rope) performed this.
Llama 3 Family (3, 3.1, 3.2) do not require this same permutation; their Q/K weights from MaxText are already in the Hugging Face expected order for RoPE. Using the old conversion script runs without errors, but the converted models are really bad.
I have a working script here. It might be a bit messy, and I have not tested this on Mistral/Mixtral/Llama2. I have verified that the output looks good for the converted Llama3.1 checkpoint. Comparing the converted model with the original checkpoint shows that they are identical.
The old script was hardcoded for float16. I changed this to bfloat16.
Since the script added some extra logic, it might be better to not build on top of the old script. So I did not do a PR on this.
The text was updated successfully, but these errors were encountered: