Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hello, this is an attempt to add MPS support.
There are two issues preventing MPS from working with the transformer backend:
torch.compile doesn't support MPS (see: here)
The solution for this issue is straightforward, just adding another condition check before using torch.compile.
Grouped Query Attention (GQA) only works with CUDA (see: here)
This is more complex. Falling back to Multi-Headed Attention (MHA) on MPS requires the same number of heads for KV as for Q, but the pre-trained weights expect a smaller number of KV heads than Q heads. My current solution for this is to duplicate the KV heads and weights to match the number of Q heads.
Aside from the extra code, the main downside of this approach is that the weights saved from a model on MPS can't be loaded again (whether on MPS or any other backend). Possible paths forward:
Additionally, an alternative to this approach of transformer the weights within the model is to instead transform the pre-trained weights outside of the model before loading.
Notes
The recording generated on MPS doesn't sound as good as the recording generated on CPU.