-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Apple Silicon support for Chipper Model #239
Comments
We are still integrating the new changes for Chipper, but I tried to see what could be done for MPS. After decoupling the encoder and the decoder it seems that there might be additional changes to be done on the decoder. I get a warning that indicates that some tensors need to be mapped from int64 to int32 and makes the greedy decoding as slow as just using CPU, it is even slower than CPU when beam search size = 3 is used. This seems to be an issue in the integration of MPS capabilities in PyTorch, even in the latest version of PyTorch. One option could be to modify the generator and test int32 where LongTensor is currently used or check support from PyTorch for LongTensor under MPS. |
Looking at Apple forums, int64 operations are supported by the GPU accelerator. Have you tried using the latest PyTorch nightly build? This issue was previously raised in PyTorch and also in other repos. Perhaps your PyTorch version doesn't have the If this was not the issue, how would you approach the first option proposed? Would it be possible to convert the input sequence to input_seq = torch.tensor(input_seq, dtype=torch.int32)
output_seq = decoder(input_seq)
output_seq = output_seq.type(torch.int64) |
I did try the latest version. In order to try making it work, the HF
generation code (
https://huggingface.co/docs/transformers/main_classes/text_generation) will
need to be revised to convert the mentions of LongTensor to int32 or a
version that works on mps efficiently. I tried converting the input ids to
int32 but the warning is from one of the methods in the HF generation code
that has no relation to the type of the input ids. It should be tested with
an MBARTDecoder, which probably works ok. If there is a setting in which
the HF generation code with an MBARTDecoder works faster on mps than on
CPU, it should be possible to speed up Chipper with mps.
…On Wed, Oct 4, 2023 at 8:25 PM Diego Sanmartin ***@***.***> wrote:
Looking at Apple forums <https://developer.apple.com/forums/thread/712317>,
int64 operations are supported by the GPU accelerator.
Have you tried using the latest PyTorch nightly build? This issue was
previously raised in PyTorch
<pytorch/pytorch#96610 (comment)>
and also in other repos
<Stability-AI/StableLM#61 (comment)>.
Perhaps your PyTorch version doesn't have the LongTensor ops enabled on
MPS? Could you please share the warning message you are getting?
If this was not the issue, how would you approach the first option
proposed? Would it be possible to convert the input sequence to int32
before passing it to the decoder and then converting it back to int64 to
avoid encountering bugs later? It would look something like this:
input_seq = torch.tensor(input_seq, dtype=torch.int32)
output_seq = decoder(input_seq)
output_seq = output_seq.type(torch.int64)
—
Reply to this email directly, view it on GitHub
<#239 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA6BZDJKLVVLVYNWUQJEKD3X5UTQDAVCNFSM6AAAAAA5PE2IVKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONBWGQ4DIMZZGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I can imagine that you are testing it on the new code for #232 ? What encoder and decoder architectures are you using in I saw that you are using |
As mentioned by @ajjimeno, the
encoder
is not available to MPS but thedecoder
is the bottleneck and can be run through a CUDA or MPS backend for GPU acceleration. This MPS backend is supported by the PyTorch framework. Pytorch backend support docsIt would just be to check if MPS is available, detach the
encoder
anddecoder
when detecting MPS instead of runningmodel.generate
, and map the computational graph of thedecoder
on themps
device. HugginFace example on MPS backend.The text was updated successfully, but these errors were encountered: