@@ -194,10 +194,11 @@ def __init__(
194
194
noise_ratio : float = 0.95 ,
195
195
reverb_ratio : float = 0.95 ,
196
196
applause_ratio : float = 0.01 ,
197
- bandpass_ratio : float = 0.1 ,
197
+ bandpass_ratio : float = 0.15 ,
198
198
distort_ratio : float = 0.15 ,
199
199
reduce_ratio : float = 0.01 ,
200
- codecs_ratio : float = 0.01 ,
200
+ detune_ratio : float = 0.1 ,
201
+ detune_max_shift : float = 0.15 ,
201
202
spec_aug_ratio : float = 0.5 ,
202
203
):
203
204
super ().__init__ ()
@@ -219,8 +220,9 @@ def __init__(
219
220
self .bandpass_ratio = bandpass_ratio
220
221
self .distort_ratio = distort_ratio
221
222
self .reduce_ratio = reduce_ratio
223
+ self .detune_ratio = detune_ratio
224
+ self .detune_max_shift = detune_max_shift
222
225
self .spec_aug_ratio = spec_aug_ratio
223
- self .codecs_ratio = codecs_ratio
224
226
self .reduction_resample_rate = 6000 # Hardcoded?
225
227
226
228
# Audio aug
@@ -268,6 +270,19 @@ def __init__(
268
270
),
269
271
)
270
272
273
+ def get_params (self ):
274
+ return {
275
+ "noise_ratio" : self .noise_ratio ,
276
+ "reverb_ratio" : self .reverb_ratio ,
277
+ "applause_ratio" : self .applause_ratio ,
278
+ "bandpass_ratio" : self .bandpass_ratio ,
279
+ "distort_ratio" : self .distort_ratio ,
280
+ "reduce_ratio" : self .reduce_ratio ,
281
+ "detune_ratio" : self .detune_ratio ,
282
+ "detune_max_shift" : self .detune_max_shift ,
283
+ "spec_aug_ratio" : self .spec_aug_ratio ,
284
+ }
285
+
271
286
def _get_paths (self , dir_path ):
272
287
os .makedirs (dir_path , exist_ok = True )
273
288
@@ -399,21 +414,7 @@ def distortion_aug_cpu(self, wav: torch.Tensor):
399
414
400
415
return wav
401
416
402
- def apply_codec (self , wav : torch .tensor ):
403
- """
404
- Apply different audio codecs to the audio.
405
- """
406
- format_encoder_pairs = [
407
- ("wav" , "pcm_mulaw" ),
408
- ("g722" , None ),
409
- ("ogg" , "vorbis" )
410
- ]
411
- for format , encoder in format_encoder_pairs :
412
- encoder = torchaudio .io .AudioEffector (format = format , encoder = encoder )
413
- if random .random () < self .codecs_ratio :
414
- wav = encoder .apply (wav , self .sample_rate )
415
-
416
- def shift_spec (self , specs : torch .Tensor , shift : int ):
417
+ def shift_spec (self , specs : torch .Tensor , shift : int | float ):
417
418
if shift == 0 :
418
419
return specs
419
420
@@ -438,9 +439,21 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
438
439
439
440
return shifted_specs
440
441
442
+ def detune_spec (self , specs : torch .Tensor ):
443
+ if random .random () < self .detune_ratio :
444
+ detune_shift = random .uniform (
445
+ - self .detune_max_shift , self .detune_max_shift
446
+ )
447
+ detuned_specs = self .shift_spec (specs , shift = detune_shift )
448
+
449
+ return (specs + detuned_specs ) / 2
450
+ else :
451
+ return specs
452
+
441
453
def aug_wav (self , wav : torch .Tensor ):
442
454
# This function doesn't apply distortion. If distortion is desired it
443
- # should be run before hand on the cpu with distortion_aug_cpu.
455
+ # should be run beforehand on the cpu with distortion_aug_cpu. Note
456
+ # also that detuning is done to the spectrogram in log_mel, not the wav.
444
457
445
458
# Noise
446
459
if random .random () < self .noise_ratio :
@@ -468,10 +481,17 @@ def norm_mel(self, mel_spec: torch.Tensor):
468
481
469
482
return log_spec
470
483
471
- def log_mel (self , wav : torch .Tensor , shift : int | None = None ):
484
+ def log_mel (
485
+ self , wav : torch .Tensor , shift : int | None = None , detune : bool = False
486
+ ):
472
487
spec = self .spec_transform (wav )[..., :- 1 ]
473
- if shift and shift != 0 :
488
+
489
+ if shift is not None and shift != 0 :
474
490
spec = self .shift_spec (spec , shift )
491
+ elif detune is True :
492
+ # Don't detune and spec shift at the same time
493
+ spec = self .detune_spec (spec )
494
+
475
495
mel_spec = self .mel_transform (spec )
476
496
477
497
# Norm
@@ -483,8 +503,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
483
503
# Noise, and reverb
484
504
wav = self .aug_wav (wav )
485
505
486
- # Spec & pitch shift
487
- log_mel = self .log_mel (wav , shift )
506
+ # Spec, detuning & pitch shift
507
+ log_mel = self .log_mel (wav , shift , detune = True )
488
508
489
509
# Spec aug
490
510
if random .random () < self .spec_aug_ratio :
0 commit comments