@@ -50,15 +50,16 @@ def __init__(
50
50
** kwargs ,
51
51
):
52
52
super ().__init__ (** kwargs )
53
+ self .sampling_rate = sampling_rate
53
54
self .melspec_kwargs = {
54
55
"sample_rate" : sampling_rate ,
55
56
"n_fft" : n_fft ,
56
57
"win_length" : win_length ,
57
58
"hop_length" : hop_length ,
58
59
"n_mels" : n_mels ,
59
60
}
60
- # Currently lazily initialized
61
- self .melspec = None
61
+ requires_backends ( self , [ "torchaudio" ])
62
+ self .mel_filters = torchaudio . transforms . MelSpectrogram ( ** self . melspec_kwargs )
62
63
self .projector_window_size = projector_window_size
63
64
self .projector_downsample_rate = projector_downsample_rate
64
65
@@ -91,34 +92,16 @@ def __call__(
91
92
).view (- 1 , 1 )
92
93
return BatchFeature (data = speech_inputs )
93
94
94
- def _ensure_melspec_transform_is_initialized (self ):
95
- """
96
- Ensures the mel spectrogram transform on this instance is initialized.
97
-
98
- We do this for now since some logging explodes since the mel spectrogram
99
- transform is not JSON serializable.
100
- """
101
- requires_backends (self , ["torchaudio" ])
102
-
103
- if self .melspec is None :
104
- # TODO (@alex-jw-brooks / @eustlb) move this to common batch
105
- # feature extraction in audio utils once they are written!
106
- self .melspec = torchaudio .transforms .MelSpectrogram (** self .melspec_kwargs )
107
-
108
95
def _extract_mel_spectrograms (self , audio : "torch.Tensor" , device = "cpu" ):
109
96
"""
110
97
Compute the Mel features to be passed to the conformer encoder.
111
98
"""
112
99
requires_backends (self , ["torchaudio" ])
113
-
114
- # Initialize the mel spectrogram if isn't not already and
115
- # move the melspec / audio to the computation device.
116
- self ._ensure_melspec_transform_is_initialized ()
117
100
if device is not None :
118
- melspec = self .melspec .to (device )
101
+ melspec = self .mel_filters .to (device )
119
102
audio = audio .to (device )
120
103
else :
121
- melspec = self .melspec
104
+ melspec = self .mel_filters
122
105
123
106
bsz = audio .shape [0 ]
124
107
with torch .no_grad ():
0 commit comments