Open
Description
As mentioned by @ajjimeno, the encoder
is not available to MPS but the decoder
is the bottleneck and can be run through a CUDA or MPS backend for GPU acceleration. This MPS backend is supported by the PyTorch framework. Pytorch backend support docs
It would just be to check if MPS is available, detach the encoder
and decoder
when detecting MPS instead of running model.generate
, and map the computational graph of the decoder
on the mps
device. HugginFace example on MPS backend.
Metadata
Metadata
Assignees
Labels
No labels