55import torch .multiprocessing as multiprocessing
66
77from torch .multiprocessing import Queue
8- from torch .cuda import device_count , is_available
98from tqdm import tqdm
109
1110from amt .model import AmtEncoderDecoder
1211from amt .tokenizer import AmtTokenizer
1312from amt .audio import AudioTransform
1413from amt .data import get_wav_mid_segments
15- from amt .config import load_config
16- from aria .data .midi import MidiDict
1714
1815MAX_SEQ_LEN = 4096
1916LEN_MS = 30000
2421VEL_TOLERANCE = 50
2522
2623
24+ # TODO: Profile and fix gpu util
25+
26+
2727def calculate_vel (
28- logits : torch .tensor ,
28+ logits : torch .Tensor ,
2929 init_vel : int ,
3030 tokenizer : AmtTokenizer = AmtTokenizer (),
3131):
@@ -51,13 +51,13 @@ def calculate_vel(
5151
5252 vels = torch .tensor (vels ).to (probs .device )
5353 new_vel = torch .sum (vels * probs ) / torch .sum (probs )
54- new_vel = round (new_vel .item () / 10 ) * 10
54+ new_vel = round (new_vel .item () / 5 ) * 5
5555
5656 return tokenizer .tok_to_id [("vel" , new_vel )]
5757
5858
5959def calculate_onset (
60- logits : torch .tensor ,
60+ logits : torch .Tensor ,
6161 init_onset : int ,
6262 tokenizer : AmtTokenizer = AmtTokenizer (),
6363):
@@ -88,6 +88,7 @@ def calculate_onset(
8888 return tokenizer .tok_to_id [("onset" , new_onset )]
8989
9090
91+ @torch .autocast ("cuda" , dtype = torch .bfloat16 )
9192def process_segments (
9293 tasks : list ,
9394 model : AmtEncoderDecoder ,
@@ -111,14 +112,14 @@ def process_segments(
111112
112113 kv_cache = model .get_empty_cache ()
113114
114- # for idx in (
115- # pbar := tqdm(
116- # range(min_prefix_len, MAX_SEQ_LEN - 1),
117- # total=MAX_SEQ_LEN - (min_prefix_len + 1),
118- # leave=False,
119- # )
120- # ):
121- for idx in range (min_prefix_len , MAX_SEQ_LEN - 1 ):
115+ for idx in (
116+ pbar := tqdm (
117+ range (min_prefix_len , MAX_SEQ_LEN - 1 ),
118+ total = MAX_SEQ_LEN - (min_prefix_len + 1 ),
119+ leave = False ,
120+ )
121+ ):
122+ # for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
122123 if idx == min_prefix_len :
123124 logits = model .decoder (
124125 xa = audio_features ,
@@ -160,6 +161,12 @@ def process_segments(
160161 if all (eos_seen ):
161162 break
162163
164+ if not all (eos_seen ):
165+ print ("WARNING: OVERFLOW" )
166+ for _idx in range (seq .shape [0 ]):
167+ if eos_seen [_idx ] == False :
168+ eos_seen [_idx ] = MAX_SEQ_LEN
169+
163170 results = [
164171 tokenizer .decode (seq [_idx , : eos_seen [_idx ] + 1 ])
165172 for _idx in range (seq .shape [0 ])
@@ -174,9 +181,7 @@ def gpu_manager(
174181 model : AmtEncoderDecoder ,
175182 batch_size : int ,
176183):
177- model .cuda ()
178- model .eval ()
179- model .compile ()
184+ # model.compile()
180185 audio_transform = AudioTransform ().cuda ()
181186 tokenizer = AmtTokenizer (return_tensors = True )
182187 process_pid = multiprocessing .current_process ().pid
@@ -277,9 +282,6 @@ def process_file(
277282 seq [init_idx : seq .index (tokenizer .eos_tok )],
278283 idx * CHUNK_LEN_MS ,
279284 )
280- print (
281- f"{ process_pid } : Finished { idx + 1 } /{ len (audio_segments )} audio segments"
282- )
283285
284286 if idx == len (audio_segments ) - 1 :
285287 break
@@ -310,8 +312,8 @@ def _get_save_path(_file_path: str):
310312 save_path = os .path .join (
311313 save_dir , os .path .splitext (input_rel_path )[0 ] + ".mid"
312314 )
313- if not os .path .exists (os .path .dirname (save_path )):
314- os .makedirs (os .path .dirname (save_path ))
315+ if not os .path .isdir (os .path .dirname (save_path )):
316+ os .makedirs (os .path .dirname (save_path ), exist_ok = True )
315317
316318 return save_path
317319
@@ -361,20 +363,15 @@ def batch_transcribe(
361363 if gpu_id is not None :
362364 os .environ ["CUDA_VISIBLE_DEVICES" ] = str (gpu_id )
363365
364- model .to ("cuda" )
366+ model .cuda ()
367+ model .eval ()
365368 file_queue = Queue ()
366369 for file_path in file_paths :
367370 file_queue .put (file_path )
368371
369372 gpu_task_queue = Queue ()
370373 result_queue = Queue ()
371374
372- gpu_manager_process = multiprocessing .Process (
373- target = gpu_manager ,
374- args = (gpu_task_queue , result_queue , model , batch_size ),
375- )
376- gpu_manager_process .start ()
377-
378375 worker_processes = [
379376 multiprocessing .Process (
380377 target = worker ,
@@ -391,6 +388,13 @@ def batch_transcribe(
391388 for p in worker_processes :
392389 p .start ()
393390
391+ time .sleep (10 )
392+ gpu_manager_process = multiprocessing .Process (
393+ target = gpu_manager ,
394+ args = (gpu_task_queue , result_queue , model , batch_size ),
395+ )
396+ gpu_manager_process .start ()
397+
394398 for p in worker_processes :
395399 p .join ()
396400
0 commit comments