Skip to content

Commit d56e8e5

Browse files
authored
Add pedal msgs to tokenizer (#18)
* add pedal msgs to tokenizer * fix eos token * format * improve inference
1 parent 12d249b commit d56e8e5

File tree

3 files changed

+158
-52
lines changed

3 files changed

+158
-52
lines changed

amt/infer.py

+51-26
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77

88
from torch.multiprocessing import Queue
99
from tqdm import tqdm
10+
from functools import wraps
11+
from torch.cuda import is_bf16_supported
1012

1113
from amt.model import AmtEncoderDecoder
1214
from amt.tokenizer import AmtTokenizer
13-
from amt.audio import AudioTransform
15+
from amt.audio import AudioTransform, pad_or_trim
1416
from amt.data import get_wav_mid_segments
1517

18+
1619
MAX_SEQ_LEN = 4096
1720
LEN_MS = 30000
1821
STRIDE_FACTOR = 3
1922
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR
20-
BEAM = 3
21-
ONSET_TOLERANCE = 50
22-
VEL_TOLERANCE = 50
23+
BEAM = 5
24+
ONSET_TOLERANCE = 61
25+
VEL_TOLERANCE = 100
2326

2427

2528
def _setup_logger():
@@ -105,10 +108,6 @@ def calculate_onset(
105108
return tokenizer.tok_to_id[("onset", new_onset)]
106109

107110

108-
from functools import wraps
109-
from torch.cuda import is_bf16_supported
110-
111-
112111
def optional_bf16_autocast(func):
113112
@wraps(func)
114113
def wrapper(*args, **kwargs):
@@ -145,7 +144,7 @@ def process_segments(
145144
tokenizer.trunc_seq(prefix, MAX_SEQ_LEN) for prefix in raw_prefixes
146145
]
147146
seq = torch.stack([tokenizer.encode(prefix) for prefix in prefixes]).cuda()
148-
eos_seen = [False for _ in prefixes]
147+
end_idxs = [MAX_SEQ_LEN for _ in prefixes]
149148

150149
kv_cache = model.get_empty_cache()
151150

@@ -173,7 +172,7 @@ def process_segments(
173172
next_tok_ids = torch.argmax(logits[:, -1], dim=-1)
174173

175174
for batch_idx in range(logits.shape[0]):
176-
if eos_seen[batch_idx] is not False:
175+
if idx > end_idxs[batch_idx]:
177176
# End already seen, add pad token
178177
tok_id = tokenizer.pad_id
179178
elif idx >= prefix_lens[batch_idx]:
@@ -192,20 +191,24 @@ def process_segments(
192191
tok_id = tokenizer.tok_to_id[prefixes[batch_idx][idx]]
193192

194193
seq[batch_idx, idx] = tok_id
195-
if tokenizer.id_to_tok[tok_id] == tokenizer.eos_tok:
196-
eos_seen[batch_idx] = idx
197-
198-
if all(eos_seen):
194+
tok = tokenizer.id_to_tok[tok_id]
195+
if tok == tokenizer.eos_tok:
196+
end_idxs[batch_idx] = idx
197+
elif (
198+
type(tok) is tuple
199+
and tok[0] == "onset"
200+
and tok[1] >= LEN_MS - CHUNK_LEN_MS
201+
):
202+
end_idxs[batch_idx] = idx - 2
203+
204+
if all(_idx <= idx for _idx in end_idxs):
199205
break
200206

201-
if not all(eos_seen):
207+
if not all(_idx <= idx for _idx in end_idxs):
202208
logger.warning("Context length overflow when transcribing segment")
203-
for _idx in range(seq.shape[0]):
204-
if eos_seen[_idx] == False:
205-
eos_seen[_idx] = MAX_SEQ_LEN
206209

207210
results = [
208-
tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1])
211+
tokenizer.decode(seq[_idx, : end_idxs[_idx] + 1])
209212
for _idx in range(seq.shape[0])
210213
]
211214

@@ -218,7 +221,7 @@ def gpu_manager(
218221
model: AmtEncoderDecoder,
219222
batch_size: int,
220223
):
221-
# model.compile()
224+
model.compile()
222225
logger = _setup_logger()
223226
audio_transform = AudioTransform().cuda()
224227
tokenizer = AmtTokenizer(return_tensors=True)
@@ -283,7 +286,7 @@ def _truncate_seq(
283286
except:
284287
return ["<S>"]
285288
else:
286-
return res[: res.index(tokenizer.eos_tok)]
289+
return res[: res.index(tokenizer.eos_tok)] # Needs to change
287290

288291

289292
def process_file(
@@ -302,8 +305,15 @@ def process_file(
302305
audio_path=file_path, stride_factor=STRIDE_FACTOR
303306
)
304307
]
305-
seq = ["<S>"]
306-
res = ["<S>"]
308+
309+
# Add addtional (padded) final audio segment
310+
_last_seg = audio_segments[-1]
311+
audio_segments.append(
312+
pad_or_trim(_last_seg[len(_last_seg) // STRIDE_FACTOR :])
313+
)
314+
315+
seq = [tokenizer.bos_tok]
316+
res = [tokenizer.bos_tok]
307317
for idx, audio_seg in enumerate(audio_segments):
308318
init_idx = len(seq)
309319

@@ -318,15 +328,18 @@ def process_file(
318328
result_queue.put(gpu_result)
319329

320330
res += _shift_onset(
321-
seq[init_idx : seq.index(tokenizer.eos_tok)],
331+
seq[init_idx:],
322332
idx * CHUNK_LEN_MS,
323333
)
324334

325335
if idx == len(audio_segments) - 1:
326336
break
337+
elif res[-1] == tokenizer.eos_tok:
338+
logger.info(f"Exiting early")
339+
break
327340
else:
328-
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS)
329-
if len(seq) == 1:
341+
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS - CHUNK_LEN_MS)
342+
if len(seq) <= 2:
330343
logger.info(f"Exiting early")
331344
return res
332345

@@ -441,3 +454,15 @@ def batch_transcribe(
441454
p.join()
442455

443456
gpu_manager_process.join()
457+
458+
459+
def sample_top_p(probs, p):
460+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
461+
probs_sum = torch.cumsum(probs_sort, dim=-1)
462+
mask = probs_sum - probs_sort > p
463+
probs_sort[mask] = 0.0
464+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
465+
next_token = torch.multinomial(probs_sort, num_samples=1)
466+
next_token = torch.gather(probs_idx, -1, next_token)
467+
468+
return next_token

amt/tokenizer.py

+90-25
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, return_tensors: bool = False):
4646
self.prev_tokens = [("prev", i) for i in range(128)]
4747
self.note_on_tokens = [("on", i) for i in range(128)]
4848
self.note_off_tokens = [("off", i) for i in range(128)]
49+
self.pedal_tokens = [("pedal", 0), (("pedal", 1))]
4950
self.velocity_tokens = [("vel", i) for i in self.velocity_quantizations]
5051
self.onset_tokens = [
5152
("onset", i) for i in self.onset_time_quantizations
@@ -56,6 +57,7 @@ def __init__(self, return_tensors: bool = False):
5657
+ self.prev_tokens
5758
+ self.note_on_tokens
5859
+ self.note_off_tokens
60+
+ self.pedal_tokens
5961
+ self.velocity_tokens
6062
+ self.onset_tokens
6163
)
@@ -76,7 +78,10 @@ def _quantize_velocity(self, velocity: int):
7678
else:
7779
return velocity_quantized
7880

79-
# This method needs to be cleaned up completely, variables renamed
81+
# TODO:
82+
# - I need to make this method more robust, as it will have to handle
83+
# an arbitrary MIDI file
84+
# - Decide whether to put pedal messages as prev tokens
8085
def _tokenize_midi_dict(
8186
self,
8287
midi_dict: MidiDict,
@@ -88,6 +93,12 @@ def _tokenize_midi_dict(
8893
), "Invalid values for start_ms, end_ms"
8994

9095
midi_dict.resolve_pedal() # Important !!
96+
pedal_intervals = midi_dict._build_pedal_intervals()
97+
if len(pedal_intervals.keys()) > 1:
98+
print("Warning: midi_dict has more than one pedal channel")
99+
pedal_intervals = pedal_intervals[0]
100+
101+
last_msg_ms = -1
91102
on_off_notes = []
92103
prev_notes = []
93104
for msg in midi_dict.note_msgs:
@@ -109,6 +120,9 @@ def _tokenize_midi_dict(
109120
ticks_per_beat=midi_dict.ticks_per_beat,
110121
)
111122

123+
if note_end_ms > last_msg_ms:
124+
last_msg_ms = note_end_ms
125+
112126
rel_note_start_ms_q = self._quantize_onset(note_start_ms - start_ms)
113127
rel_note_end_ms_q = self._quantize_onset(note_end_ms - start_ms)
114128
velocity_q = self._quantize_velocity(_velocity)
@@ -149,35 +163,70 @@ def _tokenize_midi_dict(
149163
("off", _pitch, rel_note_end_ms_q, None)
150164
)
151165

152-
on_off_notes.sort(key=lambda x: (x[2], x[0] == "on"))
166+
on_off_pedal = []
167+
for pedal_on_tick, pedal_off_tick in pedal_intervals:
168+
pedal_on_ms = get_duration_ms(
169+
start_tick=0,
170+
end_tick=pedal_on_tick,
171+
tempo_msgs=midi_dict.tempo_msgs,
172+
ticks_per_beat=midi_dict.ticks_per_beat,
173+
)
174+
pedal_off_ms = get_duration_ms(
175+
start_tick=0,
176+
end_tick=pedal_off_tick,
177+
tempo_msgs=midi_dict.tempo_msgs,
178+
ticks_per_beat=midi_dict.ticks_per_beat,
179+
)
180+
181+
rel_on_ms_q = self._quantize_onset(pedal_on_ms - start_ms)
182+
rel_off_ms_q = self._quantize_onset(pedal_off_ms - start_ms)
183+
184+
# On message
185+
if pedal_on_ms <= start_ms or pedal_on_ms >= end_ms:
186+
continue
187+
else:
188+
on_off_pedal.append(("pedal", 1, rel_on_ms_q, None))
189+
190+
# Off message
191+
if pedal_off_ms <= start_ms or pedal_off_ms >= end_ms:
192+
continue
193+
else:
194+
on_off_pedal.append(("pedal", 0, rel_off_ms_q, None))
195+
196+
on_off_combined = on_off_notes + on_off_pedal
197+
on_off_combined.sort(
198+
key=lambda x: (
199+
x[2],
200+
(0 if x[0] == "pedal" else 1 if x[0] == "off" else 2),
201+
)
202+
)
153203
random.shuffle(prev_notes)
154204

155205
tokenized_seq = []
156-
note_status = {}
157-
for pitch in prev_notes:
158-
note_status[pitch] = True
159-
for note in on_off_notes:
160-
_type, _pitch, _onset, _velocity = note
206+
for tok in on_off_combined:
207+
_type, _val, _onset, _velocity = tok
161208
if _type == "on":
162-
if note_status.get(_pitch) == True:
163-
# Place holder - we can remove note_status logic now
164-
raise Exception
165-
166-
tokenized_seq.append(("on", _pitch))
209+
tokenized_seq.append(("on", _val))
167210
tokenized_seq.append(("onset", _onset))
168211
tokenized_seq.append(("vel", _velocity))
169-
note_status[_pitch] = True
170212
elif _type == "off":
171-
if note_status.get(_pitch) == False:
172-
# Place holder - we can remove note_status logic now
173-
raise Exception
174-
else:
175-
tokenized_seq.append(("off", _pitch))
213+
tokenized_seq.append(("off", _val))
214+
tokenized_seq.append(("onset", _onset))
215+
elif _type == "pedal":
216+
if _val == 0:
217+
tokenized_seq.append(("pedal", _val))
218+
tokenized_seq.append(("onset", _onset))
219+
elif _val:
220+
tokenized_seq.append(("pedal", _val))
176221
tokenized_seq.append(("onset", _onset))
177-
note_status[_pitch] = False
178222

179223
prefix = [("prev", p) for p in prev_notes]
180-
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]
224+
225+
# Add eos_tok only if segment includes end of midi_dict
226+
if last_msg_ms < end_ms:
227+
return prefix + [self.bos_tok] + tokenized_seq + [self.eos_tok]
228+
else:
229+
return prefix + [self.bos_tok] + tokenized_seq
181230

182231
def _detokenize_midi_dict(
183232
self,
@@ -243,16 +292,29 @@ def _detokenize_midi_dict(
243292
print("Unexpected token order: 'prev' seen after '<S>'")
244293
if DEBUG:
245294
raise Exception
295+
elif tok_1_type == "pedal":
296+
# Pedal information contained in note-off messages, so we don't
297+
# need to manually processes them
298+
_pedal_data = tok_1_data
299+
_tick = tok_2_data
300+
note_msgs.append(
301+
{
302+
"type": "pedal",
303+
"data": _pedal_data,
304+
"tick": _tick,
305+
"channel": 0,
306+
}
307+
)
246308
elif tok_1_type == "on":
247309
if (tok_2_type, tok_3_type) != ("onset", "vel"):
248-
print("Unexpected token order")
310+
print("Unexpected token order:", tok_1, tok_2, tok_3)
249311
if DEBUG:
250312
raise Exception
251313
else:
252314
notes_to_close[tok_1_data] = (tok_2_data, tok_3_data)
253315
elif tok_1_type == "off":
254316
if tok_2_type != "onset":
255-
print("Unexpected token order")
317+
print("Unexpected token order:", tok_1, tok_2, tok_3)
256318
if DEBUG:
257319
raise Exception
258320
else:
@@ -336,9 +398,6 @@ def export_data_aug(self):
336398

337399
def export_msg_mixup(self):
338400
def msg_mixup(src: list):
339-
def round_to_base(n, base=150):
340-
return base * round(n / base)
341-
342401
# Process bos, eos, and pad tokens
343402
orig_len = len(src)
344403
seen_pad_tok = False
@@ -387,13 +446,19 @@ def round_to_base(n, base=150):
387446
elif tok_1_type == "off":
388447
_onset = tok_2_data
389448
buffer[_onset]["off"].append((tok_1, tok_2))
449+
elif tok_1_type == "pedal":
450+
_onset = tok_2_data
451+
buffer[_onset]["pedal"].append((tok_1, tok_2))
390452
else:
391453
pass
392454

393455
# Shuffle order and re-append to result
394456
for k, v in sorted(buffer.items()):
395457
random.shuffle(v["on"])
396458
random.shuffle(v["off"])
459+
for item in v["pedal"]:
460+
res.append(item[0]) # Pedal
461+
res.append(item[1]) # Onset
397462
for item in v["off"]:
398463
res.append(item[0]) # Pitch
399464
res.append(item[1]) # Onset

0 commit comments

Comments
 (0)