1
1
import os
2
2
import time
3
3
import random
4
+ import logging
4
5
import torch
5
6
import torch .multiprocessing as multiprocessing
6
7
21
22
VEL_TOLERANCE = 50
22
23
23
24
24
- # TODO: Profile and fix gpu util
25
+ def _setup_logger ():
26
+ logger = logging .getLogger (__name__ )
27
+ for h in logger .handlers [:]:
28
+ logger .removeHandler (h )
29
+
30
+ logger .propagate = False
31
+ logger .setLevel (logging .DEBUG )
32
+ formatter = logging .Formatter (
33
+ "[%(asctime)s] %(process)d: [%(levelname)s] %(message)s" ,
34
+ )
35
+
36
+ ch = logging .StreamHandler ()
37
+ ch .setLevel (logging .INFO )
38
+ ch .setFormatter (formatter )
39
+ logger .addHandler (ch )
40
+
41
+ return logging .getLogger (__name__ )
25
42
26
43
27
44
def calculate_vel (
@@ -101,7 +118,7 @@ def wrapper(*args, **kwargs):
101
118
return func (* args , ** kwargs )
102
119
else :
103
120
# Call the function with float16 if bfloat16 is not supported
104
- with torch .autocast ("cuda" , dtype = torch .float16 ):
121
+ with torch .autocast ("cuda" , dtype = torch .float32 ):
105
122
return func (* args , ** kwargs )
106
123
107
124
return wrapper
@@ -114,6 +131,7 @@ def process_segments(
114
131
audio_transform : AudioTransform ,
115
132
tokenizer : AmtTokenizer ,
116
133
):
134
+ logger = logging .getLogger (__name__ )
117
135
audio_segs = torch .stack (
118
136
[audio_seg for (audio_seg , prefix ), _ in tasks ]
119
137
).cuda ()
@@ -131,14 +149,14 @@ def process_segments(
131
149
132
150
kv_cache = model .get_empty_cache ()
133
151
134
- for idx in (
135
- pbar := tqdm (
136
- range (min_prefix_len , MAX_SEQ_LEN - 1 ),
137
- total = MAX_SEQ_LEN - (min_prefix_len + 1 ),
138
- leave = False ,
139
- )
140
- ):
141
- # for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
152
+ # for idx in (
153
+ # pbar := tqdm(
154
+ # range(min_prefix_len, MAX_SEQ_LEN - 1),
155
+ # total=MAX_SEQ_LEN - (min_prefix_len + 1),
156
+ # leave=False,
157
+ # )
158
+ # ):
159
+ for idx in range (min_prefix_len , MAX_SEQ_LEN - 1 ):
142
160
if idx == min_prefix_len :
143
161
logits = model .decoder (
144
162
xa = audio_features ,
@@ -181,7 +199,7 @@ def process_segments(
181
199
break
182
200
183
201
if not all (eos_seen ):
184
- print ( "WARNING: OVERFLOW " )
202
+ logger . warning ( "Context length overflow when transcribing segment " )
185
203
for _idx in range (seq .shape [0 ]):
186
204
if eos_seen [_idx ] == False :
187
205
eos_seen [_idx ] = MAX_SEQ_LEN
@@ -201,19 +219,19 @@ def gpu_manager(
201
219
batch_size : int ,
202
220
):
203
221
# model.compile()
222
+ logger = _setup_logger ()
204
223
audio_transform = AudioTransform ().cuda ()
205
224
tokenizer = AmtTokenizer (return_tensors = True )
206
- process_pid = multiprocessing .current_process ().pid
207
225
208
226
wait_for_batch = True
209
227
batch = []
210
228
while True :
211
229
try :
212
230
task , pid = gpu_task_queue .get (timeout = 5 )
213
231
except :
214
- print (f"{ process_pid } : GPU task timeout" )
232
+ logger . info (f"GPU task timeout" )
215
233
if len (batch ) == 0 :
216
- print (f"{ process_pid } : Finished GPU tasks" )
234
+ logger . info (f"Finished GPU tasks" )
217
235
return
218
236
else :
219
237
wait_for_batch = False
@@ -274,8 +292,10 @@ def process_file(
274
292
result_queue : Queue ,
275
293
tokenizer : AmtTokenizer = AmtTokenizer (),
276
294
):
277
- process_pid = multiprocessing .current_process ().pid
278
- print (f"{ process_pid } : Getting wav segments" )
295
+ logger = logging .getLogger (__name__ )
296
+ pid = multiprocessing .current_process ().pid
297
+
298
+ logger .info (f"Getting wav segments" )
279
299
audio_segments = [
280
300
f
281
301
for f , _ in get_wav_mid_segments (
@@ -288,10 +308,10 @@ def process_file(
288
308
init_idx = len (seq )
289
309
290
310
# Add to gpu queue and wait for results
291
- gpu_task_queue .put (((audio_seg , seq ), process_pid ))
311
+ gpu_task_queue .put (((audio_seg , seq ), pid ))
292
312
while True :
293
313
gpu_result = result_queue .get ()
294
- if gpu_result ["pid" ] == process_pid :
314
+ if gpu_result ["pid" ] == pid :
295
315
seq = gpu_result ["result" ]
296
316
break
297
317
else :
@@ -307,7 +327,7 @@ def process_file(
307
327
else :
308
328
seq = _truncate_seq (seq , CHUNK_LEN_MS , LEN_MS )
309
329
if len (seq ) == 1 :
310
- print (f"{ process_pid } : exiting early" )
330
+ logger . info (f"Exiting early" )
311
331
return res
312
332
313
333
return res
@@ -336,19 +356,19 @@ def _get_save_path(_file_path: str):
336
356
337
357
return save_path
338
358
339
- pid = multiprocessing . current_process (). pid
359
+ logger = _setup_logger ()
340
360
tokenizer = AmtTokenizer ()
341
361
files_processed = 0
342
362
while not file_queue .empty ():
343
363
file_path = file_queue .get ()
344
364
save_path = _get_save_path (file_path )
345
365
if os .path .exists (save_path ):
346
- print (f"{ pid } : { save_path } already exists, overwriting" )
366
+ logger . info (f"{ save_path } already exists, overwriting" )
347
367
348
368
try :
349
369
res = process_file (file_path , gpu_task_queue , result_queue )
350
370
except Exception as e :
351
- print (f"{ pid } : Failed to transcribe { file_path } " )
371
+ logger . error (f"Failed to transcribe { file_path } " )
352
372
continue
353
373
354
374
files_processed += 1
@@ -365,14 +385,14 @@ def _get_save_path(_file_path: str):
365
385
mid = mid_dict .to_midi ()
366
386
mid .save (save_path )
367
387
except Exception as e :
368
- print (f"{ pid } : Failed to detokenize with error { e } " )
388
+ logger . error (f"Failed to detokenize with error { e } " )
369
389
else :
370
- print (f"{ pid } : Finished file { files_processed } - { file_path } " )
371
- print (f"{ pid } : { file_queue .qsize ()} file(s) remaining in queue" )
390
+ logger . info (f"Finished file { files_processed } - { file_path } " )
391
+ logger . info (f"{ file_queue .qsize ()} file(s) remaining in queue" )
372
392
373
393
374
394
def batch_transcribe (
375
- file_paths : list ,
395
+ file_paths , # Queue | list,
376
396
model : AmtEncoderDecoder ,
377
397
save_dir : str ,
378
398
batch_size : int = 16 ,
@@ -384,9 +404,12 @@ def batch_transcribe(
384
404
385
405
model .cuda ()
386
406
model .eval ()
387
- file_queue = Queue ()
388
- for file_path in file_paths :
389
- file_queue .put (file_path )
407
+ if isinstance (file_paths , list ):
408
+ file_queue = Queue ()
409
+ for file_path in file_paths :
410
+ file_queue .put (file_path )
411
+ else :
412
+ file_queue = file_paths
390
413
391
414
gpu_task_queue = Queue ()
392
415
result_queue = Queue ()
0 commit comments