Skip to content

Commit d6fea7f

Browse files
authored
Fix inference, pedal, and add EQ aug (#20)
* fix inference and add prev pedal token * add bandpass eq
1 parent d56e8e5 commit d6fea7f

File tree

4 files changed

+109
-63
lines changed

4 files changed

+109
-63
lines changed

amt/audio.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
noise_ratio: float = 0.95,
195195
reverb_ratio: float = 0.95,
196196
applause_ratio: float = 0.01,
197+
bandpass_ratio: float = 0.1,
197198
distort_ratio: float = 0.15,
198199
reduce_ratio: float = 0.01,
199200
spec_aug_ratio: float = 0.5,
@@ -214,6 +215,7 @@ def __init__(
214215
self.noise_ratio = noise_ratio
215216
self.reverb_ratio = reverb_ratio
216217
self.applause_ratio = applause_ratio
218+
self.bandpass_ratio = bandpass_ratio
217219
self.distort_ratio = distort_ratio
218220
self.reduce_ratio = reduce_ratio
219221
self.spec_aug_ratio = spec_aug_ratio
@@ -350,6 +352,14 @@ def apply_applause(self, wav: torch.tensor):
350352

351353
return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs)
352354

355+
def apply_bandpass(self, wav: torch.tensor):
356+
central_freq = random.randint(1000, 3500)
357+
Q = random.uniform(0.707, 1.41)
358+
359+
return torchaudio.functional.bandpass_biquad(
360+
wav, self.sample_rate, central_freq, Q
361+
)
362+
353363
def apply_reduction(self, wav: torch.tensor):
354364
"""
355365
Limit the high-band pass filter, the low-band pass filter and the sample rate
@@ -424,9 +434,13 @@ def aug_wav(self, wav: torch.Tensor):
424434

425435
# Reverb
426436
if random.random() < self.reverb_ratio:
427-
return self.apply_reverb(wav)
428-
else:
429-
return wav
437+
wav = self.apply_reverb(wav)
438+
439+
# EQ
440+
if random.random() < self.bandpass_ratio:
441+
wav = self.apply_bandpass(wav)
442+
443+
return wav
430444

431445
def norm_mel(self, mel_spec: torch.Tensor):
432446
log_spec = torch.clamp(mel_spec, min=1e-10).log10()

amt/infer.py

+49-42
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,13 @@ def _truncate_seq(
283283
_mid_dict = tokenizer._detokenize_midi_dict(seq, LEN_MS)
284284
try:
285285
res = tokenizer._tokenize_midi_dict(_mid_dict, start_ms, end_ms - 1)
286-
except:
286+
except Exception:
287+
print("Truncate failed")
287288
return ["<S>"]
288289
else:
289-
return res[: res.index(tokenizer.eos_tok)] # Needs to change
290+
if res[-1] == tokenizer.eos_tok:
291+
res.pop()
292+
return res
290293

291294

292295
def process_file(
@@ -306,14 +309,9 @@ def process_file(
306309
)
307310
]
308311

309-
# Add addtional (padded) final audio segment
310-
_last_seg = audio_segments[-1]
311-
audio_segments.append(
312-
pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :])
313-
)
314-
312+
res = []
315313
seq = [tokenizer.bos_tok]
316-
res = [tokenizer.bos_tok]
314+
concat_seq = [tokenizer.bos_tok]
317315
for idx, audio_seg in enumerate(audio_segments):
318316
init_idx = len(seq)
319317

@@ -327,21 +325,25 @@ def process_file(
327325
else:
328326
result_queue.put(gpu_result)
329327

330-
res += _shift_onset(
328+
concat_seq += _shift_onset(
331329
seq[init_idx:],
332330
idx * CHUNK_LEN_MS,
333331
)
334332

335333
if idx == len(audio_segments) - 1:
336-
break
337-
elif res[-1] == tokenizer.eos_tok:
338-
logger.info(f"Exiting early")
339-
break
334+
res.append(concat_seq)
335+
elif concat_seq[-1] == tokenizer.eos_tok:
336+
res.append(concat_seq)
337+
seq = [tokenizer.bos_tok]
338+
concat_seq = [tokenizer.bos_tok]
339+
logger.info(f"Finished segment - eos_tok seen")
340340
else:
341341
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS)
342-
if len(seq) <= 2:
343-
logger.info(f"Exiting early")
344-
return res
342+
if len(seq) == 1:
343+
res.append(concat_seq)
344+
seq = [tokenizer.bos_tok]
345+
concat_seq = [tokenizer.bos_tok]
346+
logger.info(f"Exiting early - silence")
345347

346348
return res
347349

@@ -353,16 +355,35 @@ def worker(
353355
save_dir: str,
354356
input_dir: str | None = None,
355357
):
356-
def _get_save_path(_file_path: str):
358+
def _save_seq(_seq: list, _save_path: str):
359+
if os.path.exists(_save_path):
360+
logger.info(f"Already exists {_save_path} - overwriting")
361+
362+
for tok in _seq[::-1]:
363+
if type(tok) is tuple and tok[0] == "onset":
364+
last_onset = tok[1]
365+
break
366+
367+
try:
368+
mid_dict = tokenizer._detokenize_midi_dict(
369+
tokenized_seq=_seq, len_ms=last_onset
370+
)
371+
mid = mid_dict.to_midi()
372+
mid.save(_save_path)
373+
except Exception as e:
374+
logger.error(f"Failed to save {_save_path}")
375+
376+
def _get_save_path(_file_path: str, _idx: int | str = ""):
357377
if input_dir is None:
358378
save_path = os.path.join(
359379
save_dir,
360-
os.path.splitext(os.path.basename(file_path))[0] + ".mid",
380+
os.path.splitext(os.path.basename(file_path))[0]
381+
+ f"{_idx}.mid",
361382
)
362383
else:
363384
input_rel_path = os.path.relpath(_file_path, input_dir)
364385
save_path = os.path.join(
365-
save_dir, os.path.splitext(input_rel_path)[0] + ".mid"
386+
save_dir, os.path.splitext(input_rel_path)[0] + f"{_idx}.mid"
366387
)
367388
if not os.path.isdir(os.path.dirname(save_path)):
368389
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@@ -374,34 +395,20 @@ def _get_save_path(_file_path: str):
374395
files_processed = 0
375396
while not file_queue.empty():
376397
file_path = file_queue.get()
377-
save_path = _get_save_path(file_path)
378-
if os.path.exists(save_path):
379-
logger.info(f"{save_path} already exists, overwriting")
380398

381399
try:
382-
res = process_file(file_path, gpu_task_queue, result_queue)
400+
seqs = process_file(file_path, gpu_task_queue, result_queue)
383401
except Exception as e:
384-
logger.error(f"Failed to transcribe {file_path}")
402+
logger.error(f"Failed to process {file_path}")
385403
continue
386404

387-
files_processed += 1
388-
389-
for tok in res[::-1]:
390-
if type(tok) is tuple and tok[0] == "onset":
391-
last_onset = tok[1]
392-
break
405+
logger.info(f"Transcribed into {len(seqs)} segment(s)")
406+
for _idx, seq in enumerate(seqs):
407+
_save_seq(seq, _get_save_path(file_path, _idx))
393408

394-
try:
395-
mid_dict = tokenizer._detokenize_midi_dict(
396-
tokenized_seq=res, len_ms=last_onset
397-
)
398-
mid = mid_dict.to_midi()
399-
mid.save(save_path)
400-
except Exception as e:
401-
logger.error(f"Failed to detokenize with error {e}")
402-
else:
403-
logger.info(f"Finished file {files_processed} - {file_path}")
404-
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
409+
files_processed += 1
410+
logger.info(f"Finished file {files_processed} - {file_path}")
411+
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
405412

406413

407414
def batch_transcribe(

amt/tokenizer.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, return_tensors: bool = False):
4646
self.prev_tokens = [("prev", i) for i in range(128)]
4747
self.note_on_tokens = [("on", i) for i in range(128)]
4848
self.note_off_tokens = [("off", i) for i in range(128)]
49-
self.pedal_tokens = [("pedal", 0), (("pedal", 1))]
49+
self.pedal_tokens = [("pedal", 0), ("pedal", 1), ("prev", "pedal")]
5050
self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations]
5151
self.onset_tokens = [
5252
("onset", i) for i in self.onset_time_quantizations
@@ -81,7 +81,6 @@ def _quantize_velocity(self, velocity: int):
8181
# TODO:
8282
# - I need to make this method more robust, as it will have to handle
8383
# an arbitrary MIDI file
84-
# - Decide whether to put pedal messages as prev tokens
8584
def _tokenize_midi_dict(
8685
self,
8786
midi_dict: MidiDict,
@@ -96,11 +95,13 @@ def _tokenize_midi_dict(
9695
pedal_intervals = midi_dict._build_pedal_intervals()
9796
if len(pedal_intervals.keys()) > 1:
9897
print("Warning: midi_dict has more than one pedal channel")
98+
if len(midi_dict.instrument_msgs) > 1:
99+
print("Warning: midi_dict has more than one instrument msg")
99100
pedal_intervals = pedal_intervals[0]
100101

101102
last_msg_ms = -1
102103
on_off_notes = []
103-
prev_notes = []
104+
prev_toks = []
104105
for msg in midi_dict.note_msgs:
105106
_pitch = msg["data"]["pitch"]
106107
_velocity = msg["data"]["velocity"]
@@ -137,9 +138,9 @@ def _tokenize_midi_dict(
137138
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
138139
continue
139140
elif (
140-
note_start_ms < start_ms and _pitch not in prev_notes
141+
note_start_ms < start_ms and _pitch not in prev_toks
141142
): # Add to prev notes
142-
prev_notes.append(_pitch)
143+
prev_toks.append(_pitch)
143144
if note_end_ms < end_ms:
144145
on_off_notes.append(
145146
("off", _pitch, rel_note_end_ms_q, None)
@@ -182,8 +183,10 @@ def _tokenize_midi_dict(
182183
rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms)
183184

184185
# On message
185-
if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms:
186+
if pedal_off_ms <= start_ms or pedal_on_ms >= end_ms:
186187
continue
188+
elif pedal_on_ms < start_ms and pedal_off_ms >= start_ms:
189+
prev_toks.append("pedal")
187190
else:
188191
on_off_pedal.append(("pedal", 1, rel_on_ms_q, None))
189192

@@ -200,7 +203,7 @@ def _tokenize_midi_dict(
200203
(0 if x[0] == "pedal" else 1 if x[0] == "off" else 2),
201204
)
202205
)
203-
random.shuffle(prev_notes)
206+
random.shuffle(prev_toks)
204207

205208
tokenized_seq = []
206209
for tok in on_off_combined:
@@ -220,7 +223,7 @@ def _tokenize_midi_dict(
220223
tokenized_seq.append(("pedal", _val))
221224
tokenized_seq.append(("onset", _onset))
222225

223-
prefix = [("prev", p) for p in prev_notes]
226+
prefix = [("prev", p) for p in prev_toks]
224227

225228
# Add eos_tok only if segment includes end of midi_dict
226229
if last_msg_ms < end_ms:
@@ -271,7 +274,21 @@ def _detokenize_midi_dict(
271274
if DEBUG:
272275
raise Exception
273276

274-
notes_to_close[tok[1]] = (0, self.default_velocity)
277+
if tok[1] == "pedal":
278+
pedal_msgs.append(
279+
{
280+
"type": "pedal",
281+
"data": 1,
282+
"tick": 0,
283+
"channel": 0,
284+
}
285+
)
286+
elif isinstance(tok[1], int):
287+
notes_to_close[tok[1]] = (0, self.default_velocity)
288+
else:
289+
print(f"Invalid 'prev' token: {tok}")
290+
if DEBUG:
291+
raise Exception
275292
else:
276293
raise Exception(
277294
f"Invalid note sequence at position {idx}: {tok, tokenized_seq[:idx]}"
@@ -293,11 +310,9 @@ def _detokenize_midi_dict(
293310
if DEBUG:
294311
raise Exception
295312
elif tok_1_type == "pedal":
296-
# Pedal information contained in note-off messages, so we don't
297-
# need to manually processes them
298313
_pedal_data = tok_1_data
299314
_tick = tok_2_data
300-
note_msgs.append(
315+
pedal_msgs.append(
301316
{
302317
"type": "pedal",
303318
"data": _pedal_data,
@@ -454,13 +469,11 @@ def msg_mixup(src: list):
454469

455470
# Shuffle order and re-append to result
456471
for k, v in sorted(buffer.items()):
472+
off_pedal_combined = v["off"] + v["pedal"]
473+
random.shuffle(off_pedal_combined)
457474
random.shuffle(v["on"])
458-
random.shuffle(v["off"])
459-
for item in v["pedal"]:
460-
res.append(item[0]) # Pedal
461-
res.append(item[1]) # Onset
462-
for item in v["off"]:
463-
res.append(item[0]) # Pitch
475+
for item in off_pedal_combined:
476+
res.append(item[0]) # Off or pedal
464477
res.append(item[1]) # Onset
465478
for item in v["on"]:
466479
res.append(item[0]) # Pitch

tests/test_data.py

+12
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ def test_distortion(self):
177177
res = audio_transform.apply_distortion(wav)
178178
torchaudio.save("tests/test_results/dist.wav", res, SAMPLE_RATE)
179179

180+
def test_bandpass(self):
181+
SAMPLE_RATE, CHUNK_LEN = 16000, 30
182+
audio_transform = AudioTransform()
183+
wav, sr = torchaudio.load("tests/test_data/147.wav")
184+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE).mean(
185+
0, keepdim=True
186+
)[:, : SAMPLE_RATE * CHUNK_LEN]
187+
188+
torchaudio.save("tests/test_results/orig.wav", wav, SAMPLE_RATE)
189+
res = audio_transform.apply_bandpass(wav)
190+
torchaudio.save("tests/test_results/bandpass.wav", res, SAMPLE_RATE)
191+
180192
def test_applause(self):
181193
SAMPLE_RATE, CHUNK_LEN = 16000, 30
182194
audio_transform = AudioTransform()

0 commit comments

Comments
 (0)