Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: coqui-ai/TTS
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: dev
Choose a base ref
...
head repository: alphacep/TTS
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: dev
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 1 commit
  • 2 files changed
  • 1 contributor

Commits on Apr 8, 2024

  1. Batch inference

    nshmyrev committed Apr 8, 2024
    Copy the full SHA
    8aeac30 View commit details
Showing with 98 additions and 2 deletions.
  1. +20 −2 TTS/tts/layers/xtts/gpt.py
  2. +78 −0 TTS/tts/models/xtts.py
22 changes: 20 additions & 2 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
@@ -562,9 +562,11 @@ def compute_embeddings(
self,
cond_latents,
text_inputs,
text_lens
):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)

emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
emb = torch.cat([cond_latents, emb], dim=1)
self.gpt_inference.store_prefix_emb(emb)
@@ -578,21 +580,37 @@ def compute_embeddings(
device=text_inputs.device,
)
gpt_inputs[:, -1] = self.start_audio_token
return gpt_inputs

attn_mask = torch.full(
(
gpt_inputs.shape[0],
gpt_inputs.shape[1],
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
if text_lens is not None:
for i in range(len(text_lens)):
attn_mask[i, cond_latents.shape[1] + text_lens[i] + 1 :] = 0

return gpt_inputs, attn_mask

def generate(
self,
cond_latents,
text_inputs,
text_lens=None,
**hf_generate_kwargs,
):
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
gpt_inputs, attn_mask = self.compute_embeddings(cond_latents, text_inputs, text_lens)
gen = self.gpt_inference.generate(
gpt_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
attention_mask = attn_mask,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
78 changes: 78 additions & 0 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -582,6 +582,84 @@ def inference(
"speaker_embedding": speaker_embedding,
}

@torch.inference_mode()
def inference_batch(
self,
texts,
size,
language,
gpt_cond_latents,
speaker_embeddings,
# GPT inference
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
**hf_generate_kwargs,
):
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
gpt_cond_latents = torch.stack(gpt_cond_latents).to(self.device)
speaker_embeddings = torch.stack(speaker_embeddings).unsqueeze(-1).to(self.device)
text_tokens = [torch.IntTensor(self.tokenizer.encode(text.strip().lower(), lang=language)) for text in texts]
max_size = max(token.size(0) for token in text_tokens)
padded_text_tokens = torch.IntTensor(size, max_size).zero_()
text_lens = torch.IntTensor(size)
for i, token in enumerate(text_tokens):
text_lens[i] = token.size(0)
padded_text_tokens[i, : text_lens[i] ] = token
padded_text_tokens = padded_text_tokens.to(self.device)
text_lens = text_lens.to(self.device)

assert (
padded_text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."

with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=gpt_cond_latents,
text_inputs=padded_text_tokens,
text_lens=text_lens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
length_penalty=length_penalty,
num_return_sequences=self.gpt_batch_size,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.argmax((gpt_codes == self.gpt.stop_audio_token).to(dtype=torch.int), dim=1) * self.gpt.code_stride_len

gpt_latents = self.gpt(
padded_text_tokens,
text_lens,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latents,
return_attentions=False,
return_latent=True,
)

if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)

wavs = self.hifigan_decoder(gpt_latents, g=speaker_embeddings).cpu()

out_wavs = []
for i in range(size):
out_wavs.append(wavs[i, :, :expected_output_len[i].item() + 20 * self.gpt.code_stride_len]) # Extra to fit the last silence
return out_wavs

def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
"""Handle chunk formatting in streaming mode"""
wav_chunk = wav_gen[:-overlap_len]