@@ -182,6 +182,7 @@ def log_mel_spectrogram(
182
182
return log_spec
183
183
184
184
185
+ # Refactor default params are stored in config.json
185
186
class AudioTransform (torch .nn .Module ):
186
187
def __init__ (
187
188
self ,
@@ -190,10 +191,12 @@ def __init__(
190
191
max_snr : int = 50 ,
191
192
max_dist_gain : int = 25 ,
192
193
min_dist_gain : int = 0 ,
193
- # ratios for the reduction of the audio quality
194
- distort_ratio : float = 0.2 ,
195
- reduce_ratio : float = 0.2 ,
196
- spec_aug_ratio : float = 0.2 ,
194
+ noise_ratio : float = 0.95 ,
195
+ reverb_ratio : float = 0.95 ,
196
+ applause_ratio : float = 0.01 , # CHANGE
197
+ distort_ratio : float = 0.15 ,
198
+ reduce_ratio : float = 0.01 ,
199
+ spec_aug_ratio : float = 0.25 ,
197
200
):
198
201
super ().__init__ ()
199
202
self .tokenizer = AmtTokenizer ()
@@ -208,9 +211,13 @@ def __init__(
208
211
self .chunk_len = self .config ["chunk_len" ]
209
212
self .num_samples = self .sample_rate * self .chunk_len
210
213
211
- self .dist_ratio = distort_ratio
214
+ self .noise_ratio = noise_ratio
215
+ self .reverb_ratio = reverb_ratio
216
+ self .applause_ratio = applause_ratio
217
+ self .distort_ratio = distort_ratio
212
218
self .reduce_ratio = reduce_ratio
213
219
self .spec_aug_ratio = spec_aug_ratio
220
+ self .reduction_resample_rate = 6000 # Hardcoded?
214
221
215
222
# Audio aug
216
223
impulse_paths = self ._get_paths (
@@ -219,6 +226,9 @@ def __init__(
219
226
noise_paths = self ._get_paths (
220
227
os .path .join (os .path .dirname (__file__ ), "assets" , "noise" )
221
228
)
229
+ applause_paths = self ._get_paths (
230
+ os .path .join (os .path .dirname (__file__ ), "assets" , "applause" )
231
+ )
222
232
223
233
# Register impulses and noises as buffers
224
234
self .num_impulse = 0
@@ -231,6 +241,11 @@ def __init__(
231
241
self .register_buffer (f"noise_{ i } " , noise )
232
242
self .num_noise += 1
233
243
244
+ self .num_applause = 0
245
+ for i , applause in enumerate (self ._get_noise (applause_paths )):
246
+ self .register_buffer (f"applause_{ i } " , applause )
247
+ self .num_applause += 1
248
+
234
249
self .spec_transform = torchaudio .transforms .Spectrogram (
235
250
n_fft = self .config ["n_fft" ],
236
251
hop_length = self .config ["hop_len" ],
@@ -321,15 +336,37 @@ def apply_noise(self, wav: torch.tensor):
321
336
322
337
return AF .add_noise (waveform = wav , noise = noise , snr = snr_dbs )
323
338
339
+ def apply_applause (self , wav : torch .tensor ):
340
+ batch_size , _ = wav .shape
341
+
342
+ snr_dbs = torch .tensor (
343
+ [random .randint (1 , self .min_snr ) for _ in range (batch_size )]
344
+ ).to (wav .device )
345
+ applause_type = random .randint (5 , self .num_applause - 1 )
346
+
347
+ applause = getattr (self , f"applause_{ applause_type } " )
348
+
349
+ return AF .add_noise (waveform = wav , noise = applause , snr = snr_dbs )
350
+
324
351
def apply_reduction (self , wav : torch .tensor ):
325
352
"""
326
353
Limit the high-band pass filter, the low-band pass filter and the sample rate
327
354
Designed to mimic the effect of recording on a low-quality microphone or phone.
328
355
"""
329
- wav = AF .highpass_biquad (wav , self .sample_rate , cutoff_freq = 1200 )
330
- wav = AF .lowpass_biquad (wav , self .sample_rate , cutoff_freq = 1400 )
331
- resample_rate = 6000
332
- return AF .resample (wav , orig_freq = self .sample_rate , new_freq = resample_rate , lowpass_filter_width = 3 )
356
+ wav = AF .highpass_biquad (wav , self .sample_rate , cutoff_freq = 300 )
357
+ wav = AF .lowpass_biquad (wav , self .sample_rate , cutoff_freq = 3400 )
358
+ wav_downsampled = AF .resample (
359
+ wav ,
360
+ orig_freq = self .sample_rate ,
361
+ new_freq = self .reduction_resample_rate ,
362
+ lowpass_filter_width = 3 ,
363
+ )
364
+
365
+ return AF .resample (
366
+ wav_downsampled ,
367
+ self .reduction_resample_rate ,
368
+ self .sample_rate ,
369
+ )
333
370
334
371
def apply_distortion (self , wav : torch .tensor ):
335
372
gain = random .randint (self .min_dist_gain , self .max_dist_gain )
@@ -363,20 +400,23 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
363
400
return shifted_specs
364
401
365
402
def aug_wav (self , wav : torch .Tensor ):
366
- """
367
- pipeline for audio augmentation:
368
- 1. apply noise
369
- 2. apply distortion (x% of the time)
370
- 3. apply reduction (x% of the time)
371
- 4. apply reverb
372
- """
403
+ # Noise
404
+ if random .random () < self .noise_ratio :
405
+ wav = self .apply_noise (wav )
406
+ if random .random () < self .applause_ratio :
407
+ wav = self .apply_applause (wav )
373
408
374
- wav = self .apply_noise (wav )
375
- if random .random () < self .dist_ratio :
376
- wav = self .apply_distortion (wav )
409
+ # Distortion
377
410
if random .random () < self .reduce_ratio :
378
411
wav = self .apply_reduction (wav )
379
- return self .apply_reverb (wav )
412
+ elif random .random () < self .distort_ratio :
413
+ wav = self .apply_distortion (wav )
414
+
415
+ # Reverb
416
+ if random .random () < self .reverb_ratio :
417
+ return self .apply_reverb (wav )
418
+ else :
419
+ return wav
380
420
381
421
def norm_mel (self , mel_spec : torch .Tensor ):
382
422
log_spec = torch .clamp (mel_spec , min = 1e-10 ).log10 ()
@@ -399,13 +439,13 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
399
439
return log_spec
400
440
401
441
def forward (self , wav : torch .Tensor , shift : int = 0 ):
402
- # noise , distortion, reduction and reverb
442
+ # Noise , distortion, and reverb
403
443
wav = self .aug_wav (wav )
404
444
405
445
# Spec & pitch shift
406
446
log_mel = self .log_mel (wav , shift )
407
447
408
- # Spec aug in 20% of the cases
448
+ # Spec aug
409
449
if random .random () < self .spec_aug_ratio :
410
450
log_mel = self .spec_aug (log_mel )
411
451
0 commit comments