@@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False):
46
46
self .prev_tokens = [("prev" , i ) for i in range (128 )]
47
47
self .note_on_tokens = [("on" , i ) for i in range (128 )]
48
48
self .note_off_tokens = [("off" , i ) for i in range (128 )]
49
+ self .pedal_tokens = [("pedal" , 0 ), (("pedal" , 1 ))]
49
50
self .velocity_tokens = [("vel" , i ) for i in self .velocity_quantizations ]
50
51
self .onset_tokens = [
51
52
("onset" , i ) for i in self .onset_time_quantizations
@@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False):
56
57
+ self .prev_tokens
57
58
+ self .note_on_tokens
58
59
+ self .note_off_tokens
60
+ + self .pedal_tokens
59
61
+ self .velocity_tokens
60
62
+ self .onset_tokens
61
63
)
@@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int):
76
78
else :
77
79
return velocity_quantized
78
80
79
- # This method needs to be cleaned up completely, variables renamed
81
+ # TODO:
82
+ # - I need to make this method more robust, as it will have to handle
83
+ # an arbitrary MIDI file
84
+ # - Decide whether to put pedal messages as prev tokens
80
85
def _tokenize_midi_dict (
81
86
self ,
82
87
midi_dict : MidiDict ,
@@ -88,6 +93,12 @@ def _tokenize_midi_dict(
88
93
), "Invalid values for start_ms, end_ms"
89
94
90
95
midi_dict .resolve_pedal () # Important !!
96
+ pedal_intervals = midi_dict ._build_pedal_intervals ()
97
+ if len (pedal_intervals .keys ()) > 1 :
98
+ print ("Warning: midi_dict has more than one pedal channel" )
99
+ pedal_intervals = pedal_intervals [0 ]
100
+
101
+ last_msg_ms = - 1
91
102
on_off_notes = []
92
103
prev_notes = []
93
104
for msg in midi_dict .note_msgs :
@@ -109,6 +120,9 @@ def _tokenize_midi_dict(
109
120
ticks_per_beat = midi_dict .ticks_per_beat ,
110
121
)
111
122
123
+ if note_end_ms > last_msg_ms :
124
+ last_msg_ms = note_end_ms
125
+
112
126
rel_note_start_ms_q = self ._quantize_onset (note_start_ms - start_ms )
113
127
rel_note_end_ms_q = self ._quantize_onset (note_end_ms - start_ms )
114
128
velocity_q = self ._quantize_velocity (_velocity )
@@ -149,35 +163,70 @@ def _tokenize_midi_dict(
149
163
("off" , _pitch , rel_note_end_ms_q , None )
150
164
)
151
165
152
- on_off_notes .sort (key = lambda x : (x [2 ], x [0 ] == "on" ))
166
+ on_off_pedal = []
167
+ for pedal_on_tick , pedal_off_tick in pedal_intervals :
168
+ pedal_on_ms = get_duration_ms (
169
+ start_tick = 0 ,
170
+ end_tick = pedal_on_tick ,
171
+ tempo_msgs = midi_dict .tempo_msgs ,
172
+ ticks_per_beat = midi_dict .ticks_per_beat ,
173
+ )
174
+ pedal_off_ms = get_duration_ms (
175
+ start_tick = 0 ,
176
+ end_tick = pedal_off_tick ,
177
+ tempo_msgs = midi_dict .tempo_msgs ,
178
+ ticks_per_beat = midi_dict .ticks_per_beat ,
179
+ )
180
+
181
+ rel_on_ms_q = self ._quantize_onset (pedal_on_ms - start_ms )
182
+ rel_off_ms_q = self ._quantize_onset (pedal_off_ms - start_ms )
183
+
184
+ # On message
185
+ if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms :
186
+ continue
187
+ else :
188
+ on_off_pedal .append (("pedal" , 1 , rel_on_ms_q , None ))
189
+
190
+ # Off message
191
+ if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms :
192
+ continue
193
+ else :
194
+ on_off_pedal .append (("pedal" , 0 , rel_off_ms_q , None ))
195
+
196
+ on_off_combined = on_off_notes + on_off_pedal
197
+ on_off_combined .sort (
198
+ key = lambda x : (
199
+ x [2 ],
200
+ (0 if x [0 ] == "pedal" else 1 if x [0 ] == "off" else 2 ),
201
+ )
202
+ )
153
203
random .shuffle (prev_notes )
154
204
155
205
tokenized_seq = []
156
- note_status = {}
157
- for pitch in prev_notes :
158
- note_status [pitch ] = True
159
- for note in on_off_notes :
160
- _type , _pitch , _onset , _velocity = note
206
+ for tok in on_off_combined :
207
+ _type , _val , _onset , _velocity = tok
161
208
if _type == "on" :
162
- if note_status .get (_pitch ) == True :
163
- # Place holder - we can remove note_status logic now
164
- raise Exception
165
-
166
- tokenized_seq .append (("on" , _pitch ))
209
+ tokenized_seq .append (("on" , _val ))
167
210
tokenized_seq .append (("onset" , _onset ))
168
211
tokenized_seq .append (("vel" , _velocity ))
169
- note_status [_pitch ] = True
170
212
elif _type == "off" :
171
- if note_status .get (_pitch ) == False :
172
- # Place holder - we can remove note_status logic now
173
- raise Exception
174
- else :
175
- tokenized_seq .append (("off" , _pitch ))
213
+ tokenized_seq .append (("off" , _val ))
214
+ tokenized_seq .append (("onset" , _onset ))
215
+ elif _type == "pedal" :
216
+ if _val == 0 :
217
+ tokenized_seq .append (("pedal" , _val ))
218
+ tokenized_seq .append (("onset" , _onset ))
219
+ elif _val :
220
+ tokenized_seq .append (("pedal" , _val ))
176
221
tokenized_seq .append (("onset" , _onset ))
177
- note_status [_pitch ] = False
178
222
179
223
prefix = [("prev" , p ) for p in prev_notes ]
180
- return prefix + [self .bos_tok ] + tokenized_seq + [self .eos_tok ]
224
+
225
+ # Add eos_tok only if segment includes end of midi_dict
226
+ if last_msg_ms < end_ms :
227
+ return prefix + [self .bos_tok ] + tokenized_seq + [self .eos_tok ]
228
+ else :
229
+ return prefix + [self .bos_tok ] + tokenized_seq
181
230
182
231
def _detokenize_midi_dict (
183
232
self ,
@@ -243,16 +292,29 @@ def _detokenize_midi_dict(
243
292
print ("Unexpected token order: 'prev' seen after '<S>'" )
244
293
if DEBUG :
245
294
raise Exception
295
+ elif tok_1_type == "pedal" :
296
+ # Pedal information contained in note-off messages, so we don't
297
+ # need to manually processes them
298
+ _pedal_data = tok_1_data
299
+ _tick = tok_2_data
300
+ note_msgs .append (
301
+ {
302
+ "type" : "pedal" ,
303
+ "data" : _pedal_data ,
304
+ "tick" : _tick ,
305
+ "channel" : 0 ,
306
+ }
307
+ )
246
308
elif tok_1_type == "on" :
247
309
if (tok_2_type , tok_3_type ) != ("onset" , "vel" ):
248
- print ("Unexpected token order" )
310
+ print ("Unexpected token order:" , tok_1 , tok_2 , tok_3 )
249
311
if DEBUG :
250
312
raise Exception
251
313
else :
252
314
notes_to_close [tok_1_data ] = (tok_2_data , tok_3_data )
253
315
elif tok_1_type == "off" :
254
316
if tok_2_type != "onset" :
255
- print ("Unexpected token order" )
317
+ print ("Unexpected token order:" , tok_1 , tok_2 , tok_3 )
256
318
if DEBUG :
257
319
raise Exception
258
320
else :
@@ -336,9 +398,6 @@ def export_data_aug(self):
336
398
337
399
def export_msg_mixup (self ):
338
400
def msg_mixup (src : list ):
339
- def round_to_base (n , base = 150 ):
340
- return base * round (n / base )
341
-
342
401
# Process bos, eos, and pad tokens
343
402
orig_len = len (src )
344
403
seen_pad_tok = False
@@ -387,13 +446,19 @@ def round_to_base(n, base=150):
387
446
elif tok_1_type == "off" :
388
447
_onset = tok_2_data
389
448
buffer [_onset ]["off" ].append ((tok_1 , tok_2 ))
449
+ elif tok_1_type == "pedal" :
450
+ _onset = tok_2_data
451
+ buffer [_onset ]["pedal" ].append ((tok_1 , tok_2 ))
390
452
else :
391
453
pass
392
454
393
455
# Shuffle order and re-append to result
394
456
for k , v in sorted (buffer .items ()):
395
457
random .shuffle (v ["on" ])
396
458
random .shuffle (v ["off" ])
459
+ for item in v ["pedal" ]:
460
+ res .append (item [0 ]) # Pedal
461
+ res .append (item [1 ]) # Onset
397
462
for item in v ["off" ]:
398
463
res .append (item [0 ]) # Pitch
399
464
res .append (item [1 ]) # Onset
0 commit comments