Skip to content

Commit 5d6f630

Browse files
authored
Small fix (#13)
* add more aug * add multi gpu inference * small fix * format
1 parent e49951c commit 5d6f630

File tree

3 files changed

+42
-31
lines changed

3 files changed

+42
-31
lines changed

amt/data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def get_wav_mid_segments(
5656
# Create features
5757
total_samples = wav.shape[-1]
5858
res = []
59-
for idx in range(0, total_samples, num_samples // stride_factor):
59+
for idx in range(
60+
0,
61+
total_samples - (num_samples - (num_samples // stride_factor)),
62+
num_samples // stride_factor,
63+
):
6064
audio_feature = pad_or_trim(wav[idx:], length=num_samples)
6165
if midi_dict is not None:
6266
mid_feature = tokenizer._tokenize_midi_dict(

amt/infer.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55
import torch.multiprocessing as multiprocessing
66

77
from torch.multiprocessing import Queue
8-
from torch.cuda import device_count, is_available
98
from tqdm import tqdm
109

1110
from amt.model import AmtEncoderDecoder
1211
from amt.tokenizer import AmtTokenizer
1312
from amt.audio import AudioTransform
1413
from amt.data import get_wav_mid_segments
15-
from amt.config import load_config
16-
from aria.data.midi import MidiDict
1714

1815
MAX_SEQ_LEN = 4096
1916
LEN_MS = 30000
@@ -24,8 +21,11 @@
2421
VEL_TOLERANCE = 50
2522

2623

24+
# TODO: Profile and fix gpu util
25+
26+
2727
def calculate_vel(
28-
logits: torch.tensor,
28+
logits: torch.Tensor,
2929
init_vel: int,
3030
tokenizer: AmtTokenizer = AmtTokenizer(),
3131
):
@@ -51,13 +51,13 @@ def calculate_vel(
5151

5252
vels = torch.tensor(vels).to(probs.device)
5353
new_vel = torch.sum(vels * probs) / torch.sum(probs)
54-
new_vel = round(new_vel.item() / 10) * 10
54+
new_vel = round(new_vel.item() / 5) * 5
5555

5656
return tokenizer.tok_to_id[("vel", new_vel)]
5757

5858

5959
def calculate_onset(
60-
logits: torch.tensor,
60+
logits: torch.Tensor,
6161
init_onset: int,
6262
tokenizer: AmtTokenizer = AmtTokenizer(),
6363
):
@@ -88,6 +88,7 @@ def calculate_onset(
8888
return tokenizer.tok_to_id[("onset", new_onset)]
8989

9090

91+
@torch.autocast("cuda", dtype=torch.bfloat16)
9192
def process_segments(
9293
tasks: list,
9394
model: AmtEncoderDecoder,
@@ -111,14 +112,14 @@ def process_segments(
111112

112113
kv_cache = model.get_empty_cache()
113114

114-
# for idx in (
115-
# pbar := tqdm(
116-
# range(min_prefix_len, MAX_SEQ_LEN - 1),
117-
# total=MAX_SEQ_LEN - (min_prefix_len + 1),
118-
# leave=False,
119-
# )
120-
# ):
121-
for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
115+
for idx in (
116+
pbar := tqdm(
117+
range(min_prefix_len, MAX_SEQ_LEN - 1),
118+
total=MAX_SEQ_LEN - (min_prefix_len + 1),
119+
leave=False,
120+
)
121+
):
122+
# for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
122123
if idx == min_prefix_len:
123124
logits = model.decoder(
124125
xa=audio_features,
@@ -160,6 +161,12 @@ def process_segments(
160161
if all(eos_seen):
161162
break
162163

164+
if not all(eos_seen):
165+
print("WARNING: OVERFLOW")
166+
for _idx in range(seq.shape[0]):
167+
if eos_seen[_idx] == False:
168+
eos_seen[_idx] = MAX_SEQ_LEN
169+
163170
results = [
164171
tokenizer.decode(seq[_idx, : eos_seen[_idx] + 1])
165172
for _idx in range(seq.shape[0])
@@ -174,9 +181,7 @@ def gpu_manager(
174181
model: AmtEncoderDecoder,
175182
batch_size: int,
176183
):
177-
model.cuda()
178-
model.eval()
179-
model.compile()
184+
# model.compile()
180185
audio_transform = AudioTransform().cuda()
181186
tokenizer = AmtTokenizer(return_tensors=True)
182187
process_pid = multiprocessing.current_process().pid
@@ -277,9 +282,6 @@ def process_file(
277282
seq[init_idx : seq.index(tokenizer.eos_tok)],
278283
idx * CHUNK_LEN_MS,
279284
)
280-
print(
281-
f"{process_pid}: Finished {idx+1}/{len(audio_segments)} audio segments"
282-
)
283285

284286
if idx == len(audio_segments) - 1:
285287
break
@@ -310,8 +312,8 @@ def _get_save_path(_file_path: str):
310312
save_path = os.path.join(
311313
save_dir, os.path.splitext(input_rel_path)[0] + ".mid"
312314
)
313-
if not os.path.exists(os.path.dirname(save_path)):
314-
os.makedirs(os.path.dirname(save_path))
315+
if not os.path.isdir(os.path.dirname(save_path)):
316+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
315317

316318
return save_path
317319

@@ -361,20 +363,15 @@ def batch_transcribe(
361363
if gpu_id is not None:
362364
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
363365

364-
model.to("cuda")
366+
model.cuda()
367+
model.eval()
365368
file_queue = Queue()
366369
for file_path in file_paths:
367370
file_queue.put(file_path)
368371

369372
gpu_task_queue = Queue()
370373
result_queue = Queue()
371374

372-
gpu_manager_process = multiprocessing.Process(
373-
target=gpu_manager,
374-
args=(gpu_task_queue, result_queue, model, batch_size),
375-
)
376-
gpu_manager_process.start()
377-
378375
worker_processes = [
379376
multiprocessing.Process(
380377
target=worker,
@@ -391,6 +388,13 @@ def batch_transcribe(
391388
for p in worker_processes:
392389
p.start()
393390

391+
time.sleep(10)
392+
gpu_manager_process = multiprocessing.Process(
393+
target=gpu_manager,
394+
args=(gpu_task_queue, result_queue, model, batch_size),
395+
)
396+
gpu_manager_process.start()
397+
394398
for p in worker_processes:
395399
p.join()
396400

amt/run.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def transcribe(args):
119119
assert os.path.isfile(args.cp), "model checkpoint file not found"
120120
assert args.load_path or args.load_dir, "must give either load path or dir"
121121
if args.load_path:
122-
assert os.path.isfile(args.load_path), "audio file not found"
122+
assert os.path.isfile(
123+
args.load_path
124+
), f"audio file not found: {args.load_path}"
123125
trans_mode = "single"
124126
if args.load_dir:
125127
assert os.path.isdir(args.load_dir), "load directory doesn't exist"
@@ -201,6 +203,7 @@ def transcribe(args):
201203
model=model,
202204
save_dir=args.save_dir,
203205
batch_size=args.bs,
206+
input_dir=args.load_dir,
204207
)
205208

206209

0 commit comments

Comments
 (0)