Skip to content

Commit b82a9da

Browse files
authored
Add applause augmentation (#15)
* add data aug and clean * fix reverb
1 parent e00b374 commit b82a9da

File tree

6 files changed

+163
-50
lines changed

6 files changed

+163
-50
lines changed

.gitignore

-4
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
*.xml
1616
*.html
1717
*.htm
18-
*.mid
19-
*.midi
20-
*.wav
21-
*.mp3
2218

2319
.idea/
2420

amt/audio.py

+62-22
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def log_mel_spectrogram(
182182
return log_spec
183183

184184

185+
# Refactor default params are stored in config.json
185186
class AudioTransform(torch.nn.Module):
186187
def __init__(
187188
self,
@@ -190,10 +191,12 @@ def __init__(
190191
max_snr: int = 50,
191192
max_dist_gain: int = 25,
192193
min_dist_gain: int = 0,
193-
# ratios for the reduction of the audio quality
194-
distort_ratio: float = 0.2,
195-
reduce_ratio: float = 0.2,
196-
spec_aug_ratio: float = 0.2,
194+
noise_ratio: float = 0.95,
195+
reverb_ratio: float = 0.95,
196+
applause_ratio: float = 0.01, # CHANGE
197+
distort_ratio: float = 0.15,
198+
reduce_ratio: float = 0.01,
199+
spec_aug_ratio: float = 0.25,
197200
):
198201
super().__init__()
199202
self.tokenizer = AmtTokenizer()
@@ -208,9 +211,13 @@ def __init__(
208211
self.chunk_len = self.config["chunk_len"]
209212
self.num_samples = self.sample_rate * self.chunk_len
210213

211-
self.dist_ratio = distort_ratio
214+
self.noise_ratio = noise_ratio
215+
self.reverb_ratio = reverb_ratio
216+
self.applause_ratio = applause_ratio
217+
self.distort_ratio = distort_ratio
212218
self.reduce_ratio = reduce_ratio
213219
self.spec_aug_ratio = spec_aug_ratio
220+
self.reduction_resample_rate = 6000 # Hardcoded?
214221

215222
# Audio aug
216223
impulse_paths = self._get_paths(
@@ -219,6 +226,9 @@ def __init__(
219226
noise_paths = self._get_paths(
220227
os.path.join(os.path.dirname(__file__), "assets", "noise")
221228
)
229+
applause_paths = self._get_paths(
230+
os.path.join(os.path.dirname(__file__), "assets", "applause")
231+
)
222232

223233
# Register impulses and noises as buffers
224234
self.num_impulse = 0
@@ -231,6 +241,11 @@ def __init__(
231241
self.register_buffer(f"noise_{i}", noise)
232242
self.num_noise += 1
233243

244+
self.num_applause = 0
245+
for i, applause in enumerate(self._get_noise(applause_paths)):
246+
self.register_buffer(f"applause_{i}", applause)
247+
self.num_applause += 1
248+
234249
self.spec_transform = torchaudio.transforms.Spectrogram(
235250
n_fft=self.config["n_fft"],
236251
hop_length=self.config["hop_len"],
@@ -321,15 +336,37 @@ def apply_noise(self, wav: torch.tensor):
321336

322337
return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)
323338

339+
def apply_applause(self, wav: torch.tensor):
340+
batch_size, _ = wav.shape
341+
342+
snr_dbs = torch.tensor(
343+
[random.randint(1, self.min_snr) for _ in range(batch_size)]
344+
).to(wav.device)
345+
applause_type = random.randint(5, self.num_applause - 1)
346+
347+
applause = getattr(self, f"applause_{applause_type}")
348+
349+
return AF.add_noise(waveform=wav, noise=applause, snr=snr_dbs)
350+
324351
def apply_reduction(self, wav: torch.tensor):
325352
"""
326353
Limit the high-band pass filter, the low-band pass filter and the sample rate
327354
Designed to mimic the effect of recording on a low-quality microphone or phone.
328355
"""
329-
wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=1200)
330-
wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=1400)
331-
resample_rate = 6000
332-
return AF.resample(wav, orig_freq=self.sample_rate, new_freq=resample_rate, lowpass_filter_width=3)
356+
wav = AF.highpass_biquad(wav, self.sample_rate, cutoff_freq=300)
357+
wav = AF.lowpass_biquad(wav, self.sample_rate, cutoff_freq=3400)
358+
wav_downsampled = AF.resample(
359+
wav,
360+
orig_freq=self.sample_rate,
361+
new_freq=self.reduction_resample_rate,
362+
lowpass_filter_width=3,
363+
)
364+
365+
return AF.resample(
366+
wav_downsampled,
367+
self.reduction_resample_rate,
368+
self.sample_rate,
369+
)
333370

334371
def apply_distortion(self, wav: torch.tensor):
335372
gain = random.randint(self.min_dist_gain, self.max_dist_gain)
@@ -363,20 +400,23 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
363400
return shifted_specs
364401

365402
def aug_wav(self, wav: torch.Tensor):
366-
"""
367-
pipeline for audio augmentation:
368-
1. apply noise
369-
2. apply distortion (x% of the time)
370-
3. apply reduction (x% of the time)
371-
4. apply reverb
372-
"""
403+
# Noise
404+
if random.random() < self.noise_ratio:
405+
wav = self.apply_noise(wav)
406+
if random.random() < self.applause_ratio:
407+
wav = self.apply_applause(wav)
373408

374-
wav = self.apply_noise(wav)
375-
if random.random() < self.dist_ratio:
376-
wav = self.apply_distortion(wav)
409+
# Distortion
377410
if random.random() < self.reduce_ratio:
378411
wav = self.apply_reduction(wav)
379-
return self.apply_reverb(wav)
412+
elif random.random() < self.distort_ratio:
413+
wav = self.apply_distortion(wav)
414+
415+
# Reverb
416+
if random.random() < self.reverb_ratio:
417+
return self.apply_reverb(wav)
418+
else:
419+
return wav
380420

381421
def norm_mel(self, mel_spec: torch.Tensor):
382422
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -399,13 +439,13 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
399439
return log_spec
400440

401441
def forward(self, wav: torch.Tensor, shift: int = 0):
402-
# noise, distortion, reduction and reverb
442+
# Noise, distortion, and reverb
403443
wav = self.aug_wav(wav)
404444

405445
# Spec & pitch shift
406446
log_mel = self.log_mel(wav, shift)
407447

408-
# Spec aug in 20% of the cases
448+
# Spec aug
409449
if random.random() < self.spec_aug_ratio:
410450
log_mel = self.spec_aug(log_mel)
411451

amt/evaluate.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import os
88

9+
910
def midi_to_intervals_and_pitches(midi_file_path):
1011
"""
1112
This function reads a MIDI file and extracts note intervals and pitches
@@ -55,18 +56,26 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
5556
if ref_fpath in ref_midi_files:
5657
est_ref_pairs.append((est_fpath, ref_fpath))
5758
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
58-
est_ref_pairs.append((est_fpath, ref_fpath.replace(".mid", ".midi")))
59+
est_ref_pairs.append(
60+
(est_fpath, ref_fpath.replace(".mid", ".midi"))
61+
)
5962
else:
60-
print(f"Reference file not found for {est_fpath} (ref file: {ref_fpath})")
63+
print(
64+
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
65+
)
6166

62-
output_fhandle = open(output_stats_file, "w") if output_stats_file is not None else None
67+
output_fhandle = (
68+
open(output_stats_file, "w") if output_stats_file is not None else None
69+
)
6370

6471
for est_file, ref_file in tqdm(est_ref_pairs):
6572
ref_intervals, ref_pitches = midi_to_intervals_and_pitches(ref_file)
6673
est_intervals, est_pitches = midi_to_intervals_and_pitches(est_file)
6774
ref_pitches_hz = midi_to_hz(ref_pitches)
6875
est_pitches_hz = midi_to_hz(est_pitches, est_shift)
69-
scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz)
76+
scores = mir_eval.transcription.evaluate(
77+
ref_intervals, ref_pitches_hz, est_intervals, est_pitches_hz
78+
)
7079
if output_fhandle is not None:
7180
output_fhandle.write(json.dumps(scores))
7281
output_fhandle.write("\n")
@@ -76,30 +85,43 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
7685

7786
if __name__ == "__main__":
7887
import argparse
88+
7989
parser = argparse.ArgumentParser(usage="evaluate <command> [<args>]")
8090
parser.add_argument(
8191
"--est-dir",
8292
type=str,
83-
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed."
93+
help="Path to the directory containing either the transcribed MIDI files or WAV files to be transcribed.",
8494
)
8595
parser.add_argument(
8696
"--ref-dir",
8797
type=str,
88-
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw)."
98+
help="Path to the directory containing the reference files (we'll use gold MIDI for mir_eval, WAV for dtw).",
8999
)
90100
parser.add_argument(
91-
'--output-stats-file',
101+
"--output-stats-file",
92102
default=None,
93-
type=str, help="Path to the file to save the evaluation stats"
103+
type=str,
104+
help="Path to the file to save the evaluation stats",
94105
)
95106

96107
# add mir_eval and dtw subparsers
97108
subparsers = parser.add_subparsers(help="sub-command help")
98-
mir_eval_parse = subparsers.add_parser("run_mir_eval", help="Run standard mir_eval evaluation on MAESTRO test set.")
99-
mir_eval_parse.add_argument('--shift', type=int, default=0, help="Shift to apply to the estimated pitches.")
109+
mir_eval_parse = subparsers.add_parser(
110+
"run_mir_eval",
111+
help="Run standard mir_eval evaluation on MAESTRO test set.",
112+
)
113+
mir_eval_parse.add_argument(
114+
"--shift",
115+
type=int,
116+
default=0,
117+
help="Shift to apply to the estimated pitches.",
118+
)
100119

101120
# to come
102-
dtw_eval_parse = subparsers.add_parser("run_dtw", help="Run dynamic time warping evaluation on a specified dataset.")
121+
dtw_eval_parse = subparsers.add_parser(
122+
"run_dtw",
123+
help="Run dynamic time warping evaluation on a specified dataset.",
124+
)
103125

104126
args = parser.parse_args()
105127
if not hasattr(args, "command"):
@@ -112,6 +134,8 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
112134
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.
113135

114136
if args.command == "run_mir_eval":
115-
evaluate_mir_eval(args.est_dir, args.ref_dir, args.output_stats_file, args.shift)
137+
evaluate_mir_eval(
138+
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
139+
)
116140
elif args.command == "run_dtw":
117141
pass

amt/infer.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
# TODO: Profile and fix gpu util
2525

26+
2627
def calculate_vel(
2728
logits: torch.Tensor,
2829
init_vel: int,
@@ -89,6 +90,8 @@ def calculate_onset(
8990

9091
from functools import wraps
9192
from torch.cuda import is_bf16_supported
93+
94+
9295
def optional_bf16_autocast(func):
9396
@wraps(func)
9497
def wrapper(*args, **kwargs):
@@ -100,6 +103,7 @@ def wrapper(*args, **kwargs):
100103
# Call the function with float16 if bfloat16 is not supported
101104
with torch.autocast("cuda", dtype=torch.float16):
102105
return func(*args, **kwargs)
106+
103107
return wrapper
104108

105109

amt/run.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,25 @@ def _add_maestro_args(subparser):
2525

2626
def _add_transcribe_args(subparser):
2727
subparser.add_argument("model_name", help="name of model config file")
28-
subparser.add_argument('checkpoint_path', help="checkpoint path")
29-
subparser.add_argument("-load_path", help="path to mp3/wav file", required=False)
28+
subparser.add_argument("checkpoint_path", help="checkpoint path")
29+
subparser.add_argument(
30+
"-load_path", help="path to mp3/wav file", required=False
31+
)
3032
subparser.add_argument(
3133
"-load_dir", help="dir containing mp3/wav files", required=False
3234
)
33-
subparser.add_argument("-save_dir", help="dir to save midi files", required=True)
35+
subparser.add_argument(
36+
"-save_dir", help="dir to save midi files", required=True
37+
)
3438
subparser.add_argument(
3539
"-multi_gpu", help="use all GPUs", action="store_true", default=False
3640
)
3741
subparser.add_argument("-bs", help="batch size", type=int, default=16)
3842

3943

40-
def build_maestro(maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs):
44+
def build_maestro(
45+
maestro_dir, maestro_csv_file, train_file, val_file, test_file, num_procs
46+
):
4147
from amt.data import AmtDataset
4248

4349
assert os.path.isdir(maestro_dir), "MAESTRO directory not found"
@@ -101,9 +107,14 @@ def build_maestro(maestro_dir, maestro_csv_file, train_file, val_file, test_file
101107

102108

103109
def transcribe(
104-
model_name, checkpoint_path, save_dir, load_path=None, load_dir=None,
105-
batch_size=16, multi_gpu=False,
106-
augment=None,
110+
model_name,
111+
checkpoint_path,
112+
save_dir,
113+
load_path=None,
114+
load_dir=None,
115+
batch_size=16,
116+
multi_gpu=False,
117+
augment=None,
107118
):
108119
"""
109120
Transcribe audio files to midi using the given model and checkpoint.
@@ -139,9 +150,7 @@ def transcribe(
139150
assert os.path.isfile(checkpoint_path), "model checkpoint file not found"
140151
assert load_path or load_dir, "must give either load path or dir"
141152
if load_path:
142-
assert os.path.isfile(
143-
load_path
144-
), f"audio file not found: {load_path}"
153+
assert os.path.isfile(load_path), f"audio file not found: {load_path}"
145154
trans_mode = "single"
146155
if load_dir:
147156
assert os.path.isdir(load_dir), "load directory doesn't exist"
@@ -232,8 +241,12 @@ def main():
232241
parser = argparse.ArgumentParser(usage="amt <command> [<args>]")
233242
subparsers = parser.add_subparsers(help="sub-command help")
234243
# add maestro and transcribe subparsers
235-
subparser_maestro = subparsers.add_parser("maestro", help="Commands to build the maestro dataset.")
236-
subparser_transcribe = subparsers.add_parser("transcribe", help="Commands to run transcription.")
244+
subparser_maestro = subparsers.add_parser(
245+
"maestro", help="Commands to build the maestro dataset."
246+
)
247+
subparser_transcribe = subparsers.add_parser(
248+
"transcribe", help="Commands to run transcription."
249+
)
237250
_add_maestro_args(subparser_maestro)
238251
_add_transcribe_args(subparser_transcribe)
239252

0 commit comments

Comments
 (0)