Skip to content

Add MPS support #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add MPS support #190

wants to merge 1 commit into from

Conversation

12v
Copy link

@12v 12v commented Mar 11, 2025

Hello, this is an attempt to add MPS support.

There are two issues preventing MPS from working with the transformer backend:

torch.compile doesn't support MPS (see: here)

The solution for this issue is straightforward, just adding another condition check before using torch.compile.

Grouped Query Attention (GQA) only works with CUDA (see: here)

This is more complex. Falling back to Multi-Headed Attention (MHA) on MPS requires the same number of heads for KV as for Q, but the pre-trained weights expect a smaller number of KV heads than Q heads. My current solution for this is to duplicate the KV heads and weights to match the number of Q heads.

Aside from the extra code, the main downside of this approach is that the weights saved from a model on MPS can't be loaded again (whether on MPS or any other backend). Possible paths forward:

  1. Continue with this approach, and just log a warning or error if the model is trained on MPS
  2. When on MPS, copy and compress the weights internally to a shape corresponding to the number of heads used by GQA so they can be saved and loaded
  3. ...

Additionally, an alternative to this approach of transformer the weights within the model is to instead transform the pre-trained weights outside of the model before loading.

Notes

The recording generated on MPS doesn't sound as good as the recording generated on CPU.

This was referenced Mar 11, 2025
@12v 12v marked this pull request as ready for review March 12, 2025 18:40
@tjameswilliams
Copy link

@12v have you done testing on this branch on an Apple device? Does it speed this up considerably? This is awesome, because inference is insanely slow on mac (I am using an m4 Max and a few seconds of audio takes minutes.)

I tried solving this myself, but continued to run into issues with tokenization.

@tjameswilliams
Copy link

Actually I just cloned your branch and tested it. Still no support for the Hybrid, but that's ok, the transformer is massively faster on MPS.

@ReadyPlayerEmma
Copy link

How much is the quality affected? Is there a way to get the behavior/quality to match the CPU case? What exactly is causing the quality loss?

@tjameswilliams
Copy link

@ReadyPlayerEmma to be fair (and complete) the quality loss happens on CUDA too. My guess would simply be floating point precision is less precise.

@Aedelon
Copy link

Aedelon commented Apr 14, 2025

I just saw your PR. I made mine without being aware of your work.

In model.py, you can compile with the backend "aot_eager". It is also indicated on the link you shared: LINK

decode_one_token = torch.compile( decode_one_token, dynamic=True, backend="aot_eager", disable=cg or disable_torch_compile )

@Aedelon Aedelon mentioned this pull request Apr 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants