Skip to content

Commit 50f0b60

Browse files
authored
Update inference (#22)
* synth data - untested * adjust audio striding * add detune aug * add aug val to train * add b64 encoding to dataset * update dataset * fix * training changes * inference changes * format * format
1 parent 06abaff commit 50f0b60

15 files changed

+1690
-815
lines changed

amt/audio.py

+43-23
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,11 @@ def __init__(
194194
noise_ratio: float = 0.95,
195195
reverb_ratio: float = 0.95,
196196
applause_ratio: float = 0.01,
197-
bandpass_ratio: float = 0.1,
197+
bandpass_ratio: float = 0.15,
198198
distort_ratio: float = 0.15,
199199
reduce_ratio: float = 0.01,
200-
codecs_ratio: float = 0.01,
200+
detune_ratio: float = 0.1,
201+
detune_max_shift: float = 0.15,
201202
spec_aug_ratio: float = 0.5,
202203
):
203204
super().__init__()
@@ -219,8 +220,9 @@ def __init__(
219220
self.bandpass_ratio = bandpass_ratio
220221
self.distort_ratio = distort_ratio
221222
self.reduce_ratio = reduce_ratio
223+
self.detune_ratio = detune_ratio
224+
self.detune_max_shift = detune_max_shift
222225
self.spec_aug_ratio = spec_aug_ratio
223-
self.codecs_ratio = codecs_ratio
224226
self.reduction_resample_rate = 6000 # Hardcoded?
225227

226228
# Audio aug
@@ -268,6 +270,19 @@ def __init__(
268270
),
269271
)
270272

273+
def get_params(self):
274+
return {
275+
"noise_ratio": self.noise_ratio,
276+
"reverb_ratio": self.reverb_ratio,
277+
"applause_ratio": self.applause_ratio,
278+
"bandpass_ratio": self.bandpass_ratio,
279+
"distort_ratio": self.distort_ratio,
280+
"reduce_ratio": self.reduce_ratio,
281+
"detune_ratio": self.detune_ratio,
282+
"detune_max_shift": self.detune_max_shift,
283+
"spec_aug_ratio": self.spec_aug_ratio,
284+
}
285+
271286
def _get_paths(self, dir_path):
272287
os.makedirs(dir_path, exist_ok=True)
273288

@@ -399,21 +414,7 @@ def distortion_aug_cpu(self, wav: torch.Tensor):
399414

400415
return wav
401416

402-
def apply_codec(self, wav: torch.tensor):
403-
"""
404-
Apply different audio codecs to the audio.
405-
"""
406-
format_encoder_pairs = [
407-
("wav", "pcm_mulaw"),
408-
("g722", None),
409-
("ogg", "vorbis")
410-
]
411-
for format, encoder in format_encoder_pairs:
412-
encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder)
413-
if random.random() < self.codecs_ratio:
414-
wav = encoder.apply(wav, self.sample_rate)
415-
416-
def shift_spec(self, specs: torch.Tensor, shift: int):
417+
def shift_spec(self, specs: torch.Tensor, shift: int | float):
417418
if shift == 0:
418419
return specs
419420

@@ -438,9 +439,21 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
438439

439440
return shifted_specs
440441

442+
def detune_spec(self, specs: torch.Tensor):
443+
if random.random() < self.detune_ratio:
444+
detune_shift = random.uniform(
445+
-self.detune_max_shift, self.detune_max_shift
446+
)
447+
detuned_specs = self.shift_spec(specs, shift=detune_shift)
448+
449+
return (specs + detuned_specs) / 2
450+
else:
451+
return specs
452+
441453
def aug_wav(self, wav: torch.Tensor):
442454
# This function doesn't apply distortion. If distortion is desired it
443-
# should be run before hand on the cpu with distortion_aug_cpu.
455+
# should be run beforehand on the cpu with distortion_aug_cpu. Note
456+
# also that detuning is done to the spectrogram in log_mel, not the wav.
444457

445458
# Noise
446459
if random.random() < self.noise_ratio:
@@ -468,10 +481,17 @@ def norm_mel(self, mel_spec: torch.Tensor):
468481

469482
return log_spec
470483

471-
def log_mel(self, wav: torch.Tensor, shift: int | None = None):
484+
def log_mel(
485+
self, wav: torch.Tensor, shift: int | None = None, detune: bool = False
486+
):
472487
spec = self.spec_transform(wav)[..., :-1]
473-
if shift and shift != 0:
488+
489+
if shift is not None and shift != 0:
474490
spec = self.shift_spec(spec, shift)
491+
elif detune is True:
492+
# Don't detune and spec shift at the same time
493+
spec = self.detune_spec(spec)
494+
475495
mel_spec = self.mel_transform(spec)
476496

477497
# Norm
@@ -483,8 +503,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
483503
# Noise, and reverb
484504
wav = self.aug_wav(wav)
485505

486-
# Spec & pitch shift
487-
log_mel = self.log_mel(wav, shift)
506+
# Spec, detuning & pitch shift
507+
log_mel = self.log_mel(wav, shift, detune=True)
488508

489509
# Spec aug
490510
if random.random() < self.spec_aug_ratio:

0 commit comments

Comments
 (0)