Skip to content

Infinite streaming #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

mrdrprofuroboros
Copy link

@mrdrprofuroboros mrdrprofuroboros commented Apr 3, 2025

Closing #187 in favor of this

So basically this is an long version of seamless streaming, examples below:

Imagine we have something like this to say:

texts = [
    "Rain lashed against the attic window.",
    "Dust motes danced in the single moonbeam slicing the darkness.",
    "A floorboard creaked downstairs.",
    "She held her breath, listening.",
    "Silence answered, heavy and absolute.",
    "Slowly, she lifted the rusted latch on the old trunk.",
    "A faint scent of lavender and forgotten years drifted out.",
    "Inside, nestled on velvet, lay a single, tarnished silver key.",
]
# Happiness, Sadness, Disgust, Fear, Surprise, Anger, Other, Neutral
emotions = [
    [1.0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1.0, 0, 0, 0, 0, 0, 0],
    [0, 0, 1.0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1.0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1.0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1.0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1.0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1.0],
]
pitch_stds = range(80, 161, 10)

A naive approach before was something like this:

from zonos.conditioning import make_cond_dict

all_audio_chunks = []
for text, emotion, pitch_std in zip(texts, emotions, pitch_stds):
    cond_dict = make_cond_dict(
        text=text + " ",
        language="en-us",
        speaker=speaker,
        emotion=emotion,
        pitch_std=pitch_std,
    )
    conditioning = model.prepare_conditioning(cond_dict)
    codes = model.generate(conditioning)
    wavs = model.autoencoder.decode(codes).cpu()
    all_audio_chunks.append(wavs[0])

# Concatenate all audio chunks along the time axis.
audio = torch.cat(all_audio_chunks, dim=-1)
out_sr = model.autoencoder.sampling_rate
torchaudio.save("stitching.wav", audio, out_sr)
display(Audio(data=audio, rate=44100))

The result is quite anticlimactic, something like
https://github.com/user-attachments/assets/b4756b0a-91d1-4ef5-87e3-6cb94b476946

  • clicks and artifacts on borders
  • varying volume, energy and prosody
  • high latency, can't stream realtime

This PR brings:

    def generator():
        # Can stream from your LLM or other source here, just partition the text into
        # sentences with nltk or rule based tokenizer. See example here:
        # https://stackoverflow.com/a/31505798
        for text, emotion, pitch_std in zip(texts, emotions, pitch_stds):
            yield {
                "text": text,
                "speaker": speaker,
                "emotion": emotion,
                "pitch_std": pitch_std,
                "language": "en-us",
            }

    # Define chunk schedule: start with small chunks for faster initial output,
    # then gradually increase to larger chunks for fewer cuts
    stream_generator = model.stream(
        cond_dicts_generator=generator(),
        chunk_schedule=[17, *range(9, 100)],  # optimal schedule for RTX3090
        chunk_overlap=2,  # tokens to overlap between chunks (affects crossfade)
    )

    # Accumulate audio chunks as they are generated.
    audio_chunks = []
    for i, audio_chunk in enumerate(stream_generator):
        audio_chunks.append(audio_chunk)

    audio = torch.cat(audio_chunks, dim=-1).cpu()
    out_sr = model.autoencoder.sampling_rate
    torchaudio.save("streaming.wav", audio, out_sr)
    display(Audio(data=audio, rate=out_sr))

Which gives:

streaming.mp4

Here are some stats for RTX3090

Starting streaming generation...
Chunk   1: elapsed   236ms | generated up to   317ms
Chunk   2: elapsed   315ms | generated up to   421ms
Chunk   3: elapsed   398ms | generated up to   537ms
Chunk   4: elapsed   488ms | generated up to   664ms
Chunk   5: elapsed   586ms | generated up to   803ms
...
Chunk  78: elapsed  9381ms | generated up to 13663ms
TTFB: 236ms, generation: 9.382ms, duration: 13.468ms, RTX: 1.44
Saved streaming audio to 'streaming.wav' (sampling rate: 44100 Hz).

This PR incorporates:

The main idea is super simple and straightforward: we split the text into sentences and use FULL previous sentence audio codes and text as a prefix for the next one. Thus we know exactly where to cut it and how to stitch it. Also we decode tokens in small chunks as they appear during the inference and apply cosine cross-fade to stitch them together eliminating most audible clicks

Unfortunately it's not infinite, since it accumulates the error of the previous generations. It's still usable on roughly <30-40 sec generations, but now they can be streamed with minimal latency

PS
there's also the ability to load models from filesystem (both zonos weights and embedder)

@mrdrprofuroboros mrdrprofuroboros changed the title Infinite streaming long streaming Apr 3, 2025
@coezbek
Copy link
Contributor

coezbek commented Apr 3, 2025

@mrdrprofuroboros Very interesting! I will give this a spin!

Couple of questions:

  • I see that you use the emotion vector. Have you found a way to make good use of it? Even with cfg_scale = 4.0 or 5.0, it is hard for me to notice the 8 emotion dimensions (for instance comparing your sample sentences).
  • I also am still experiencing a lot of clicks at the end of generation. I think it is likely an artifact of training on data which doesn't end properly itself. I have experimented a lot with how to extend generation and reduce likelyhood of EOS, etc., but in the end I use an logarithmic fade-out. Is one token cross-fade really enough for you to make it work?
  • You mention that you see a lot drift over time: Have you tried maybe not to use only the previous sentence as a prefix, but always use the first sentence and the previous sentence as prefix? Maybe always just taking the first sentence would be okay as well. In my experiments the audio prefix has almost no impact on the generated audio anyway.

@darkacorn
Copy link
Contributor

what we really also need to state is, that we just duck-tapeing till 1.5/2.

there is only so much we can get out of a beta release / they are working on it tho ..

@mrdrprofuroboros
Copy link
Author

mrdrprofuroboros commented Apr 3, 2025

@coezbek

  • nope, it's just for demonstration. I wasn't able to get any meaningful control with it unfortunately
  • bigger fade-out is definitely a thing to add, I'll be looking into it next few days. I'm also playing with noisereduce
  • hmm, that is worth checking out, thx. My intuition was that the prosody may drift and if you control emotions / pitch_std and other stuff, you'd want it to glue with the previous segment, not first. but since they are working rather bad in general, trading it for longer more stable generations might be a thing to consider. I'll post what I find out regarding the usage of the first segment always.

@darkacorn
hehe, at this point the whole industry is just duct-taping till ASI :D Anyway, you're rockstars for opensourcing it, we're glad to build products with what is available. Curious to see what's there in the next release, but till then... :)

@mrdrprofuroboros
Copy link
Author

okay, @coezbek 's suggestion about putting the first segment as a start eliminates the degradation and works way better, practically enabling infinite generation
I've generated an independent segments stitching (no prefix used) to compare -- the prefix actually helps to reduce the divergence of segments quite a lot!

from zonos.conditioning import make_cond_dict
from tqdm.auto import tqdm

texts = [
    "The old clock tower hadn't chimed in living memory.",
    "Its stone face, weathered and stained, watched over the perpetually drowsy town.",
    "Elara, however, felt a strange pull towards it.",
    "She often sketched its silhouette in her worn notebook.",
    "One moonless night, a faint, melodic hum vibrated through the cobblestones beneath her feet.",
    "It seemed to emanate from the silent tower.",
    "Driven by a curiosity stronger than fear, she crept towards the heavy oak door.",
    "Surprisingly, it swung open at her touch, revealing a spiral staircase choked with dust.",
    "The air inside was thick with the scent of ozone and something ancient.",
    "She ascended, each step echoing in the profound stillness.",
    "Higher and higher she climbed, the humming growing louder, resonating within her chest.",
    "Finally, she reached the belfry.",
    "Instead of bells, intricate crystalline structures pulsed with soft, blue light.",
    "They hung suspended, rotating slowly, emitting the enchanting melody.",
    "In the center hovered a sphere of swirling energy.",
    "As Elara approached, the humming intensified, the light brightening.",
    "Tendrils of energy reached out from the sphere, brushing against her fingertips.",
    "A flood of images poured into her mind: star charts, forgotten equations, galaxies blooming and dying.",
    "She wasn't just in a clock tower; she was inside a celestial resonator.",
    "It was a device left by travelers from a distant star, waiting for someone attuned to its frequency.",
    "Elara realized the tower hadn't been silent, just waiting.",
    "She raised her hands, not in fear, but in acceptance.",
    "The energy flowed into her, cool and invigorating.",
    "Suddenly, with a resonant *gong*, the tower chimed, a sound unheard for centuries.",
    "Its song wasn't marking time, but awakening possibilities across the cosmos."
]

prefixing = True
first_text = ""
first_codes = None
all_segments = []
whitespace = " "
torch.manual_seed(777)

for text in tqdm(texts):
    cond_dict = make_cond_dict(
        text=first_text + text + whitespace,
        language="en-us",
        speaker=speaker,
        pitch_std=120,
    )
    conditioning = model.prepare_conditioning(cond_dict)
    codes = model.generate(conditioning, first_codes, progress_bar=False)

    if prefixing:
        if first_codes is None:
            first_codes = codes
            first_text = text + whitespace
        else:
            codes = codes[:,:,first_codes.shape[-1]:]

    wavs = model.autoencoder.decode(codes).cpu()
    all_segments.append(wavs[0])

audio = torch.cat(all_segments, dim=-1)
display(Audio(data=audio, rate=44100))

prefixing = False

no-prefix.mp4

prefixing = True

first-as-prefix.mp4

@mrdrprofuroboros
Copy link
Author

pushed the updated version and also added the longer log fading in the ends of sentences. here's an example of 25 sentence generation:

streaming-first-as-prefix.mp4

@mrdrprofuroboros mrdrprofuroboros changed the title long streaming Infinite streaming Apr 4, 2025
@mrdrprofuroboros
Copy link
Author

mrdrprofuroboros commented Apr 4, 2025

okidoke, one more cool update, I reduced initial response latency to 135ms on RTX3090

here's the thing - the model takes some time/tokens to warm up. We can preallocate open streams and feed some warm-up text to them, so once we have real queries, we'd be ready to process them faster

here's what I came up with:

Starting streaming generation...

Yielding sentence 885ms: The old clock tower hadn't chimed in living memory.
Chunk   1: elapsed  1017ms | generated up to  1133ms 
...
Chunk  21: elapsed  3101ms | generated up to  4223ms 
Yielding sentence 3102ms: Its stone face, weathered and stained, watched over the perpetually drowsy town.
Chunk  22: elapsed  3282ms | generated up to  4304ms 
...
Chunk  47: elapsed  6500ms | generated up to  8866ms 
Yielding sentence 6500ms: Elara, however, felt a strange pull towards it.
Chunk  48: elapsed  6685ms | generated up to  9086ms 
...
Chunk  68: elapsed  8961ms | generated up to 12199ms 
Yielding sentence 8961ms: She often sketched its silhouette in her worn notebook.
Chunk  69: elapsed  9143ms | generated up to 12466ms 
...
Chunk  90: elapsed 11297ms | generated up to 15683ms 

TTFB: 1017ms, generation: 11.298ms, duration: 14.71ms, RTX: 1.3
Saved streaming audio to 'stream_sample.wav' (sampling rate: 44100 Hz).

So it took 885ms to warmup, but then from the point I got the first real sentence to the first response chunk of audio it took only 1017 - 885 = 132ms

Giving it a warmup of "And I say OK" doesn't change the prosody/style of the next sentence from what I saw. But at the same time it also does a pretty natural full stop pause without actually full stopping for a long time. you can experiment with other warmup prefills

@darkacorn
Copy link
Contributor

darkacorn commented Apr 4, 2025

okidoke, one more cool update, I reduced initial response latency to 135ms on RTX3090

here's the thing - the model takes some time/tokens to warm up. We can preallocate open streams and feed some warm-up text to them, so once we have real queries, we'd be ready to process them faster

here's what I came up with:

Starting streaming generation...

Yielding sentence 885ms: The old clock tower hadn't chimed in living memory.
Chunk   1: elapsed  1017ms | generated up to  1133ms 
...
Chunk  21: elapsed  3101ms | generated up to  4223ms 
Yielding sentence 3102ms: Its stone face, weathered and stained, watched over the perpetually drowsy town.
Chunk  22: elapsed  3282ms | generated up to  4304ms 
...
Chunk  47: elapsed  6500ms | generated up to  8866ms 
Yielding sentence 6500ms: Elara, however, felt a strange pull towards it.
Chunk  48: elapsed  6685ms | generated up to  9086ms 
...
Chunk  68: elapsed  8961ms | generated up to 12199ms 
Yielding sentence 8961ms: She often sketched its silhouette in her worn notebook.
Chunk  69: elapsed  9143ms | generated up to 12466ms 
...
Chunk  90: elapsed 11297ms | generated up to 15683ms 

TTFB: 1017ms, generation: 11.298ms, duration: 14.71ms, RTX: 1.3
Saved streaming audio to 'stream_sample.wav' (sampling rate: 44100 Hz).

So it took 885ms to warmup, but then from the point I got the first real sentence to the first response chunk of audio it took only 1017 - 885 = 132ms

Giving it a warmup of "And I say OK" doesn't change the prosody/style of the next sentence from what I saw. But at the same time it also does a pretty natural full stop pause without actually full stopping for a long time. you can experiment with other warmup prefills

the warmup is mostly torch compile and building out the cuda graphs / i fyou hop on discord - com is easyer .. [coezbek] (oezi in discord) is in there too

@mrdrprofuroboros
Copy link
Author

*we continued the discussion in discord, but just for the record - when I said "initial response latency" I actually meant TTFB
the 1st response is indeed slower because of torch.compile JIT, I'm speaking about the following ones

@coezbek
Copy link
Contributor

coezbek commented Apr 4, 2025

I made some minor improvements (fix warnings) in my branch https://github.com/coezbek/Zonos/tree/infinite-streaming

@mrdrprofuroboros
Copy link
Author

pulled them in, thank you!
I've tested the full stop with a few examples on different seeds and from what I saw it gave long pauses, so the actual generation started with too much silence. might need more tests, but colon works nice, so I left it

…ensors must match except in dimension 0. Expected size 29 but got size 28 for tensor number 1 in the list.
@SuperPauly
Copy link

@gabrielclark3330 thoughts? Can this be merged?

@constan1
Copy link

constan1 commented May 2, 2025

So im only getting a real time factor of 0.42 on a colab instance on a100.

image

am i missing something? how are you getting >1 rtx on a 4090?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants