Skip to content

Commit d25a7fa

Browse files
authored
Fix distortion and dataset indexing (#16)
* working * format * fix distortion bottlekneck * format * adj
1 parent b82a9da commit d25a7fa

File tree

6 files changed

+148
-58
lines changed

6 files changed

+148
-58
lines changed

amt/audio.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def __init__(
193193
min_dist_gain: int = 0,
194194
noise_ratio: float = 0.95,
195195
reverb_ratio: float = 0.95,
196-
applause_ratio: float = 0.01, # CHANGE
196+
applause_ratio: float = 0.01,
197197
distort_ratio: float = 0.15,
198198
reduce_ratio: float = 0.01,
199-
spec_aug_ratio: float = 0.25,
199+
spec_aug_ratio: float = 0.5,
200200
):
201201
super().__init__()
202202
self.tokenizer = AmtTokenizer()
@@ -257,7 +257,7 @@ def __init__(
257257
)
258258
self.spec_aug = torch.nn.Sequential(
259259
torchaudio.transforms.FrequencyMasking(
260-
freq_mask_param=10, iid_masks=True
260+
freq_mask_param=15, iid_masks=True
261261
),
262262
torchaudio.transforms.TimeMasking(
263263
time_mask_param=1000, iid_masks=True
@@ -374,6 +374,17 @@ def apply_distortion(self, wav: torch.tensor):
374374

375375
return AF.overdrive(wav, gain=gain, colour=colour)
376376

377+
def distortion_aug_cpu(self, wav: torch.Tensor):
378+
# This function should run on the cpu (i.e. in the dataloader collate
379+
# function) in order to not be a bottlekneck
380+
381+
if random.random() < self.reduce_ratio:
382+
wav = self.apply_reduction(wav)
383+
if random.random() < self.distort_ratio:
384+
wav = self.apply_distortion(wav)
385+
386+
return wav
387+
377388
def shift_spec(self, specs: torch.Tensor, shift: int):
378389
if shift == 0:
379390
return specs
@@ -400,18 +411,15 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
400411
return shifted_specs
401412

402413
def aug_wav(self, wav: torch.Tensor):
414+
# This function doesn't apply distortion. If distortion is desired it
415+
# should be run before hand on the cpu with distortion_aug_cpu.
416+
403417
# Noise
404418
if random.random() < self.noise_ratio:
405419
wav = self.apply_noise(wav)
406420
if random.random() < self.applause_ratio:
407421
wav = self.apply_applause(wav)
408422

409-
# Distortion
410-
if random.random() < self.reduce_ratio:
411-
wav = self.apply_reduction(wav)
412-
elif random.random() < self.distort_ratio:
413-
wav = self.apply_distortion(wav)
414-
415423
# Reverb
416424
if random.random() < self.reverb_ratio:
417425
return self.apply_reverb(wav)
@@ -439,7 +447,7 @@ def log_mel(self, wav: torch.Tensor, shift: int | None = None):
439447
return log_spec
440448

441449
def forward(self, wav: torch.Tensor, shift: int = 0):
442-
# Noise, distortion, and reverb
450+
# Noise, and reverb
443451
wav = self.aug_wav(wav)
444452

445453
# Spec & pitch shift

amt/data.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,17 @@ def __init__(self, load_path: str):
113113
self.file_mmap = mmap.mmap(
114114
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
115115
)
116-
self.index = self._build_index()
116+
117+
index_path = AmtDataset._get_index_path(load_path=load_path)
118+
if os.path.isfile(index_path) is True:
119+
self.index = self._load_index(load_path=index_path)
120+
else:
121+
print("Calculating index...")
122+
self.index = self._build_index()
123+
print(
124+
f"Index of length {len(self.index)} calculated, saving to {index_path}"
125+
)
126+
self._save_index(index=self.index, save_path=index_path)
117127

118128
def close(self):
119129
if self.file_buff:
@@ -167,6 +177,21 @@ def _build_index(self):
167177

168178
return index
169179

180+
def _save_index(self, index: list[int], save_path: str):
181+
with open(save_path, "w") as file:
182+
for idx in index:
183+
file.write(f"{idx}\n")
184+
185+
def _load_index(self, load_path: str):
186+
with open(load_path, "r") as file:
187+
return [int(line.strip()) for line in file]
188+
189+
@staticmethod
190+
def _get_index_path(load_path: str):
191+
return (
192+
f"{load_path.rsplit('.', 1)[0]}_index.{load_path.rsplit('.', 1)[1]}"
193+
)
194+
170195
@classmethod
171196
def build(
172197
cls,
@@ -175,6 +200,12 @@ def build(
175200
num_processes: int = 1,
176201
):
177202
assert os.path.isfile(save_path) is False, f"{save_path} already exists"
203+
204+
index_path = AmtDataset._get_index_path(load_path=save_path)
205+
if os.path.isfile(index_path):
206+
print(f"Removing existing index file at {index_path}")
207+
os.remove(AmtDataset._get_index_path(load_path=save_path))
208+
178209
num_paths = len(matched_load_paths)
179210
with Pool(processes=num_processes) as pool:
180211
sharded_save_paths = []
@@ -202,3 +233,25 @@ def build(
202233
os.system(shell_cmd)
203234
for _path in sharded_save_paths:
204235
os.remove(_path)
236+
237+
# Create index by loading object
238+
AmtDataset(load_path=save_path)
239+
240+
def _build_index(self):
241+
self.file_mmap.seek(0)
242+
index = []
243+
pos = 0
244+
while True:
245+
pos_buff = pos
246+
247+
pos = self.file_mmap.find(b"\n", pos)
248+
if pos == -1:
249+
break
250+
pos = self.file_mmap.find(b"\n", pos + 1)
251+
if pos == -1:
252+
break
253+
254+
index.append(pos_buff)
255+
pos += 1
256+
257+
return index

amt/infer.py

+52-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import time
33
import random
4+
import logging
45
import torch
56
import torch.multiprocessing as multiprocessing
67

@@ -21,7 +22,23 @@
2122
VEL_TOLERANCE = 50
2223

2324

24-
# TODO: Profile and fix gpu util
25+
def _setup_logger():
26+
logger = logging.getLogger(__name__)
27+
for h in logger.handlers[:]:
28+
logger.removeHandler(h)
29+
30+
logger.propagate = False
31+
logger.setLevel(logging.DEBUG)
32+
formatter = logging.Formatter(
33+
"[%(asctime)s] %(process)d: [%(levelname)s] %(message)s",
34+
)
35+
36+
ch = logging.StreamHandler()
37+
ch.setLevel(logging.INFO)
38+
ch.setFormatter(formatter)
39+
logger.addHandler(ch)
40+
41+
return logging.getLogger(__name__)
2542

2643

2744
def calculate_vel(
@@ -101,7 +118,7 @@ def wrapper(*args, **kwargs):
101118
return func(*args, **kwargs)
102119
else:
103120
# Call the function with float16 if bfloat16 is not supported
104-
with torch.autocast("cuda", dtype=torch.float16):
121+
with torch.autocast("cuda", dtype=torch.float32):
105122
return func(*args, **kwargs)
106123

107124
return wrapper
@@ -114,6 +131,7 @@ def process_segments(
114131
audio_transform: AudioTransform,
115132
tokenizer: AmtTokenizer,
116133
):
134+
logger = logging.getLogger(__name__)
117135
audio_segs = torch.stack(
118136
[audio_seg for (audio_seg, prefix), _ in tasks]
119137
).cuda()
@@ -131,14 +149,14 @@ def process_segments(
131149

132150
kv_cache = model.get_empty_cache()
133151

134-
for idx in (
135-
pbar := tqdm(
136-
range(min_prefix_len, MAX_SEQ_LEN - 1),
137-
total=MAX_SEQ_LEN - (min_prefix_len + 1),
138-
leave=False,
139-
)
140-
):
141-
# for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
152+
# for idx in (
153+
# pbar := tqdm(
154+
# range(min_prefix_len, MAX_SEQ_LEN - 1),
155+
# total=MAX_SEQ_LEN - (min_prefix_len + 1),
156+
# leave=False,
157+
# )
158+
# ):
159+
for idx in range(min_prefix_len, MAX_SEQ_LEN - 1):
142160
if idx == min_prefix_len:
143161
logits = model.decoder(
144162
xa=audio_features,
@@ -181,7 +199,7 @@ def process_segments(
181199
break
182200

183201
if not all(eos_seen):
184-
print("WARNING: OVERFLOW")
202+
logger.warning("Context length overflow when transcribing segment")
185203
for _idx in range(seq.shape[0]):
186204
if eos_seen[_idx] == False:
187205
eos_seen[_idx] = MAX_SEQ_LEN
@@ -201,19 +219,19 @@ def gpu_manager(
201219
batch_size: int,
202220
):
203221
# model.compile()
222+
logger = _setup_logger()
204223
audio_transform = AudioTransform().cuda()
205224
tokenizer = AmtTokenizer(return_tensors=True)
206-
process_pid = multiprocessing.current_process().pid
207225

208226
wait_for_batch = True
209227
batch = []
210228
while True:
211229
try:
212230
task, pid = gpu_task_queue.get(timeout=5)
213231
except:
214-
print(f"{process_pid}: GPU task timeout")
232+
logger.info(f"GPU task timeout")
215233
if len(batch) == 0:
216-
print(f"{process_pid}: Finished GPU tasks")
234+
logger.info(f"Finished GPU tasks")
217235
return
218236
else:
219237
wait_for_batch = False
@@ -274,8 +292,10 @@ def process_file(
274292
result_queue: Queue,
275293
tokenizer: AmtTokenizer = AmtTokenizer(),
276294
):
277-
process_pid = multiprocessing.current_process().pid
278-
print(f"{process_pid}: Getting wav segments")
295+
logger = logging.getLogger(__name__)
296+
pid = multiprocessing.current_process().pid
297+
298+
logger.info(f"Getting wav segments")
279299
audio_segments = [
280300
f
281301
for f, _ in get_wav_mid_segments(
@@ -288,10 +308,10 @@ def process_file(
288308
init_idx = len(seq)
289309

290310
# Add to gpu queue and wait for results
291-
gpu_task_queue.put(((audio_seg, seq), process_pid))
311+
gpu_task_queue.put(((audio_seg, seq), pid))
292312
while True:
293313
gpu_result = result_queue.get()
294-
if gpu_result["pid"] == process_pid:
314+
if gpu_result["pid"] == pid:
295315
seq = gpu_result["result"]
296316
break
297317
else:
@@ -307,7 +327,7 @@ def process_file(
307327
else:
308328
seq = _truncate_seq(seq, CHUNK_LEN_MS, LEN_MS)
309329
if len(seq) == 1:
310-
print(f"{process_pid}: exiting early")
330+
logger.info(f"Exiting early")
311331
return res
312332

313333
return res
@@ -336,19 +356,19 @@ def _get_save_path(_file_path: str):
336356

337357
return save_path
338358

339-
pid = multiprocessing.current_process().pid
359+
logger = _setup_logger()
340360
tokenizer = AmtTokenizer()
341361
files_processed = 0
342362
while not file_queue.empty():
343363
file_path = file_queue.get()
344364
save_path = _get_save_path(file_path)
345365
if os.path.exists(save_path):
346-
print(f"{pid}: {save_path} already exists, overwriting")
366+
logger.info(f"{save_path} already exists, overwriting")
347367

348368
try:
349369
res = process_file(file_path, gpu_task_queue, result_queue)
350370
except Exception as e:
351-
print(f"{pid}: Failed to transcribe {file_path}")
371+
logger.error(f"Failed to transcribe {file_path}")
352372
continue
353373

354374
files_processed += 1
@@ -365,14 +385,14 @@ def _get_save_path(_file_path: str):
365385
mid = mid_dict.to_midi()
366386
mid.save(save_path)
367387
except Exception as e:
368-
print(f"{pid}: Failed to detokenize with error {e}")
388+
logger.error(f"Failed to detokenize with error {e}")
369389
else:
370-
print(f"{pid}: Finished file {files_processed} - {file_path}")
371-
print(f"{pid}: {file_queue.qsize()} file(s) remaining in queue")
390+
logger.info(f"Finished file {files_processed} - {file_path}")
391+
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
372392

373393

374394
def batch_transcribe(
375-
file_paths: list,
395+
file_paths, # Queue | list,
376396
model: AmtEncoderDecoder,
377397
save_dir: str,
378398
batch_size: int = 16,
@@ -384,9 +404,12 @@ def batch_transcribe(
384404

385405
model.cuda()
386406
model.eval()
387-
file_queue = Queue()
388-
for file_path in file_paths:
389-
file_queue.put(file_path)
407+
if isinstance(file_paths, list):
408+
file_queue = Queue()
409+
for file_path in file_paths:
410+
file_queue.put(file_path)
411+
else:
412+
file_queue = file_paths
390413

391414
gpu_task_queue = Queue()
392415
result_queue = Queue()

0 commit comments

Comments
 (0)