Skip to content

Commit 1c9d666

Browse files
authored
Fix training and inference (#24)
* fix * fix batched * adjust * transfer * add cleanup * working * move_to_node * fix config * fix msg * add synth dataset gen * remote changes * local changes * add scripts * fix audio params
1 parent 50f0b60 commit 1c9d666

21 files changed

+1064
-302
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# data files
22
*.csv
3-
*.json
43
*.xls
54
*.xlsx
65
*.pkl

amt/audio.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,15 @@ def __init__(
191191
max_snr: int = 50,
192192
max_dist_gain: int = 25,
193193
min_dist_gain: int = 0,
194-
noise_ratio: float = 0.95,
195-
reverb_ratio: float = 0.95,
194+
noise_ratio: float = 0.75,
195+
reverb_ratio: float = 0.75,
196196
applause_ratio: float = 0.01,
197197
bandpass_ratio: float = 0.15,
198198
distort_ratio: float = 0.15,
199199
reduce_ratio: float = 0.01,
200-
detune_ratio: float = 0.1,
201-
detune_max_shift: float = 0.15,
202-
spec_aug_ratio: float = 0.5,
200+
detune_ratio: float = 0.0,
201+
detune_max_shift: float = 0.0,
202+
spec_aug_ratio: float = 0.9,
203203
):
204204
super().__init__()
205205
self.tokenizer = AmtTokenizer()
@@ -223,7 +223,10 @@ def __init__(
223223
self.detune_ratio = detune_ratio
224224
self.detune_max_shift = detune_max_shift
225225
self.spec_aug_ratio = spec_aug_ratio
226-
self.reduction_resample_rate = 6000 # Hardcoded?
226+
227+
self.time_mask_param = 2500
228+
self.freq_mask_param = 15
229+
self.reduction_resample_rate = 6000
227230

228231
# Audio aug
229232
impulse_paths = self._get_paths(
@@ -263,10 +266,10 @@ def __init__(
263266
)
264267
self.spec_aug = torch.nn.Sequential(
265268
torchaudio.transforms.FrequencyMasking(
266-
freq_mask_param=15, iid_masks=True
269+
freq_mask_param=self.freq_mask_param, iid_masks=True
267270
),
268271
torchaudio.transforms.TimeMasking(
269-
time_mask_param=1000, iid_masks=True
272+
time_mask_param=self.time_mask_param, iid_masks=True
270273
),
271274
)
272275

@@ -281,6 +284,8 @@ def get_params(self):
281284
"detune_ratio": self.detune_ratio,
282285
"detune_max_shift": self.detune_max_shift,
283286
"spec_aug_ratio": self.spec_aug_ratio,
287+
"time_mask_param": self.time_mask_param,
288+
"freq_mask_param": self.freq_mask_param,
284289
}
285290

286291
def _get_paths(self, dir_path):

amt/data.py

+84-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import mmap
22
import os
33
import io
4+
import random
5+
import shlex
46
import base64
57
import shutil
68
import orjson
@@ -16,14 +18,22 @@
1618
from amt.audio import pad_or_trim
1719

1820

19-
# Occasionally the worker util goes to 0 for some reason, debug this
21+
def _check_onset_threshold(seq: list, onset: int):
22+
for tok_1, tok_2 in zip(seq, seq[1:]):
23+
if isinstance(tok_1, tuple) and tok_1[0] in ("on", "off"):
24+
_onset = tok_2[1]
25+
if _onset > onset:
26+
return True
27+
28+
return False
2029

2130

2231
def get_wav_mid_segments(
2332
audio_path: str,
2433
mid_path: str = "",
2534
return_json: bool = False,
2635
stride_factor: int | None = None,
36+
pad_last=False,
2737
):
2838
"""This function yields tuples of matched log mel spectrograms and
2939
tokenized sequences (np.array, list). If it is given only an audio path
@@ -61,10 +71,12 @@ def get_wav_mid_segments(
6171

6272
# Create features
6373
total_samples = wav.shape[-1]
74+
pad_factor = 2 if pad_last is True else 1
6475
res = []
6576
for idx in range(
6677
0,
67-
total_samples - (num_samples - num_samples // stride_factor),
78+
total_samples
79+
- (num_samples - pad_factor * (num_samples // stride_factor)),
6880
num_samples // stride_factor,
6981
):
7082
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
@@ -75,6 +87,12 @@ def get_wav_mid_segments(
7587
end_ms=(idx + num_samples) / samples_per_ms,
7688
max_pedal_len_ms=10000,
7789
)
90+
91+
# Hardcoded to 2.5s
92+
if _check_onset_threshold(mid_feature, 2500) is False:
93+
print("No note messages after 2.5s - skipping")
94+
continue
95+
7896
else:
7997
mid_feature = []
8098

@@ -86,6 +104,56 @@ def get_wav_mid_segments(
86104
return res
87105

88106

107+
def pianoteq_cmd_fn(mid_path: str, wav_path: str):
108+
presets = [
109+
"C. Bechstein",
110+
"C. Bechstein Close Mic",
111+
"C. Bechstein Under Lid",
112+
"C. Bechstein 440",
113+
"C. Bechstein Recording",
114+
"C. Bechstein Werckmeister III",
115+
"C. Bechstein Neidhardt III",
116+
"C. Bechstein mesotonic",
117+
"C. Bechstein well tempered",
118+
"HB Steinway D Blues",
119+
"HB Steinway D Pop",
120+
"HB Steinway D New Age",
121+
"HB Steinway D Prelude",
122+
"HB Steinway D Felt I",
123+
"HB Steinway D Felt II",
124+
"HB Steinway Model D",
125+
"HB Steinway D Classical Recording",
126+
"HB Steinway D Jazz Recording",
127+
"HB Steinway D Chamber Recording",
128+
"HB Steinway D Studio Recording",
129+
"HB Steinway D Intimate",
130+
"HB Steinway D Cinematic",
131+
"HB Steinway D Close Mic Classical",
132+
"HB Steinway D Close Mic Jazz",
133+
"HB Steinway D Player Wide",
134+
"HB Steinway D Player Clean",
135+
"HB Steinway D Trio",
136+
"HB Steinway D Duo",
137+
"HB Steinway D Cabaret",
138+
"HB Steinway D Bright",
139+
"HB Steinway D Hyper Bright",
140+
"HB Steinway D Prepared",
141+
"HB Steinway D Honky Tonk",
142+
]
143+
144+
preset = random.choice(presets)
145+
146+
# Safely quote the preset name, MIDI path, and WAV path
147+
safe_preset = shlex.quote(preset)
148+
safe_mid_path = shlex.quote(mid_path)
149+
safe_wav_path = shlex.quote(wav_path)
150+
151+
# Construct the command
152+
command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset {safe_preset} --midi {safe_mid_path} --wav {safe_wav_path}"
153+
154+
return command
155+
156+
89157
def write_features(audio_path: str, mid_path: str, save_path: str):
90158
features = get_wav_mid_segments(
91159
audio_path=audio_path,
@@ -121,7 +189,7 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
121189

122190
try:
123191
get_synth_audio(
124-
cli_cmd=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
192+
cli_cmd_fn=cli_cmd_fn, mid_path=mid_path, wav_path=audio_path_temp
125193
)
126194
except:
127195
if os.path.isfile(audio_path_temp):
@@ -133,7 +201,11 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
133201
mid_path=mid_path,
134202
return_json=False,
135203
)
136-
os.remove(audio_path_temp)
204+
205+
if os.path.isfile(audio_path_temp):
206+
os.remove(audio_path_temp)
207+
208+
print(f"Found {len(features)}")
137209

138210
with open(save_path, mode="a") as file:
139211
for wav, seq in features:
@@ -174,7 +246,11 @@ def build_synth_worker_fn(
174246

175247
while not load_path_queue.empty():
176248
mid_path = load_path_queue.get()
177-
write_synth_features(cli_cmd, mid_path, worker_save_path)
249+
try:
250+
write_synth_features(cli_cmd, mid_path, worker_save_path)
251+
except Exception as e:
252+
print("Failed")
253+
print(e)
178254

179255
save_path_queue.put(worker_save_path)
180256

@@ -239,7 +315,7 @@ def _format(tok):
239315
seq_len=self.config["max_seq_len"],
240316
)
241317

242-
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt)
318+
return wav, self.tokenizer.encode(src), self.tokenizer.encode(tgt), idx
243319

244320
def _build_index(self):
245321
self.file_mmap.seek(0)
@@ -254,7 +330,7 @@ def _build_index(self):
254330

255331
return index
256332

257-
def _save_index(self, index: list[int], save_path: str):
333+
def _save_index(self, index: list, save_path: str):
258334
with open(save_path, "w") as file:
259335
for idx in index:
260336
file.write(f"{idx}\n")
@@ -325,7 +401,7 @@ def build(
325401
]
326402
else:
327403
# Build synthetic dataset
328-
assert len(load_paths[0]) == 1, "Invalid load paths"
404+
assert isinstance(load_paths[0], str), "Invalid load paths"
329405
print("Building synthetic dataset")
330406
worker_processes = [
331407
Process(

amt/evaluate.py

+27-51
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,32 @@ def midi_to_hz(note, shift=0):
4242
# return (a / 32) * (2 ** ((note - 9) / 12))
4343

4444

45+
def get_matched_files(est_dir: str, ref_dir: str):
46+
# We assume that the files have the same path relative to their directory
47+
48+
res = []
49+
est_paths = glob.glob(os.path.join(est_dir, "**/*.mid"), recursive=True)
50+
print(f"found {len(est_paths)} est files")
51+
52+
for est_path in est_paths:
53+
est_rel_path = os.path.relpath(est_path, est_dir)
54+
ref_path = os.path.join(
55+
ref_dir, os.path.splitext(est_rel_path)[0] + ".midi"
56+
)
57+
if os.path.isfile(ref_path):
58+
res.append((est_path, ref_path))
59+
60+
print(f"found {len(res)} matched est-ref pairs")
61+
62+
return res
63+
64+
4565
def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
4666
"""
4767
Evaluate the estimated pitches against the reference pitches using mir_eval.
4868
"""
49-
# Evaluate the estimated pitches against the reference pitches
50-
ref_midi_files = glob.glob(f"{ref_dir}/*.mid*")
51-
est_midi_files = glob.glob(f"{est_dir}/*.mid*")
52-
53-
est_ref_pairs = []
54-
for est_fpath in est_midi_files:
55-
ref_fpath = os.path.join(ref_dir, os.path.basename(est_fpath))
56-
if ref_fpath in ref_midi_files:
57-
est_ref_pairs.append((est_fpath, ref_fpath))
58-
if ref_fpath.replace(".mid", ".midi") in ref_midi_files:
59-
est_ref_pairs.append(
60-
(est_fpath, ref_fpath.replace(".mid", ".midi"))
61-
)
62-
else:
63-
print(
64-
f"Reference file not found for {est_fpath} (ref file: {ref_fpath})"
65-
)
69+
70+
est_ref_pairs = get_matched_files(est_dir, ref_dir)
6671

6772
output_fhandle = (
6873
open(output_stats_file, "w") if output_stats_file is not None else None
@@ -104,38 +109,9 @@ def evaluate_mir_eval(est_dir, ref_dir, output_stats_file=None, est_shift=0):
104109
help="Path to the file to save the evaluation stats",
105110
)
106111

107-
# add mir_eval and dtw subparsers
108-
subparsers = parser.add_subparsers(help="sub-command help")
109-
mir_eval_parse = subparsers.add_parser(
110-
"run_mir_eval",
111-
help="Run standard mir_eval evaluation on MAESTRO test set.",
112-
)
113-
mir_eval_parse.add_argument(
114-
"--shift",
115-
type=int,
116-
default=0,
117-
help="Shift to apply to the estimated pitches.",
118-
)
119-
120-
# to come
121-
dtw_eval_parse = subparsers.add_parser(
122-
"run_dtw",
123-
help="Run dynamic time warping evaluation on a specified dataset.",
124-
)
125-
126112
args = parser.parse_args()
127-
if not hasattr(args, "command"):
128-
parser.print_help()
129-
print("Unrecognized command")
130-
exit(1)
131-
132-
# todo: should we add an option to run transcription again every time we wish to evaluate?
133-
# that way, we can run both tests with a range of different audio augmentations right here.
134-
# -> We expect that baseline methods will fall flat on these, while aria-amt will be OK.
135-
136-
if args.command == "run_mir_eval":
137-
evaluate_mir_eval(
138-
args.est_dir, args.ref_dir, args.output_stats_file, args.shift
139-
)
140-
elif args.command == "run_dtw":
141-
pass
113+
evaluate_mir_eval(
114+
args.est_dir,
115+
args.ref_dir,
116+
args.output_stats_file,
117+
)

0 commit comments

Comments
 (0)