Skip to content

Commit 6bdd4ec

Browse files
authored
Add kyutai stt (#38909)
* first draft * cleaner version * udpate tests + modeling * add tests * init * udpate test_modeling_common * fix tests * csm Processor draft * convertion update * mimi cache padding convolutions draft * mimi streaming udpates * update mimi padding cache test * udpate cache padding mimi test * make style mimi * updates generate moshi asr * moshi asr integration tests (single + batched) * update tests * update conversion script * good default sliding window value * udpdate generate * update test checkpoint * nit * fix mimi * fix codec prefix * revert * revert * update config * update config * unnecessary mimi input restriction * remove delay in tokens * remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask * test update * modular update * make style * nit * rename * create codec model generation config at init * remove delay * max_new_tokens/length warning * correct conv1 padding cache import for modular * nit * fix on encoder_past_key_values * convert modular * move frame_size to config * move frame_size to config * update test name * handle first token is bos * better handling of max_new_tokens * fix * fix batch size in test input prep * update docstring * convert modular * make style * make style * add feature extractor * correct modular convention name for feature_extraction file * update convertion script * doc processor * update doc * udpate init * update model type * fixes * update tests * fix * make * add doc * nit * fix * doc * auto mappings * doc * nit * convert modular * doc * nit * extend _keep_in_fp32_modules to enforce fp32 * renaming to stt * doc update + test update * doc fixes * doc fix * doc fix * fix musicgen tests * fix musicgen tests * make style * fix musicgen tests * correct frame_rate config param for mimi * update mimi test * revert update mimi test * enforce cpu test * move cache init in cache class * convert modular * docstring update * update model id * feature_extractor -> feature_extraction (SEW) * convert modular * update model id
1 parent 08bf7f1 commit 6bdd4ec

23 files changed

+3999
-199
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,8 @@
843843
title: GraniteSpeech
844844
- local: model_doc/hubert
845845
title: Hubert
846+
- local: model_doc/stt
847+
title: Kyutai Speech-To-Text
846848
- local: model_doc/mctct
847849
title: MCTCT
848850
- local: model_doc/mimi

docs/source/en/model_doc/stt.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Kyutai Speech-To-Text
18+
## Overview
19+
20+
Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutai’s lab has released two model checkpoints:
21+
- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French
22+
- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy
23+
24+
<div class="flex justify-center">
25+
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/kyutai_stt.png"/>
26+
</div>
27+
28+
## Usage Tips
29+
30+
### Inference
31+
32+
```python
33+
import torch
34+
from datasets import load_dataset, Audio
35+
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
36+
37+
# 1. load the model and the processor
38+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
39+
model_id = "kyutai/stt-2.6b-en"
40+
41+
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
42+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
43+
44+
# 2. load audio samples
45+
ds = load_dataset(
46+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
47+
)
48+
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
49+
50+
# 3. prepare the model inputs
51+
inputs = processor(
52+
ds[0]["audio"]["array"],
53+
)
54+
inputs.to(torch_device)
55+
56+
# 4. infer the model
57+
output_tokens = model.generate(**inputs)
58+
59+
# 5. decode the generated tokens
60+
print(processor.batch_decode(output_tokens, skip_special_tokens=True))
61+
```
62+
63+
### Batched Inference
64+
65+
```python
66+
import torch
67+
from datasets import load_dataset, Audio
68+
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
69+
70+
# 1. load the model and the processor
71+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
72+
model_id = "kyutai/stt-2.6b-en"
73+
74+
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
75+
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
76+
77+
# 2. load audio samples
78+
ds = load_dataset(
79+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
80+
)
81+
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
82+
83+
# 3. prepare the model inputs
84+
audio_arrays = [ds[i]["audio"]["array"] for i in range(4)]
85+
inputs = processor(audio_arrays, return_tensors="pt", padding=True)
86+
inputs = inputs.to(torch_device)
87+
88+
# 4. infer the model
89+
output_tokens = model.generate(**inputs)
90+
91+
# 5. decode the generated tokens
92+
decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True)
93+
for output in decoded_outputs:
94+
print(output)
95+
```
96+
97+
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
98+
The original code can be found [here](https://github.com/kyutai-labs/moshi).
99+
100+
101+
## KyutaiSpeechToTextConfig
102+
103+
[[autodoc]] KyutaiSpeechToTextConfig
104+
105+
## KyutaiSpeechToTextProcessor
106+
107+
[[autodoc]] KyutaiSpeechToTextProcessor
108+
- __call__
109+
110+
## KyutaiSpeechToTextFeatureExtractor
111+
112+
[[autodoc]] KyutaiSpeechToTextFeatureExtractor
113+
114+
## KyutaiSpeechToTextForConditionalGeneration
115+
116+
[[autodoc]] KyutaiSpeechToTextForConditionalGeneration
117+
- forward
118+
- generate
119+
120+
## KyutaiSpeechToTextModel
121+
122+
[[autodoc]] KyutaiSpeechToTextModel

src/transformers/modeling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4658,8 +4658,11 @@ def from_pretrained(
46584658
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
46594659
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
46604660
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
4661+
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
46614662
if model._keep_in_fp32_modules is not None and (
4662-
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
4663+
torch_dtype == torch.float16
4664+
or torch_dtype == torch.bfloat16
4665+
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
46634666
):
46644667
# We need to match exact layers, so we add either `.` on each side, or start/end of string
46654668
keep_in_fp32_regex = re.compile(

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@
285285
from .squeezebert import *
286286
from .stablelm import *
287287
from .starcoder2 import *
288+
from .stt import *
288289
from .superglue import *
289290
from .superpoint import *
290291
from .swiftformer import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@
322322
("squeezebert", "SqueezeBertConfig"),
323323
("stablelm", "StableLmConfig"),
324324
("starcoder2", "Starcoder2Config"),
325+
("stt", "KyutaiSpeechToTextConfig"),
325326
("superglue", "SuperGlueConfig"),
326327
("superpoint", "SuperPointConfig"),
327328
("swiftformer", "SwiftFormerConfig"),
@@ -707,6 +708,7 @@
707708
("squeezebert", "SqueezeBERT"),
708709
("stablelm", "StableLm"),
709710
("starcoder2", "Starcoder2"),
711+
("stt", "KyutaiSpeechToText"),
710712
("superglue", "SuperGlue"),
711713
("superpoint", "SuperPoint"),
712714
("swiftformer", "SwiftFormer"),

src/transformers/models/auto/feature_extraction_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
("sew-d", "Wav2Vec2FeatureExtractor"),
9292
("speech_to_text", "Speech2TextFeatureExtractor"),
9393
("speecht5", "SpeechT5FeatureExtractor"),
94+
("stt", "KyutaiSpeechToTextFeatureExtractor"),
9495
("swiftformer", "ViTFeatureExtractor"),
9596
("swin", "ViTFeatureExtractor"),
9697
("swinv2", "ViTFeatureExtractor"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@
300300
("squeezebert", "SqueezeBertModel"),
301301
("stablelm", "StableLmModel"),
302302
("starcoder2", "Starcoder2Model"),
303+
("stt", "KyutaiSpeechToTextModel"),
303304
("superglue", "SuperGlueForKeypointMatching"),
304305
("swiftformer", "SwiftFormerModel"),
305306
("swin", "SwinModel"),
@@ -1055,6 +1056,7 @@
10551056
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
10561057
("speech_to_text", "Speech2TextForConditionalGeneration"),
10571058
("speecht5", "SpeechT5ForSpeechToText"),
1059+
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
10581060
("whisper", "WhisperForConditionalGeneration"),
10591061
]
10601062
)

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
("speech_to_text", "Speech2TextProcessor"),
117117
("speech_to_text_2", "Speech2Text2Processor"),
118118
("speecht5", "SpeechT5Processor"),
119+
("stt", "KyutaiSpeechToTextProcessor"),
119120
("trocr", "TrOCRProcessor"),
120121
("tvlt", "TvltProcessor"),
121122
("tvp", "TvpProcessor"),

src/transformers/models/mimi/configuration_mimi.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class MimiConfig(PretrainedConfig):
3838
Args:
3939
sampling_rate (`int`, *optional*, defaults to 24000):
4040
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
41-
frame_rate (`float`, *optional*, defaults to 12.5):
42-
Framerate of the model.
41+
frame_rate (`float`, *optional*):
42+
Should be computed from the other parameters, yet kept for backward compatibility.
4343
audio_channels (`int`, *optional*, defaults to 1):
4444
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
4545
hidden_size (`int`, *optional*, defaults to 512):
@@ -111,6 +111,8 @@ class MimiConfig(PretrainedConfig):
111111
use_cache (`bool`, *optional*, defaults to `False`):
112112
Whether or not the model should return the last key/values attentions (not used by all models). Only
113113
relevant if `config.is_decoder=True`.
114+
use_streaming (`bool`, *optional*, defaults to `False`):
115+
Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
114116
rope_theta (`float`, *optional*, defaults to 10000.0):
115117
The base period of the RoPE embeddings.
116118
sliding_window (`int`, *optional*, defaults to 250):
@@ -141,7 +143,7 @@ class MimiConfig(PretrainedConfig):
141143
def __init__(
142144
self,
143145
sampling_rate=24_000,
144-
frame_rate=12.5,
146+
frame_rate=None,
145147
audio_channels=1,
146148
hidden_size=512,
147149
num_filters=64,
@@ -172,6 +174,7 @@ def __init__(
172174
initializer_range=0.02,
173175
norm_eps=1e-5,
174176
use_cache=False,
177+
use_streaming=False,
175178
rope_theta=10000.0,
176179
sliding_window=250,
177180
attention_dropout=0.0,
@@ -180,7 +183,6 @@ def __init__(
180183
**kwargs,
181184
):
182185
self.sampling_rate = sampling_rate
183-
self.frame_rate = frame_rate
184186
self.audio_channels = audio_channels
185187
self.hidden_size = hidden_size
186188
self.num_filters = num_filters
@@ -209,13 +211,22 @@ def __init__(
209211
self.initializer_range = initializer_range
210212
self.norm_eps = norm_eps
211213
self.use_cache = use_cache
214+
self.use_streaming = use_streaming
212215
self.rope_theta = rope_theta
213216
self.sliding_window = sliding_window
214217
self.attention_dropout = attention_dropout
215218
self.head_dim = head_dim or hidden_size // num_attention_heads
216219
self.layer_scale_initial_scale = layer_scale_initial_scale
217220
self.attention_bias = attention_bias
218221

222+
# Handle backward compatibility for frame_rate:
223+
# If frame_rate is explicitly provided, use it (backward compatibility)
224+
# Otherwise, compute it from other parameters (correctly)
225+
if frame_rate is not None:
226+
self._frame_rate = frame_rate
227+
else:
228+
self._frame_rate = None
229+
219230
if num_semantic_quantizers >= self.num_quantizers:
220231
raise ValueError(
221232
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
@@ -233,5 +244,36 @@ def num_codebooks(self) -> int:
233244
# alias to num_quantizers
234245
return self.num_quantizers
235246

247+
@property
248+
def frame_size(self) -> int:
249+
# 1. we need each encoder conv stride
250+
# first conv
251+
strides = [1]
252+
253+
# layer convs
254+
for ratio in reversed(self.upsampling_ratios):
255+
for j in range(self.num_residual_layers):
256+
len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
257+
strides.extend([1] * (len_kernel_sizes + 1))
258+
if self.use_conv_shortcut: # skip connection
259+
strides.append(1)
260+
261+
strides.append(ratio)
262+
263+
# last conv
264+
strides.append(1)
265+
266+
# downsampling layer
267+
strides.append(2)
268+
269+
return math.prod(strides)
270+
271+
@property
272+
def frame_rate(self) -> float:
273+
# handle backward compatibility
274+
if self._frame_rate is not None:
275+
return self._frame_rate
276+
return self.sampling_rate / self.frame_size
277+
236278

237279
__all__ = ["MimiConfig"]

0 commit comments

Comments
 (0)