-
Notifications
You must be signed in to change notification settings - Fork 870
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TorchServe linux-aarch64 experimental support (#3071)
* Changes for building TorchServe on linux aarch64 * Changes for building TorchServe on linux aarch64 * Added an example for linux aarch64 * Doc update for linux aarch64 * Doc update for linux aarch64 * Doc update for linux aarch64 * removed torchtext for aarch64 * lint failure * lint failure * Build conda binaries * Build conda binaries * resolving merge conflicts * resolving merge conflicts * update documentation * review comments * Updated based on review comments --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]>
- Loading branch information
1 parent
a69e561
commit 5c1682a
Showing
18 changed files
with
202 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# TorchServe on linux aarch64 - Experimental | ||
|
||
TorchServe has been tested to be working on linux aarch64 for some of the examples. | ||
- Tested this on Amazon Graviton 3 instance(m7g.4x.large) | ||
|
||
## Installation | ||
|
||
Currently installation from PyPi or installing from source works | ||
|
||
``` | ||
python ts_scripts/install_dependencies.py | ||
pip install torchserve torch-model-archiver torch-workflow-archiver | ||
``` | ||
|
||
## Optimizations | ||
|
||
You can also enable this optimizations for Graviton 3 to get an improved performance. More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/) | ||
``` | ||
export DNNL_DEFAULT_FPMATH_MODE=BF16 | ||
export LRU_CACHE_CAPACITY=1024 | ||
``` | ||
|
||
## Example | ||
|
||
This [example](https://github.com/pytorch/serve/tree/master/examples/text_to_speech_synthesizer/SpeechT5) on Text to Speech synthesis was verified to be working on Graviton 3 | ||
|
||
## To Dos | ||
- CI | ||
- Regression tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Text to Speech synthesis with SpeechT5 | ||
|
||
This is an example showing text to speech synthesis using SpeechT5 model. This has been verified to work on (linux-aarch64) Graviton 3 instance | ||
|
||
While running this model on `linux-aarch64`, you can enable these optimizations | ||
|
||
``` | ||
export DNNL_DEFAULT_FPMATH_MODE=BF16 | ||
export LRU_CACHE_CAPACITY=1024 | ||
``` | ||
More details can be found in this [blog](https://pytorch.org/blog/optimized-pytorch-w-graviton/) | ||
|
||
|
||
## Pre-requisites | ||
``` | ||
chmod +x setup.sh | ||
./setup.sh | ||
``` | ||
|
||
## Download model | ||
|
||
This saves the model artifacts to `model_artifacts` directory | ||
``` | ||
huggingface-cli login | ||
python download_model.py | ||
``` | ||
|
||
## Create model archiver | ||
|
||
``` | ||
mkdir model_store | ||
torch-model-archiver --model-name SpeechT5-TTS --version 1.0 --handler text_to_speech_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f | ||
mv model_artifacts/* model_store/SpeechT5-TTS/ | ||
``` | ||
|
||
## Start TorchServe | ||
|
||
``` | ||
torchserve --start --ncs --model-store model_store --models SpeechT5-TTS | ||
``` | ||
|
||
## Send Inference request | ||
|
||
``` | ||
curl http://127.0.0.1:8080/predictions/SpeechT5-TTS -T sample_input.txt -o speech.wav | ||
``` | ||
|
||
This generates an audio file `speech.wav` corresponding to the text in `sample_input.txt` |
14 changes: 14 additions & 0 deletions
14
examples/text_to_speech_synthesizer/SpeechT5/download_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from datasets import load_dataset | ||
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | ||
|
||
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | ||
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts") | ||
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | ||
|
||
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | ||
|
||
model.save_pretrained(save_directory="model_artifacts/model") | ||
processor.save_pretrained(save_directory="model_artifacts/processor") | ||
vocoder.save_pretrained(save_directory="model_artifacts/vocoder") | ||
embeddings_dataset.save_to_disk("model_artifacts/speaker_embeddings") | ||
print("Save model artifacts to directory model_artifacts") |
8 changes: 8 additions & 0 deletions
8
examples/text_to_speech_synthesizer/SpeechT5/model-config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
handler: | ||
model: "model" | ||
vocoder: "vocoder" | ||
processor: "processor" | ||
speaker_embeddings: "speaker_embeddings" | ||
output_dir: "/tmp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"I love San Francisco" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
# Needed for soundfile | ||
sudo apt install libsndfile1 -y | ||
|
||
pip install --upgrade transformers sentencepiece datasets[audio] soundfile |
68 changes: 68 additions & 0 deletions
68
examples/text_to_speech_synthesizer/SpeechT5/text_to_speech_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import logging | ||
import os | ||
import uuid | ||
|
||
import soundfile as sf | ||
import torch | ||
from datasets import load_from_disk | ||
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SpeechT5_TTS(BaseHandler): | ||
def __init__(self): | ||
self.model = None | ||
self.processor = None | ||
self.vocoder = None | ||
self.speaker_embeddings = None | ||
self.output_dir = "/tmp" | ||
|
||
def initialize(self, ctx): | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
processor = ctx.model_yaml_config["handler"]["processor"] | ||
model = ctx.model_yaml_config["handler"]["model"] | ||
vocoder = ctx.model_yaml_config["handler"]["vocoder"] | ||
embeddings_dataset = ctx.model_yaml_config["handler"]["speaker_embeddings"] | ||
self.output_dir = ctx.model_yaml_config["handler"]["output_dir"] | ||
|
||
self.processor = SpeechT5Processor.from_pretrained(processor) | ||
self.model = SpeechT5ForTextToSpeech.from_pretrained(model) | ||
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder) | ||
|
||
# load xvector containing speaker's voice characteristics from a dataset | ||
embeddings_dataset = load_from_disk(embeddings_dataset) | ||
self.speaker_embeddings = torch.tensor( | ||
embeddings_dataset[7306]["xvector"] | ||
).unsqueeze(0) | ||
|
||
def preprocess(self, requests): | ||
assert len(requests) == 1, "This is currently supported with batch_size=1" | ||
req_data = requests[0] | ||
|
||
input_data = req_data.get("data") or req_data.get("body") | ||
|
||
if isinstance(input_data, (bytes, bytearray)): | ||
input_data = input_data.decode("utf-8") | ||
|
||
inputs = self.processor(text=input_data, return_tensors="pt") | ||
|
||
return inputs | ||
|
||
def inference(self, inputs): | ||
output = self.model.generate_speech( | ||
inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder | ||
) | ||
return output | ||
|
||
def postprocess(self, inference_output): | ||
path = self.output_dir + "/{}.wav".format(uuid.uuid4().hex) | ||
sf.write(path, inference_output.numpy(), samplerate=16000) | ||
with open(path, "rb") as output: | ||
data = output.read() | ||
os.remove(path) | ||
return [data] |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu | ||
--extra-index-url https://download.pytorch.org/whl/cpu | ||
-r torch_common.txt | ||
torch==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64' | ||
torchvision==0.17.1; sys_platform == 'linux' and platform_machine == 'aarch64' | ||
torchaudio==2.2.1; sys_platform == 'linux' and platform_machine == 'aarch64' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1216,6 +1216,10 @@ libomp | |
rpath | ||
venv | ||
TorchInductor | ||
Graviton | ||
aarch | ||
linux | ||
SpeechT | ||
Pytests | ||
deviceType | ||
XGBoost | ||
|