Skip to content

Conversation

@SpenserCai
Copy link
Contributor

@SpenserCai SpenserCai commented Jan 5, 2026

Summary

This PR adds complete support for CosyVoice3, a state-of-the-art multilingual zero-shot text-to-speech model from FunAudioLLM. CosyVoice3 supports multiple synthesis modes including zero-shot voice cloning, cross-lingual synthesis, and instruction-guided generation.

Features

Model Architecture

CosyVoice3 consists of four main components:

  1. Frontend - Text tokenization (Qwen2 tiktoken), mel spectrogram extraction, and speaker embedding extraction via ONNX models
  2. CosyVoice3LM - Qwen2-based autoregressive language model for speech token generation
  3. Flow Decoder - 22-layer Diffusion Transformer (DiT) with Conditional Flow Matching (CFM) for mel spectrogram generation
  4. HiFT Vocoder - Neural Source Filter (NSF) based vocoder with iSTFT for waveform synthesis

Supported Modes

  • Zero-Shot: Voice cloning from a reference audio sample
  • Cross-Lingual: Voice cloning across different languages
  • Instruct: Instruction-guided speech synthesis

Key Implementation Details

  • Full native Rust implementation of all model components
  • Support for both CPU and GPU (Metal/CUDA) inference
  • Optional ONNX integration for speaker embedding and speech tokenization
  • Kaldi-compatible fbank feature extraction
  • 24kHz audio output

Changes

candle-transformers

Added new cosyvoice module with the following structure:

candle-transformers/src/models/cosyvoice/
├── mod.rs                    # Module exports and documentation
├── config.rs                 # Configuration structures
├── activations.rs            # Snake, SnakeBeta, Mish, ELU activations
├── flow/
│   ├── dit.rs                # Diffusion Transformer (22 layers)
│   ├── embeddings.rs         # AdaLayerNorm, RoPE, Timestep embeddings
│   ├── flow_matching.rs      # Conditional Flow Matching with Euler ODE solver
│   └── pre_lookahead.rs      # PreLookaheadLayer for causal processing
├── frontend/
│   ├── audio.rs              # MelSpectrogram, KaldiFbank, resampling
│   └── onnx_models.rs        # ONNX-based feature extraction (optional)
├── llm/
│   └── cosyvoice3_lm.rs      # Qwen2-based speech token generation
└── vocoder/
    ├── f0_predictor.rs       # Causal F0 prediction
    ├── hift_generator.rs     # HiFi-GAN based generator
    ├── istft.rs              # Inverse STFT for waveform synthesis
    ├── source_module.rs      # Neural Source Filter (NSF)
    └── stft.rs               # STFT for source signal processing

Total: ~7,200 lines of new Rust code

candle-onnx

Enhanced ONNX operator support required for CosyVoice3's ONNX models (campplus.onnx, speech_tokenizer_v3.onnx):

Operator Enhancement
AveragePool Added 1D pooling support with ceil_mode
Conv Added asymmetric stride/dilation support
Pad Added constant mode with custom value
GRU Full implementation with ONNX weight reordering
Elu New operator implementation
Mod New operator with fmod support
Round New operator implementation
ReduceMean Fixed negative axis handling
ReduceProd New operator implementation

candle-examples

Added cosyvoice3 example with:

  • Complete CLI for text-to-speech synthesis
  • Weight conversion script (convert_weights.py)
  • Random noise extraction script (extract_rand_noise.py) for exact reproducibility
  • Support for pre-extracted features or native ONNX extraction

Usage

Basic Usage

# Download pre-converted weights
huggingface-cli download spensercai/CosyVoice3-0.5B-Candle --local-dir weights/CosyVoice3-0.5B-Candle

# With ONNX support (native prompt feature extraction)
cargo run --release --example cosyvoice3 --features="symphonia,onnx" -- \
    --text "Hello, this is a test." \
    --prompt-wav /path/to/prompt.wav \
    --model-dir weights/CosyVoice3-0.5B-Candle \
    --output output.wav

# With GPU acceleration (Metal)
cargo run --release --example cosyvoice3 --features="symphonia,onnx,metal" -- \
    --text "Hello, this is a test." \
    --prompt-wav /path/to/prompt.wav \
    --model-dir weights/CosyVoice3-0.5B-Candle \
    --output output.wav

# Without ONNX (using pre-extracted features)
cargo run --release --example cosyvoice3 --features="symphonia" -- \
    --text "Hello, this is a test." \
    --prompt-features /path/to/features.safetensors \
    --model-dir weights/CosyVoice3-0.5B-Candle \
    --output output.wav

Weight Conversion

If you prefer to convert weights manually from the original PyTorch model:

python candle-examples/examples/cosyvoice3/convert_weights.py \
    --input weights/Fun-CosyVoice3-0.5B-2512 \
    --output weights/CosyVoice3-0.5B-Candle

Random Noise Extraction (Optional)

For exact numerical reproducibility with the Python implementation, you can extract the pre-computed random noise:

python candle-examples/examples/cosyvoice3/extract_rand_noise.py \
    --output weights/CosyVoice3-0.5B-Candle/rand_noise.safetensors

Note: This file is optional and already included in the pre-converted weights on Hugging Face. Without it, the Candle implementation generates its own deterministic noise using a fallback algorithm. The generated audio will be equally valid but may differ slightly from the Python implementation's output.

Programmatic Usage

use candle_transformers::models::cosyvoice::{
    CosyVoice3LM, CausalMaskedDiffWithDiT, CausalHiFTGenerator,
    DiT, FlowConfig, HiFTConfig, SamplingConfig,
};

// Load model components
let llm = CosyVoice3LM::new(&llm_config, llm_vb)?;
let flow_decoder = CausalMaskedDiffWithDiT::new(...)?;
let vocoder = CausalHiFTGenerator::new(hift_config, hift_vb)?;

// Generate speech tokens
let speech_tokens = llm.inference(
    &text_tokens,
    &prompt_text_tokens,
    &prompt_speech_tokens,
    &sampling_config,
)?;

// Convert to mel spectrogram
let mel = flow_decoder.inference(
    &speech_tokens,
    &prompt_speech_tokens,
    &prompt_mel,
    &speaker_embedding,
    10, // n_timesteps
    false,
)?;

// Generate waveform
let waveform = vocoder.inference(&mel, true)?;

Model Weights

Pre-converted Weights (Recommended)

Pre-converted weights are available on Hugging Face:

spensercai/CosyVoice3-0.5B-Candle

# Download using huggingface-cli
huggingface-cli download spensercai/CosyVoice3-0.5B-Candle --local-dir weights/CosyVoice3-0.5B-Candle

Manual Conversion

Alternatively, convert from the original Fun-CosyVoice3-0.5B-2512 using the provided script.

Performance

Device RTF (Real-Time Factor)
Apple M1 Pro (Metal) ~0.3-0.5x
CPU (x86_64) ~2-4x

RTF < 1.0 means faster than real-time

Technical Notes

Kaldi Fbank Compatibility

The mel spectrogram extraction follows Kaldi's fbank implementation with:

  • Povey window function
  • Pre-emphasis filtering
  • DC offset removal
  • HTK-style mel scale

Flow Matching

Uses Conditional Flow Matching (CFM) with:

  • Cosine timestep scheduler
  • 10-step Euler ODE solver
  • Optional Classifier-Free Guidance (CFG)

HiFT Vocoder

Neural Source Filter based vocoder with:

  • F0 prediction from mel features
  • Harmonic + noise source generation
  • 16-point iSTFT for efficient waveform synthesis
  • Snake activation functions

Dependencies

  • candle-core
  • candle-nn
  • candle-onnx (optional, for native feature extraction)
  • tokenizers (for Qwen2 tokenization)
  • symphonia (optional, for audio decoding)

References

Checklist

  • Model implementation complete
  • Weight conversion script
  • Example CLI
  • ONNX operator enhancements
  • Documentation
  • Unit tests

candle-onnx Enhancements (Detailed)

This PR also includes significant enhancements to candle-onnx to support the ONNX models used by CosyVoice3:

New Operators

GRU (Gated Recurrent Unit)

Full implementation of the ONNX GRU operator with:

  • Forward direction support
  • Proper weight reordering (ONNX uses zrh order, candle-nn uses rzn)
  • Separate input/recurrent biases
  • Optional initial hidden state
// ONNX GRU weight order: z(update), r(reset), h(hidden)
// candle-nn order: r(reset), z(update), n(hidden)

Elu (Exponential Linear Unit)

// f(x) = x if x > 0, alpha * (exp(x) - 1) if x <= 0

Mod (Modulo)

Supports both floor division (fmod=0) and truncated division (fmod=1) modes.

Round

Banker's rounding (round to nearest even).

ReduceProd

Product reduction along specified axes with keepdims support.

Enhanced Operators

AveragePool

  • Added 1D pooling support (previously only 2D)
  • Added ceil_mode support for output size calculation
  • Handles edge cases where kernel > input size

Conv

  • Added asymmetric stride support (different H/W strides)
  • Added asymmetric dilation support
  • Uses index_select for subsampling when needed

Pad

  • Added constant mode with custom padding value
  • Previously only supported reflect mode

ReduceMean

  • Fixed negative axis handling using normalize_axis

This implementation was developed and tested against the official Python implementation to ensure numerical accuracy.

Development Notes

Verification Process

The implementation was verified against the official Python CosyVoice3 implementation through:

  1. Component-level testing: Each module (LLM, Flow, Vocoder) was tested independently against Python outputs
  2. Intermediate tensor comparison: Key intermediate tensors were compared to ensure numerical consistency
  3. End-to-end audio comparison: Generated audio was compared for quality and similarity

Key Implementation Challenges

  1. Kaldi Fbank Compatibility: Required careful implementation of Povey window, pre-emphasis, and mel filter bank to match torchaudio's kaldi_fbank

  2. ONNX GRU Weight Reordering: ONNX uses (z, r, h) gate order while candle-nn uses (r, z, n), requiring weight tensor reordering

  3. HiFT Weight Norm Fusion: The original PyTorch model uses weight_norm parametrization which needed to be fused during weight conversion

  4. Causal Convolution: Implemented proper causal padding for streaming-compatible inference

ScreenShot

Metal

image

CUDA

65b30ff55f31936a4a507ddc6f01aa66 ee0d51b7ee70f2bba9ebfaf72e528756

@SpenserCai SpenserCai changed the title Cosyvoice support Cosyvoice3 support Jan 5, 2026
- Add CosyVoice3Frontend for extracting speech tokens, speaker embeddings,
  and mel spectrograms directly from audio
- Extend candle-onnx with new operators: Elu, Mod, Round, ReduceProd
- Add AvgPool1d support and fix ReduceSum negative axis handling
- Support asymmetric Conv2d strides and Pad constant mode
- Add optional 'onnx' feature to candle-transformers
…lar filter in the mel domain, matching `get_mel_banks` in `torchaudio`

2. Forward: change to frame-by-frame processing:
- Extract original frames
- DC offset removal (frame-by-frame)
- Pre-emphasis (frame-by-frame, using replicate padding)
 - Povey window
 - FFT
- Mel energy calculation
@SpenserCai SpenserCai changed the title Cosyvoice3 support Add CosyVoice3 Text-to-Speech Model Support Jan 5, 2026
@SpenserCai SpenserCai marked this pull request as ready for review January 5, 2026 12:55
@SpenserCai
Copy link
Contributor Author

All functions of CosyVoice3 have been equivalently migrated and are now fully ready to perform CI/CL checks and code reviews. @ivarflakstad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant