Skip to content

Commit be10d4d

Browse files
authored
Granite speech - minor fixes to support training with the HF trainer (#38833)
* ensure the query is updated during training avoid unused parameters that DDP does not like * avoid a crash when `kwargs` contain `padding=True` trainers often pass this argument automatically * minor * Remove mel_spec lazy init, and rename to mel_filters. this ensures save_pretrained will not crash when saving the processor during training https://github.com/huggingface/transformers/blob/d5d007a1a0f0c11a726a54c8f00bd71825f84d02/src/transformers/feature_extraction_utils.py#L595 * minor - most feature extractors has a `sampling_rate` property
1 parent e1e11b0 commit be10d4d

File tree

3 files changed

+9
-24
lines changed

3 files changed

+9
-24
lines changed

src/transformers/models/granite_speech/feature_extraction_granite_speech.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@ def __init__(
5050
**kwargs,
5151
):
5252
super().__init__(**kwargs)
53+
self.sampling_rate = sampling_rate
5354
self.melspec_kwargs = {
5455
"sample_rate": sampling_rate,
5556
"n_fft": n_fft,
5657
"win_length": win_length,
5758
"hop_length": hop_length,
5859
"n_mels": n_mels,
5960
}
60-
# Currently lazily initialized
61-
self.melspec = None
61+
requires_backends(self, ["torchaudio"])
62+
self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
6263
self.projector_window_size = projector_window_size
6364
self.projector_downsample_rate = projector_downsample_rate
6465

@@ -91,34 +92,16 @@ def __call__(
9192
).view(-1, 1)
9293
return BatchFeature(data=speech_inputs)
9394

94-
def _ensure_melspec_transform_is_initialized(self):
95-
"""
96-
Ensures the mel spectrogram transform on this instance is initialized.
97-
98-
We do this for now since some logging explodes since the mel spectrogram
99-
transform is not JSON serializable.
100-
"""
101-
requires_backends(self, ["torchaudio"])
102-
103-
if self.melspec is None:
104-
# TODO (@alex-jw-brooks / @eustlb) move this to common batch
105-
# feature extraction in audio utils once they are written!
106-
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
107-
10895
def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
10996
"""
11097
Compute the Mel features to be passed to the conformer encoder.
11198
"""
11299
requires_backends(self, ["torchaudio"])
113-
114-
# Initialize the mel spectrogram if isn't not already and
115-
# move the melspec / audio to the computation device.
116-
self._ensure_melspec_transform_is_initialized()
117100
if device is not None:
118-
melspec = self.melspec.to(device)
101+
melspec = self.mel_filters.to(device)
119102
audio = audio.to(device)
120103
else:
121-
melspec = self.melspec
104+
melspec = self.mel_filters
122105

123106
bsz = audio.shape[0]
124107
with torch.no_grad():

src/transformers/models/granite_speech/modeling_granite_speech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8383
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
8484

8585
query_output = self.qformer(
86-
query_embeds=self.query.data,
86+
query_embeds=self.query,
8787
encoder_hidden_states=hidden_states,
8888
encoder_attention_mask=None,
8989
return_dict=True,

src/transformers/models/granite_speech/processing_granite_speech.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def __call__(
8888
else:
8989
audio_inputs = {}
9090

91-
text_inputs = self.tokenizer(prompt_strings, padding=True, **kwargs)
91+
if "padding" not in kwargs:
92+
kwargs["padding"] = True
93+
text_inputs = self.tokenizer(prompt_strings, **kwargs)
9294
return BatchFeature(data={**text_inputs, **audio_inputs})
9395

9496
def _get_validated_text(self, text: Union[str, list]) -> list[str]:

0 commit comments

Comments
 (0)