@@ -783,17 +783,17 @@ def get_save_path(
783
783
file_path : str ,
784
784
input_dir : str | None ,
785
785
save_dir : str ,
786
- idx : int | str = "" ,
786
+ idx_str : int | str = "" ,
787
787
):
788
788
if input_dir is None :
789
789
save_path = os .path .join (
790
790
save_dir ,
791
- os .path .splitext (os .path .basename (file_path ))[0 ] + f"{ idx } .mid" ,
791
+ os .path .splitext (os .path .basename (file_path ))[0 ] + f"{ idx_str } .mid" ,
792
792
)
793
793
else :
794
794
input_rel_path = os .path .relpath (file_path , input_dir )
795
795
save_path = os .path .join (
796
- save_dir , os .path .splitext (input_rel_path )[0 ] + f"{ idx } .mid"
796
+ save_dir , os .path .splitext (input_rel_path )[0 ] + f"{ idx_str } .mid"
797
797
)
798
798
if not os .path .isdir (os .path .dirname (save_path )):
799
799
os .makedirs (os .path .dirname (save_path ), exist_ok = True )
@@ -810,7 +810,7 @@ def process_file(
810
810
save_dir : str ,
811
811
input_dir : str ,
812
812
logger : logging .Logger ,
813
- segments : List [Tuple [int , int ]] | None = None ,
813
+ segments : List [Tuple [int , Tuple [ int , int ] ]] | None = None ,
814
814
):
815
815
def _save_seq (_seq : List , _save_path : str ):
816
816
if os .path .exists (_save_path ):
@@ -852,12 +852,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
852
852
853
853
pid = threading .get_ident ()
854
854
if segments is None :
855
- segments = [None ]
855
+ # process_file and get_wav_segments will interpret segment=None as
856
+ # processing the entire file
857
+ segments = [(None , None )]
856
858
857
859
if len (segments ) == 0 :
858
860
logger .info (f"No segments to transcribe, skipping file: { file_path } " )
859
861
860
- for idx , segment in enumerate (segments ):
862
+ for idx , segment in segments :
863
+ idx_str = f"_{ idx } " if idx is not None else ""
864
+ save_path = get_save_path (file_path , input_dir , save_dir , idx_str )
865
+
861
866
try :
862
867
seq = transcribe_file (
863
868
file_path ,
@@ -876,15 +881,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
876
881
logger .info (f"Removed { res_rmv_cnt } from result queue" )
877
882
continue
878
883
879
- logger .info (f"Finished file: { file_path } (segment: { idx } )" )
884
+ logger .info (
885
+ f"Finished file: { file_path } (segment: { idx if idx is not None else 'full' } )"
886
+ )
880
887
if len (seq ) < 500 :
881
- logger .info (f"Skipping seq - too short (segment { idx } )" )
888
+ logger .info (
889
+ f"Skipping seq - too short (segment { idx if idx is not None else 'full' } )"
890
+ )
882
891
else :
883
892
logger .debug (
884
- f"Saving seq of length { len (seq )} from file: { file_path } (segment: { idx } )"
893
+ f"Saving seq of length { len (seq )} from file: { file_path } (segment: { idx if idx is not None else 'full' } )"
885
894
)
886
- idx = f"_{ idx } " if segment is not None else ""
887
- save_path = get_save_path (file_path , input_dir , save_dir , idx )
888
895
_save_seq (seq , save_path )
889
896
890
897
logger .info (f"{ file_queue .qsize ()} file(s) remaining in queue" )
@@ -997,20 +1004,28 @@ def batch_transcribe(
997
1004
files_to_process , key = lambda x : os .path .getsize (x ["path" ]), reverse = True
998
1005
)
999
1006
for file_to_process in files_to_process :
1000
- # Only add to file_queue if transcription MIDI file doesn't exist
1001
- if (
1002
- os .path .isfile (
1007
+ if "segments" in file_to_process :
1008
+ # Process files with segments
1009
+ unsaved_segments = []
1010
+ for idx , segment in enumerate (file_to_process ["segments" ]):
1011
+ segment_save_path = get_save_path (
1012
+ file_to_process ["path" ],
1013
+ input_dir ,
1014
+ save_dir ,
1015
+ idx_str = f"_{ idx } " ,
1016
+ )
1017
+ if not os .path .isfile (segment_save_path ):
1018
+ unsaved_segments .append ((idx , segment ))
1019
+
1020
+ if unsaved_segments :
1021
+ file_to_process ["segments" ] = unsaved_segments
1022
+ file_queue .put (file_to_process )
1023
+ else :
1024
+ # Process files without segments (whole file)
1025
+ if not os .path .isfile (
1003
1026
get_save_path (file_to_process ["path" ], input_dir , save_dir )
1004
- )
1005
- is False
1006
- ) and os .path .isfile (
1007
- get_save_path (
1008
- file_to_process ["path" ], input_dir , save_dir , idx = "_0"
1009
- )
1010
- ) is False :
1011
- file_queue .put (file_to_process )
1012
- elif len (files_to_process ) == 1 :
1013
- file_queue .put (file_to_process )
1027
+ ):
1028
+ file_queue .put (file_to_process )
1014
1029
1015
1030
logger .info (
1016
1031
f"Files to process: { file_queue .qsize ()} /{ len (files_to_process )} "
@@ -1026,7 +1041,7 @@ def batch_transcribe(
1026
1041
file_queue .qsize (),
1027
1042
)
1028
1043
num_processes_per_worker = min (
1029
- 3 * (batch_size // num_workers ), file_queue .qsize () // num_workers
1044
+ 5 * (batch_size // num_workers ), file_queue .qsize () // num_workers
1030
1045
)
1031
1046
1032
1047
mp_manager = Manager ()
0 commit comments