1
1
import mmap
2
2
import os
3
3
import io
4
+ import random
5
+ import shlex
4
6
import base64
5
7
import shutil
6
8
import orjson
16
18
from amt .audio import pad_or_trim
17
19
18
20
19
- # Occasionally the worker util goes to 0 for some reason, debug this
21
+ def _check_onset_threshold (seq : list , onset : int ):
22
+ for tok_1 , tok_2 in zip (seq , seq [1 :]):
23
+ if isinstance (tok_1 , tuple ) and tok_1 [0 ] in ("on" , "off" ):
24
+ _onset = tok_2 [1 ]
25
+ if _onset > onset :
26
+ return True
27
+
28
+ return False
20
29
21
30
22
31
def get_wav_mid_segments (
23
32
audio_path : str ,
24
33
mid_path : str = "" ,
25
34
return_json : bool = False ,
26
35
stride_factor : int | None = None ,
36
+ pad_last = False ,
27
37
):
28
38
"""This function yields tuples of matched log mel spectrograms and
29
39
tokenized sequences (np.array, list). If it is given only an audio path
@@ -61,10 +71,12 @@ def get_wav_mid_segments(
61
71
62
72
# Create features
63
73
total_samples = wav .shape [- 1 ]
74
+ pad_factor = 2 if pad_last is True else 1
64
75
res = []
65
76
for idx in range (
66
77
0 ,
67
- total_samples - (num_samples - num_samples // stride_factor ),
78
+ total_samples
79
+ - (num_samples - pad_factor * (num_samples // stride_factor )),
68
80
num_samples // stride_factor ,
69
81
):
70
82
audio_feature = pad_or_trim (wav [idx :], length = num_samples )
@@ -75,6 +87,12 @@ def get_wav_mid_segments(
75
87
end_ms = (idx + num_samples ) / samples_per_ms ,
76
88
max_pedal_len_ms = 10000 ,
77
89
)
90
+
91
+ # Hardcoded to 2.5s
92
+ if _check_onset_threshold (mid_feature , 2500 ) is False :
93
+ print ("No note messages after 2.5s - skipping" )
94
+ continue
95
+
78
96
else :
79
97
mid_feature = []
80
98
@@ -86,6 +104,56 @@ def get_wav_mid_segments(
86
104
return res
87
105
88
106
107
+ def pianoteq_cmd_fn (mid_path : str , wav_path : str ):
108
+ presets = [
109
+ "C. Bechstein" ,
110
+ "C. Bechstein Close Mic" ,
111
+ "C. Bechstein Under Lid" ,
112
+ "C. Bechstein 440" ,
113
+ "C. Bechstein Recording" ,
114
+ "C. Bechstein Werckmeister III" ,
115
+ "C. Bechstein Neidhardt III" ,
116
+ "C. Bechstein mesotonic" ,
117
+ "C. Bechstein well tempered" ,
118
+ "HB Steinway D Blues" ,
119
+ "HB Steinway D Pop" ,
120
+ "HB Steinway D New Age" ,
121
+ "HB Steinway D Prelude" ,
122
+ "HB Steinway D Felt I" ,
123
+ "HB Steinway D Felt II" ,
124
+ "HB Steinway Model D" ,
125
+ "HB Steinway D Classical Recording" ,
126
+ "HB Steinway D Jazz Recording" ,
127
+ "HB Steinway D Chamber Recording" ,
128
+ "HB Steinway D Studio Recording" ,
129
+ "HB Steinway D Intimate" ,
130
+ "HB Steinway D Cinematic" ,
131
+ "HB Steinway D Close Mic Classical" ,
132
+ "HB Steinway D Close Mic Jazz" ,
133
+ "HB Steinway D Player Wide" ,
134
+ "HB Steinway D Player Clean" ,
135
+ "HB Steinway D Trio" ,
136
+ "HB Steinway D Duo" ,
137
+ "HB Steinway D Cabaret" ,
138
+ "HB Steinway D Bright" ,
139
+ "HB Steinway D Hyper Bright" ,
140
+ "HB Steinway D Prepared" ,
141
+ "HB Steinway D Honky Tonk" ,
142
+ ]
143
+
144
+ preset = random .choice (presets )
145
+
146
+ # Safely quote the preset name, MIDI path, and WAV path
147
+ safe_preset = shlex .quote (preset )
148
+ safe_mid_path = shlex .quote (mid_path )
149
+ safe_wav_path = shlex .quote (wav_path )
150
+
151
+ # Construct the command
152
+ command = f"/home/mchorse/pianoteq/x86-64bit/Pianoteq\\ 8\\ STAGE --preset { safe_preset } --midi { safe_mid_path } --wav { safe_wav_path } "
153
+
154
+ return command
155
+
156
+
89
157
def write_features (audio_path : str , mid_path : str , save_path : str ):
90
158
features = get_wav_mid_segments (
91
159
audio_path = audio_path ,
@@ -121,7 +189,7 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
121
189
122
190
try :
123
191
get_synth_audio (
124
- cli_cmd = cli_cmd_fn , mid_path = mid_path , wav_path = audio_path_temp
192
+ cli_cmd_fn = cli_cmd_fn , mid_path = mid_path , wav_path = audio_path_temp
125
193
)
126
194
except :
127
195
if os .path .isfile (audio_path_temp ):
@@ -133,7 +201,11 @@ def write_synth_features(cli_cmd_fn: Callable, mid_path: str, save_path: str):
133
201
mid_path = mid_path ,
134
202
return_json = False ,
135
203
)
136
- os .remove (audio_path_temp )
204
+
205
+ if os .path .isfile (audio_path_temp ):
206
+ os .remove (audio_path_temp )
207
+
208
+ print (f"Found { len (features )} " )
137
209
138
210
with open (save_path , mode = "a" ) as file :
139
211
for wav , seq in features :
@@ -174,7 +246,11 @@ def build_synth_worker_fn(
174
246
175
247
while not load_path_queue .empty ():
176
248
mid_path = load_path_queue .get ()
177
- write_synth_features (cli_cmd , mid_path , worker_save_path )
249
+ try :
250
+ write_synth_features (cli_cmd , mid_path , worker_save_path )
251
+ except Exception as e :
252
+ print ("Failed" )
253
+ print (e )
178
254
179
255
save_path_queue .put (worker_save_path )
180
256
@@ -239,7 +315,7 @@ def _format(tok):
239
315
seq_len = self .config ["max_seq_len" ],
240
316
)
241
317
242
- return wav , self .tokenizer .encode (src ), self .tokenizer .encode (tgt )
318
+ return wav , self .tokenizer .encode (src ), self .tokenizer .encode (tgt ), idx
243
319
244
320
def _build_index (self ):
245
321
self .file_mmap .seek (0 )
@@ -254,7 +330,7 @@ def _build_index(self):
254
330
255
331
return index
256
332
257
- def _save_index (self , index : list [ int ] , save_path : str ):
333
+ def _save_index (self , index : list , save_path : str ):
258
334
with open (save_path , "w" ) as file :
259
335
for idx in index :
260
336
file .write (f"{ idx } \n " )
@@ -325,7 +401,7 @@ def build(
325
401
]
326
402
else :
327
403
# Build synthetic dataset
328
- assert len (load_paths [0 ]) == 1 , "Invalid load paths"
404
+ assert isinstance (load_paths [0 ], str ) , "Invalid load paths"
329
405
print ("Building synthetic dataset" )
330
406
worker_processes = [
331
407
Process (
0 commit comments