Skip to content

Commit e49951c

Browse files
authored
Add multi gpu batched inference (#12)
* add more aug * add multi gpu inference
1 parent f6f5fbb commit e49951c

11 files changed

+595
-285
lines changed

amt/audio.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,18 @@ class AudioTransform(torch.nn.Module):
186186
def __init__(
187187
self,
188188
reverb_factor: int = 1,
189-
min_snr: int = 10,
190-
max_snr: int = 40,
189+
min_snr: int = 20,
190+
max_snr: int = 50,
191+
max_dist_gain: int = 25,
192+
min_dist_gain: int = 0,
191193
):
192194
super().__init__()
193195
self.tokenizer = AmtTokenizer()
194196
self.reverb_factor = reverb_factor
195197
self.min_snr = min_snr
196198
self.max_snr = max_snr
199+
self.max_dist_gain = max_dist_gain
200+
self.min_dist_gain = min_dist_gain
197201

198202
self.config = load_config()["audio"]
199203
self.sample_rate = self.config["sample_rate"]
@@ -230,10 +234,10 @@ def __init__(
230234
)
231235
self.spec_aug = torch.nn.Sequential(
232236
torchaudio.transforms.FrequencyMasking(
233-
freq_mask_param=15, iid_masks=True
237+
freq_mask_param=10, iid_masks=True
234238
),
235239
torchaudio.transforms.TimeMasking(
236-
time_mask_param=500, iid_masks=True
240+
time_mask_param=1000, iid_masks=True
237241
),
238242
)
239243

@@ -309,6 +313,12 @@ def apply_noise(self, wav: torch.tensor):
309313

310314
return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)
311315

316+
def apply_distortion(self, wav: torch.tensor):
317+
gain = random.randint(self.min_dist_gain, self.max_dist_gain)
318+
colour = random.randint(5, 95)
319+
320+
return AF.overdrive(wav, gain=gain, colour=colour)
321+
312322
def shift_spec(self, specs: torch.Tensor, shift: int):
313323
if shift == 0:
314324
return specs
@@ -335,7 +345,13 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
335345
return shifted_specs
336346

337347
def aug_wav(self, wav: torch.Tensor):
338-
return self.apply_reverb(self.apply_noise(wav))
348+
# Only apply distortion in 20% of cases
349+
if random.random() > 0.20:
350+
return self.apply_reverb(self.apply_noise(wav))
351+
else:
352+
return self.apply_reverb(
353+
self.apply_distortion(self.apply_noise(wav))
354+
)
339355

340356
def norm_mel(self, mel_spec: torch.Tensor):
341357
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -364,8 +380,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
364380
# Spec & pitch shift
365381
log_mel = self.log_mel(wav, shift)
366382

367-
# Spec aug
368-
if random.random() > 0.2:
383+
# Spec aug in 20% of cases
384+
if random.random() > 0.20:
369385
log_mel = self.spec_aug(log_mel)
370386

371387
return log_mel

amt/data.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,25 @@
1414

1515

1616
def get_wav_mid_segments(
17-
audio_path: str, mid_path: str = "", return_json: bool = False
17+
audio_path: str,
18+
mid_path: str = "",
19+
return_json: bool = False,
20+
stride_factor: int | None = None,
1821
):
1922
"""This function yields tuples of matched log mel spectrograms and
2023
tokenized sequences (np.array, list). If it is given only an audio path
2124
then it will return an empty list for the mid_feature
2225
"""
2326
tokenizer = AmtTokenizer()
2427
config = load_config()
25-
stride_factor = config["data"]["stride_factor"]
2628
sample_rate = config["audio"]["sample_rate"]
2729
chunk_len = config["audio"]["chunk_len"]
2830
num_samples = sample_rate * chunk_len
2931
samples_per_ms = sample_rate // 1000
3032

33+
if not stride_factor:
34+
stride_factor = config["data"]["stride_factor"]
35+
3136
if not os.path.isfile(audio_path):
3237
return None
3338

0 commit comments

Comments
 (0)