diff --git a/models/tts/maskgct/README.md b/models/tts/maskgct/README.md index b75a1ea6..db168313 100644 --- a/models/tts/maskgct/README.md +++ b/models/tts/maskgct/README.md @@ -27,6 +27,11 @@ If you encounter any issue when using MaskGCT, feel free to open an issue in thi ## Quickstart +To run this model, you need to follow the steps below: + +1. Clone the repository and install the environment. +2. Run the Jupyter notebook / Gradio demo / Inference script. + ### Clone and Environment This parts, follow the steps below to clone the repository and install the environment. @@ -81,33 +86,46 @@ sudo yum install espeak-ng # Please visit https://github.com/espeak-ng/espeak-ng/releases to download .msi installer ``` -It is recommended to use conda to configure the environment. You can use the following command to create and activate a new conda environment. +Now, we are going to install the environment. It is recommended to use conda to configure: ```bash conda create -n maskgct python=3.10 conda activate maskgct + +pip install -r models/tts/maskgct/requirements.txt ``` -Then, install the python packages. +If no error occurs, the environment is successfully installed. Now, you can choose you preferred way to run the model. + +- Run Gradio 🤗 Playground Locally +- Jupyter Notebook +- Inference Script + +### Run Gradio 🤗 Playground Locally + +You can run the following command to interact with the playground: ```bash -pip install -r models/tts/maskgct/requirements.txt +python -m models.tts.maskgct.gradio_demo ``` ### Jupyter Notebook We provide a [Jupyter notebook](../../../models/tts/maskgct/maskgct_demo.ipynb) to show how to use MaskGCT to inference. -After installing the environment, you can open this notebook and start running. +You can open this notebook and start running. + +### Inference Script -### Start from Scratch +If you do not want to use Juptyer notebook or Gradio, you can use the python inference script. -If you do not want to use Juptyer notebook, you can start from scratch. We provide the following steps to help you start from scratch. +```bash +python -m models.tts.maskgct.maskgct_inference +``` -1. Download the pretrained model. -2. Load the model and inference. +Running this will automatically download the pretrained model from HuggingFace and start the inference process. The result audio is by default saved in `generated_audio.wav`, you can change this in the [script](../../../models/tts/maskgct/maskgct_inference.py). -#### 1. Model download +## Model Introduction We provide the following pretrained checkpoints: @@ -119,7 +137,7 @@ We provide the following pretrained checkpoints: | [MaskGCT-T2S](https://huggingface.co/amphion/MaskGCT/tree/main/t2s_model) | Predicting semantic tokens with text and prompt semantic tokens. | | [MaskGCT-S2A](https://huggingface.co/amphion/MaskGCT/tree/main/s2a_model) | Predicts acoustic tokens conditioned on semantic tokens. | -You can download all pretrained checkpoints from [HuggingFace](https://huggingface.co/amphion/MaskGCT/tree/main) or use huggingface api. +You can download all pretrained checkpoints from [HuggingFace](https://huggingface.co/amphion/MaskGCT/tree/main) or use huggingface API. ```python from huggingface_hub import hf_hub_download @@ -127,90 +145,17 @@ from huggingface_hub import hf_hub_download # download semantic codec ckpt semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors") -# download acoustic codec ckpt -codec_encoder_ckpt = hf_hub_download("amphion/MaskGCT", filename="acoustic_codec/model.safetensors") -codec_decoder_ckpt = hf_hub_download("amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors") - -# download t2s model ckpt -t2s_model_ckpt = hf_hub_download("amphion/MaskGCT", filename="t2s_model/model.safetensors") - -# download s2a model ckpt -s2a_1layer_ckpt = hf_hub_download("amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors") -s2a_full_ckpt = hf_hub_download("amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors") +# same for other models ``` -#### 2. Basic Inference - -You can use the following code to generate speech from text and a prompt speech (the code is also provided in [inference.py](../../../models/tts/maskgct/maskgct_inference.py)). +By running the Jupyter/Gradio/inference script, it will automatically download the pretrained model from HuggingFace. -Run it with `python -m models.tts.maskgct.maskgct_inference`. +If you having trouble connecting to HuggingFace, you try switch endpoint to mirror site: -```python -from models.tts.maskgct.maskgct_utils import * -from huggingface_hub import hf_hub_download -import safetensors -import soundfile as sf - -if __name__ == "__main__": - - # build model - device = torch.device("cuda:0") - cfg_path = "./models/tts/maskgct/config/maskgct.json" - cfg = load_config(cfg_path) - # 1. build semantic model (w2v-bert-2.0) - semantic_model, semantic_mean, semantic_std = build_semantic_model(device) - # 2. build semantic codec - semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device) - # 3. build acoustic codec - codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device) - # 4. build t2s model - t2s_model = build_t2s_model(cfg.model.t2s_model, device) - # 5. build s2a model - s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device) - s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device) - - # download checkpoint - # ... - - # load semantic codec - safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) - # load acoustic codec - safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt) - safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt) - # load t2s model - safetensors.torch.load_model(t2s_model, t2s_model_ckpt) - # load s2a model - safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt) - safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt) - - # inference - prompt_wav_path = "./models/tts/maskgct/wav/prompt.wav" - save_path = "[YOUR SAVE PATH]" - prompt_text = " We do not break. We never give in. We never back down." - target_text = "In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision." - # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration. - target_len = 18 - - maskgct_inference_pipeline = MaskGCT_Inference_Pipeline( - semantic_model, - semantic_codec, - codec_encoder, - codec_decoder, - t2s_model, - s2a_model_1layer, - s2a_model_full, - semantic_mean, - semantic_std, - device, - ) - - recovered_audio = maskgct_inference_pipeline.maskgct_inference( - prompt_wav_path, prompt_text, target_text, "en", "en", target_len=target_len - ) - sf.write(save_path, recovered_audio, 24000) +```bash +export HF_ENDPOINT=https://hf-mirror.com ``` - ## Training Dataset We use the [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) dataset to train our models. Emilia is a multilingual and diverse in-the-wild speech dataset designed for large-scale speech generation. In this work, we use English and Chinese data from Emilia, each with 50K hours of speech (totaling 100K hours). diff --git a/models/tts/maskgct/gradio_demo.py b/models/tts/maskgct/gradio_demo.py new file mode 100644 index 00000000..fe0ae933 --- /dev/null +++ b/models/tts/maskgct/gradio_demo.py @@ -0,0 +1,433 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import spaces +import accelerate +import gradio as gr +import torch +import safetensors +from huggingface_hub import hf_hub_download +import soundfile as sf +import os + +import numpy as np +import librosa +from models.codec.kmeans.repcodec_model import RepCodec +from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A +from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S +from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder +from transformers import Wav2Vec2BertModel +from utils.util import load_config +from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p + +from transformers import SeamlessM4TFeatureExtractor +import py3langid as langid +import whisper + +print("Start loading: facebook/w2v-bert-2.0") +processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") +device = torch.device("cuda" if torch.cuda.is_available() else "CPU") +whisper_model = None +output_file_name_idx = 0 + + +def detect_text_language(text): + return langid.classify(text)[0] + + +def detect_speech_language(speech_file): + global whisper_model + if whisper_model == None: + whisper_model = whisper.load_model("turbo") + # load audio and pad/trim it to fit 30 seconds + audio = whisper.load_audio(speech_file) + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device) + + # detect the spoken language + _, probs = whisper_model.detect_language(mel) + return max(probs, key=probs.get) + + +@torch.no_grad() +def get_prompt_text(speech_16k, language): + full_prompt_text = "" + shot_prompt_text = "" + short_prompt_end_ts = 0.0 + + global whisper_model + if whisper_model == None: + whisper_model = whisper.load_model("turbo") + asr_result = whisper_model.transcribe(speech_16k, language=language) + full_prompt_text = asr_result["text"] # whisper asr result + # text = asr_result["segments"][0]["text"] # whisperx asr result + shot_prompt_text = "" + short_prompt_end_ts = 0.0 + for segment in asr_result["segments"]: + shot_prompt_text = shot_prompt_text + segment["text"] + short_prompt_end_ts = segment["end"] + if short_prompt_end_ts >= 4: + break + return full_prompt_text, shot_prompt_text, short_prompt_end_ts + + +def g2p_(text, language): + if language in ["zh", "en"]: + return chn_eng_g2p(text) + else: + return g2p(text, sentence=None, language=language) + + +def build_t2s_model(cfg, device): + t2s_model = MaskGCT_T2S(cfg=cfg) + t2s_model.eval() + t2s_model.to(device) + return t2s_model + + +def build_s2a_model(cfg, device): + soundstorm_model = MaskGCT_S2A(cfg=cfg) + soundstorm_model.eval() + soundstorm_model.to(device) + return soundstorm_model + + +def build_semantic_model(device): + semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") + semantic_model.eval() + semantic_model.to(device) + stat_mean_var = torch.load("./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt") + semantic_mean = stat_mean_var["mean"] + semantic_std = torch.sqrt(stat_mean_var["var"]) + semantic_mean = semantic_mean.to(device) + semantic_std = semantic_std.to(device) + return semantic_model, semantic_mean, semantic_std + + +def build_semantic_codec(cfg, device): + semantic_codec = RepCodec(cfg=cfg) + semantic_codec.eval() + semantic_codec.to(device) + return semantic_codec + + +def build_acoustic_codec(cfg, device): + codec_encoder = CodecEncoder(cfg=cfg.encoder) + codec_decoder = CodecDecoder(cfg=cfg.decoder) + codec_encoder.eval() + codec_decoder.eval() + codec_encoder.to(device) + codec_decoder.to(device) + return codec_encoder, codec_decoder + + +@torch.no_grad() +def extract_features(speech, processor): + inputs = processor(speech, sampling_rate=16000, return_tensors="pt") + input_features = inputs["input_features"][0] + attention_mask = inputs["attention_mask"][0] + return input_features, attention_mask + + +@torch.no_grad() +def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask): + vq_emb = semantic_model( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + ) + feat = vq_emb.hidden_states[17] # (B, T, C) + feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat) + + semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T) + return semantic_code, rec_feat + + +@torch.no_grad() +def extract_acoustic_code(speech): + vq_emb = codec_encoder(speech.unsqueeze(1)) + _, vq, _, _, _ = codec_decoder.quantizer(vq_emb) + acoustic_code = vq.permute(1, 2, 0) + return acoustic_code + + +@torch.no_grad() +def text2semantic( + device, + prompt_speech, + prompt_text, + prompt_language, + target_text, + target_language, + target_len=None, + n_timesteps=50, + cfg=2.5, + rescale_cfg=0.75, +): + + prompt_phone_id = g2p_(prompt_text, prompt_language)[1] + + target_phone_id = g2p_(target_text, target_language)[1] + + if target_len < 0: + target_len = int( + (len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id)) + / 16000 + * 50 + ) + else: + target_len = int(target_len * 50) + + prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device) + target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) + + phone_id = torch.cat([prompt_phone_id, target_phone_id]) + + input_fetures, attention_mask = extract_features(prompt_speech, processor) + input_fetures = input_fetures.unsqueeze(0).to(device) + attention_mask = attention_mask.unsqueeze(0).to(device) + semantic_code, _ = extract_semantic_code( + semantic_mean, semantic_std, input_fetures, attention_mask + ) + + predict_semantic = t2s_model.reverse_diffusion( + semantic_code[:, :], + target_len, + phone_id.unsqueeze(0), + n_timesteps=n_timesteps, + cfg=cfg, + rescale_cfg=rescale_cfg, + ) + + combine_semantic_code = torch.cat([semantic_code[:, :], predict_semantic], dim=-1) + prompt_semantic_code = semantic_code + + return combine_semantic_code, prompt_semantic_code + + +@torch.no_grad() +def semantic2acoustic( + device, + combine_semantic_code, + acoustic_code, + n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + cfg=2.5, + rescale_cfg=0.75, +): + + semantic_code = combine_semantic_code + + cond = s2a_model_1layer.cond_emb(semantic_code) + prompt = acoustic_code[:, :, :] + predict_1layer = s2a_model_1layer.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=n_timesteps[:1], + cfg=cfg, + rescale_cfg=rescale_cfg, + ) + + cond = s2a_model_full.cond_emb(semantic_code) + prompt = acoustic_code[:, :, :] + predict_full = s2a_model_full.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=n_timesteps, + cfg=cfg, + rescale_cfg=rescale_cfg, + gt_code=predict_1layer, + ) + + vq_emb = codec_decoder.vq2emb(predict_full.permute(2, 0, 1), n_quantizers=12) + recovered_audio = codec_decoder(vq_emb) + prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2, 0, 1), n_quantizers=12) + recovered_prompt_audio = codec_decoder(prompt_vq_emb) + recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() + recovered_audio = recovered_audio[0][0].cpu().numpy() + combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio]) + + return combine_audio, recovered_audio + + +# Load the model and checkpoints +def load_models(): + cfg_path = "./models/tts/maskgct/config/maskgct.json" + + cfg = load_config(cfg_path) + semantic_model, semantic_mean, semantic_std = build_semantic_model(device) + semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device) + codec_encoder, codec_decoder = build_acoustic_codec( + cfg.model.acoustic_codec, device + ) + t2s_model = build_t2s_model(cfg.model.t2s_model, device) + s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device) + s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device) + print("Models built successfully.") + + # Download checkpoints + semantic_code_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="semantic_codec/model.safetensors" + ) + codec_encoder_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="acoustic_codec/model.safetensors" + ) + codec_decoder_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors" + ) + t2s_model_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="t2s_model/model.safetensors" + ) + s2a_1layer_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors" + ) + s2a_full_ckpt = hf_hub_download( + "amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors" + ) + print("Checkpoints downloaded successfully.") + + safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) + safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt) + safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt) + # accelerate.load_checkpoint_and_dispatch(codec_encoder, codec_encoder_ckpt) + # accelerate.load_checkpoint_and_dispatch(codec_decoder, codec_decoder_ckpt) + safetensors.torch.load_model(t2s_model, t2s_model_ckpt) + safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt) + safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt) + print("Checkpoints loaded successfully.") + + return ( + semantic_model, + semantic_mean, + semantic_std, + semantic_codec, + codec_encoder, + codec_decoder, + t2s_model, + s2a_model_1layer, + s2a_model_full, + ) + + +@torch.no_grad() +def maskgct_inference( + prompt_speech_path, + target_text, + target_len=None, + n_timesteps=25, + cfg=2.5, + rescale_cfg=0.75, + n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + cfg_s2a=2.5, + rescale_cfg_s2a=0.75, + device=torch.device("cuda:0"), +): + speech_16k = librosa.load(prompt_speech_path, sr=16000)[0] + speech = librosa.load(prompt_speech_path, sr=24000)[0] + print("Audio loaded.") + + prompt_language = detect_speech_language(prompt_speech_path) + full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text( + prompt_speech_path, prompt_language + ) + # use the first 4+ seconds wav as the prompt in case the prompt wav is too long + speech = speech[0 : int(shot_prompt_end_ts * 24000)] + speech_16k = speech_16k[0 : int(shot_prompt_end_ts * 16000)] + + target_language = detect_text_language(target_text) + combine_semantic_code, _ = text2semantic( + device, + speech_16k, + short_prompt_text, + prompt_language, + target_text, + target_language, + target_len, + n_timesteps, + cfg, + rescale_cfg, + ) + acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) + _, recovered_audio = semantic2acoustic( + device, + combine_semantic_code, + acoustic_code, + n_timesteps=n_timesteps_s2a, + cfg=cfg_s2a, + rescale_cfg=rescale_cfg_s2a, + ) + + return recovered_audio + + +@spaces.GPU +def inference( + prompt_wav, + target_text, + target_len, + n_timesteps, +): + print(f"===== New task submitted =====") + global output_file_name_idx + save_path = f"./output/output_{output_file_name_idx}.wav" + os.makedirs("./output", exist_ok=True) + print(f"Start inference...") + recovered_audio = maskgct_inference( + prompt_wav, + target_text, + target_len=target_len, + n_timesteps=int(n_timesteps), + device=device, + ) + print(f"Saved: {save_path}") + sf.write(save_path, recovered_audio, 24000) + output_file_name_idx = (output_file_name_idx + 1) % 100 + return save_path + + +# Load models once +( + semantic_model, + semantic_mean, + semantic_std, + semantic_codec, + codec_encoder, + codec_decoder, + t2s_model, + s2a_model_1layer, + s2a_model_full, +) = load_models() + +# Language list +language_list = ["en", "zh", "ja", "ko", "fr", "de"] + +# Gradio interface +iface = gr.Interface( + fn=inference, + inputs=[ + gr.Audio(label="Upload Prompt Wav", type="filepath"), + gr.Textbox(label="Target Text"), + gr.Number( + label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", + value=-1, + ), # Removed 'optional=True' + gr.Slider( + label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1 + ), + ], + outputs=gr.Audio(label="Generated Audio"), + title="MaskGCT TTS Demo", + description=""" + [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct) + """, +) + +# Launch the interface +iface.launch(allowed_paths=["./output"]) diff --git a/models/tts/maskgct/requirements.txt b/models/tts/maskgct/requirements.txt index 9db71fb8..877151f5 100644 --- a/models/tts/maskgct/requirements.txt +++ b/models/tts/maskgct/requirements.txt @@ -1,15 +1,15 @@ setuptools onnxruntime torch==2.0.1 -transformers===4.41.1 -tensorboard -tensorboardX -accelerate==0.31.0 +transformers==4.41.2 +accelerate==0.24.1 unidecode -numpy==1.23.5 +numpy==1.26.0 +scipy==1.12.0 librosa -encodecphonemizer +encodec +phonemizer g2p_en jieba cn2an @@ -22,3 +22,7 @@ json5 black==24.1.1 ruamel.yaml tqdm + +spaces +gradio +openai-whisper