Skip to content

Commit 5a904e6

Browse files
authored
Change overwrite logic (#58)
1 parent 4120f57 commit 5a904e6

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

amt/inference/transcribe.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -783,17 +783,17 @@ def get_save_path(
783783
file_path: str,
784784
input_dir: str | None,
785785
save_dir: str,
786-
idx: int | str = "",
786+
idx_str: int | str = "",
787787
):
788788
if input_dir is None:
789789
save_path = os.path.join(
790790
save_dir,
791-
os.path.splitext(os.path.basename(file_path))[0] + f"{idx}.mid",
791+
os.path.splitext(os.path.basename(file_path))[0] + f"{idx_str}.mid",
792792
)
793793
else:
794794
input_rel_path = os.path.relpath(file_path, input_dir)
795795
save_path = os.path.join(
796-
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx}.mid"
796+
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx_str}.mid"
797797
)
798798
if not os.path.isdir(os.path.dirname(save_path)):
799799
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@@ -810,7 +810,7 @@ def process_file(
810810
save_dir: str,
811811
input_dir: str,
812812
logger: logging.Logger,
813-
segments: List[Tuple[int, int]] | None = None,
813+
segments: List[Tuple[int, Tuple[int, int]]] | None = None,
814814
):
815815
def _save_seq(_seq: List, _save_path: str):
816816
if os.path.exists(_save_path):
@@ -852,12 +852,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
852852

853853
pid = threading.get_ident()
854854
if segments is None:
855-
segments = [None]
855+
# process_file and get_wav_segments will interpret segment=None as
856+
# processing the entire file
857+
segments = [(None, None)]
856858

857859
if len(segments) == 0:
858860
logger.info(f"No segments to transcribe, skipping file: {file_path}")
859861

860-
for idx, segment in enumerate(segments):
862+
for idx, segment in segments:
863+
idx_str = f"_{idx}" if idx is not None else ""
864+
save_path = get_save_path(file_path, input_dir, save_dir, idx_str)
865+
861866
try:
862867
seq = transcribe_file(
863868
file_path,
@@ -876,15 +881,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
876881
logger.info(f"Removed {res_rmv_cnt} from result queue")
877882
continue
878883

879-
logger.info(f"Finished file: {file_path} (segment: {idx})")
884+
logger.info(
885+
f"Finished file: {file_path} (segment: {idx if idx is not None else 'full'})"
886+
)
880887
if len(seq) < 500:
881-
logger.info(f"Skipping seq - too short (segment {idx})")
888+
logger.info(
889+
f"Skipping seq - too short (segment {idx if idx is not None else 'full'})"
890+
)
882891
else:
883892
logger.debug(
884-
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})"
893+
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx if idx is not None else 'full'})"
885894
)
886-
idx = f"_{idx}" if segment is not None else ""
887-
save_path = get_save_path(file_path, input_dir, save_dir, idx)
888895
_save_seq(seq, save_path)
889896

890897
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
@@ -997,20 +1004,28 @@ def batch_transcribe(
9971004
files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True
9981005
)
9991006
for file_to_process in files_to_process:
1000-
# Only add to file_queue if transcription MIDI file doesn't exist
1001-
if (
1002-
os.path.isfile(
1007+
if "segments" in file_to_process:
1008+
# Process files with segments
1009+
unsaved_segments = []
1010+
for idx, segment in enumerate(file_to_process["segments"]):
1011+
segment_save_path = get_save_path(
1012+
file_to_process["path"],
1013+
input_dir,
1014+
save_dir,
1015+
idx_str=f"_{idx}",
1016+
)
1017+
if not os.path.isfile(segment_save_path):
1018+
unsaved_segments.append((idx, segment))
1019+
1020+
if unsaved_segments:
1021+
file_to_process["segments"] = unsaved_segments
1022+
file_queue.put(file_to_process)
1023+
else:
1024+
# Process files without segments (whole file)
1025+
if not os.path.isfile(
10031026
get_save_path(file_to_process["path"], input_dir, save_dir)
1004-
)
1005-
is False
1006-
) and os.path.isfile(
1007-
get_save_path(
1008-
file_to_process["path"], input_dir, save_dir, idx="_0"
1009-
)
1010-
) is False:
1011-
file_queue.put(file_to_process)
1012-
elif len(files_to_process) == 1:
1013-
file_queue.put(file_to_process)
1027+
):
1028+
file_queue.put(file_to_process)
10141029

10151030
logger.info(
10161031
f"Files to process: {file_queue.qsize()}/{len(files_to_process)}"
@@ -1026,7 +1041,7 @@ def batch_transcribe(
10261041
file_queue.qsize(),
10271042
)
10281043
num_processes_per_worker = min(
1029-
3 * (batch_size // num_workers), file_queue.qsize() // num_workers
1044+
5 * (batch_size // num_workers), file_queue.qsize() // num_workers
10301045
)
10311046

10321047
mp_manager = Manager()

0 commit comments

Comments
 (0)