@@ -186,14 +186,18 @@ class AudioTransform(torch.nn.Module):
186
186
def __init__ (
187
187
self ,
188
188
reverb_factor : int = 1 ,
189
- min_snr : int = 10 ,
190
- max_snr : int = 40 ,
189
+ min_snr : int = 20 ,
190
+ max_snr : int = 50 ,
191
+ max_dist_gain : int = 25 ,
192
+ min_dist_gain : int = 0 ,
191
193
):
192
194
super ().__init__ ()
193
195
self .tokenizer = AmtTokenizer ()
194
196
self .reverb_factor = reverb_factor
195
197
self .min_snr = min_snr
196
198
self .max_snr = max_snr
199
+ self .max_dist_gain = max_dist_gain
200
+ self .min_dist_gain = min_dist_gain
197
201
198
202
self .config = load_config ()["audio" ]
199
203
self .sample_rate = self .config ["sample_rate" ]
@@ -230,10 +234,10 @@ def __init__(
230
234
)
231
235
self .spec_aug = torch .nn .Sequential (
232
236
torchaudio .transforms .FrequencyMasking (
233
- freq_mask_param = 15 , iid_masks = True
237
+ freq_mask_param = 10 , iid_masks = True
234
238
),
235
239
torchaudio .transforms .TimeMasking (
236
- time_mask_param = 500 , iid_masks = True
240
+ time_mask_param = 1000 , iid_masks = True
237
241
),
238
242
)
239
243
@@ -309,6 +313,12 @@ def apply_noise(self, wav: torch.tensor):
309
313
310
314
return AF .add_noise (waveform = wav , noise = noise , snr = snr_dbs )
311
315
316
+ def apply_distortion (self , wav : torch .tensor ):
317
+ gain = random .randint (self .min_dist_gain , self .max_dist_gain )
318
+ colour = random .randint (5 , 95 )
319
+
320
+ return AF .overdrive (wav , gain = gain , colour = colour )
321
+
312
322
def shift_spec (self , specs : torch .Tensor , shift : int ):
313
323
if shift == 0 :
314
324
return specs
@@ -335,7 +345,13 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
335
345
return shifted_specs
336
346
337
347
def aug_wav (self , wav : torch .Tensor ):
338
- return self .apply_reverb (self .apply_noise (wav ))
348
+ # Only apply distortion in 20% of cases
349
+ if random .random () > 0.20 :
350
+ return self .apply_reverb (self .apply_noise (wav ))
351
+ else :
352
+ return self .apply_reverb (
353
+ self .apply_distortion (self .apply_noise (wav ))
354
+ )
339
355
340
356
def norm_mel (self , mel_spec : torch .Tensor ):
341
357
log_spec = torch .clamp (mel_spec , min = 1e-10 ).log10 ()
@@ -364,8 +380,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
364
380
# Spec & pitch shift
365
381
log_mel = self .log_mel (wav , shift )
366
382
367
- # Spec aug
368
- if random .random () > 0.2 :
383
+ # Spec aug in 20% of cases
384
+ if random .random () > 0.20 :
369
385
log_mel = self .spec_aug (log_mel )
370
386
371
387
return log_mel
0 commit comments