@@ -95,6 +95,10 @@ def __init__(self, config):
9595 # Check if model_data has a "training" key with "instruments" list
9696 self .primary_stem_name = None
9797 self .secondary_stem_name = None
98+
99+ # Audio bit depth tracking for preserving input quality
100+ self .input_bit_depth = None
101+ self .input_subtype = None
98102
99103 if "training" in self .model_data and "instruments" in self .model_data ["training" ]:
100104 instruments = self .model_data ["training" ]["instruments" ]
@@ -211,11 +215,40 @@ def prepare_mix(self, mix):
211215 # Check if the input is a file path (string) and needs to be loaded
212216 if not isinstance (mix , np .ndarray ):
213217 self .logger .debug (f"Loading audio from file: { mix } " )
218+
219+ # Get audio file info to capture bit depth before loading
220+ try :
221+ audio_info = sf .info (mix )
222+ self .input_subtype = audio_info .subtype
223+ self .logger .info (f"Input audio subtype: { self .input_subtype } " )
224+
225+ # Map subtype to bit depth
226+ if 'PCM_16' in self .input_subtype or self .input_subtype == 'PCM_S8' :
227+ self .input_bit_depth = 16
228+ elif 'PCM_24' in self .input_subtype :
229+ self .input_bit_depth = 24
230+ elif 'PCM_32' in self .input_subtype or 'FLOAT' in self .input_subtype or 'DOUBLE' in self .input_subtype :
231+ self .input_bit_depth = 32
232+ else :
233+ # Default to 16-bit for unknown formats
234+ self .input_bit_depth = 16
235+ self .logger .warning (f"Unknown audio subtype { self .input_subtype } , defaulting to 16-bit output" )
236+
237+ self .logger .info (f"Detected input bit depth: { self .input_bit_depth } -bit" )
238+ except Exception as e :
239+ self .logger .warning (f"Could not read audio file info, defaulting to 16-bit output: { e } " )
240+ self .input_bit_depth = 16
241+ self .input_subtype = 'PCM_16'
242+
214243 mix , sr = librosa .load (mix , mono = False , sr = self .sample_rate )
215244 self .logger .debug (f"Audio loaded. Sample rate: { sr } , Audio shape: { mix .shape } " )
216245 else :
217246 # Transpose the mix if it's already an ndarray (expected shape: [channels, samples])
218247 self .logger .debug ("Transposing the provided mix array." )
248+ # Default to 16-bit if numpy array provided directly
249+ if self .input_bit_depth is None :
250+ self .input_bit_depth = 16
251+ self .input_subtype = 'PCM_16'
219252 mix = mix .T
220253 self .logger .debug (f"Transposed mix shape: { mix .shape } " )
221254
@@ -278,10 +311,15 @@ def write_audio_pydub(self, stem_path: str, stem_source):
278311 self .logger .debug (f"Audio data shape before processing: { stem_source .shape } " )
279312 self .logger .debug (f"Data type before conversion: { stem_source .dtype } " )
280313
281- # Ensure the audio data is in the correct format (e.g., int16)
314+ # Determine bit depth for output (use input bit depth if available, otherwise default to 16)
315+ output_bit_depth = self .input_bit_depth if self .input_bit_depth is not None else 16
316+ self .logger .info (f"Writing output with { output_bit_depth } -bit depth" )
317+
318+ # For pydub, we always convert to int16 for the AudioSegment creation
319+ # Then let ffmpeg handle the conversion to the target bit depth during export
282320 if stem_source .dtype != np .int16 :
283321 stem_source = (stem_source * 32767 ).astype (np .int16 )
284- self .logger .debug ("Converted stem_source to int16." )
322+ self .logger .debug ("Converted stem_source to int16 for pydub processing ." )
285323
286324 # Correctly interleave stereo channels
287325 stem_source_interleaved = np .empty ((2 * stem_source .shape [0 ],), dtype = np .int16 )
@@ -290,9 +328,9 @@ def write_audio_pydub(self, stem_path: str, stem_source):
290328
291329 self .logger .debug (f"Interleaved audio data shape: { stem_source_interleaved .shape } " )
292330
293- # Create a pydub AudioSegment
331+ # Create a pydub AudioSegment (always from 16-bit data)
294332 try :
295- audio_segment = AudioSegment (stem_source_interleaved .tobytes (), frame_rate = self .sample_rate , sample_width = stem_source . dtype . itemsize , channels = 2 )
333+ audio_segment = AudioSegment (stem_source_interleaved .tobytes (), frame_rate = self .sample_rate , sample_width = 2 , channels = 2 )
296334 self .logger .debug ("Created AudioSegment successfully." )
297335 except (IOError , ValueError ) as e :
298336 self .logger .error (f"Specific error creating AudioSegment: { e } " )
@@ -312,8 +350,31 @@ def write_audio_pydub(self, stem_path: str, stem_source):
312350
313351 # Export using the determined format
314352 try :
315- audio_segment .export (stem_path , format = file_format , bitrate = bitrate )
316- self .logger .debug (f"Exported audio file successfully to { stem_path } " )
353+ # Pass codec parameters to ffmpeg to enforce bit depth for lossless formats
354+ export_params = {"format" : file_format }
355+
356+ if bitrate :
357+ export_params ["bitrate" ] = bitrate
358+
359+ # For lossless formats (WAV/FLAC), specify the codec parameters to enforce bit depth
360+ if file_format in ["wav" , "flac" ]:
361+ if output_bit_depth == 16 :
362+ export_params ["parameters" ] = ["-sample_fmt" , "s16" ]
363+ elif output_bit_depth == 24 :
364+ export_params ["parameters" ] = ["-sample_fmt" , "s32" ]
365+ # For 24-bit, we also need to specify the bit depth explicitly
366+ if file_format == "wav" :
367+ export_params ["codec" ] = "pcm_s24le"
368+ elif file_format == "flac" :
369+ # FLAC supports 24-bit natively, no special handling needed
370+ pass
371+ elif output_bit_depth == 32 :
372+ export_params ["parameters" ] = ["-sample_fmt" , "s32" ]
373+ if file_format == "wav" :
374+ export_params ["codec" ] = "pcm_s32le"
375+
376+ audio_segment .export (stem_path , ** export_params )
377+ self .logger .debug (f"Exported audio file successfully to { stem_path } with { output_bit_depth } -bit depth" )
317378 except (IOError , ValueError ) as e :
318379 self .logger .error (f"Error exporting audio file: { e } " )
319380
@@ -335,32 +396,47 @@ def write_audio_soundfile(self, stem_path: str, stem_source):
335396 os .makedirs (self .output_dir , exist_ok = True )
336397 stem_path = os .path .join (self .output_dir , stem_path )
337398
399+ # Determine the subtype based on the input audio's bit depth
400+ output_subtype = None
401+ if self .input_subtype :
402+ output_subtype = self .input_subtype
403+ self .logger .info (f"Using input subtype for output: { output_subtype } " )
404+ elif self .input_bit_depth :
405+ # Map bit depth to subtype
406+ if self .input_bit_depth == 16 :
407+ output_subtype = 'PCM_16'
408+ elif self .input_bit_depth == 24 :
409+ output_subtype = 'PCM_24'
410+ elif self .input_bit_depth == 32 :
411+ output_subtype = 'PCM_32'
412+ else :
413+ output_subtype = 'PCM_16' # Default fallback
414+ self .logger .info (f"Using output subtype based on bit depth: { output_subtype } " )
415+ else :
416+ # Default to PCM_16 if no bit depth info available
417+ output_subtype = 'PCM_16'
418+ self .logger .warning ("No bit depth info available, defaulting to PCM_16" )
419+
338420 # Correctly interleave stereo channels if needed
339421 if stem_source .shape [1 ] == 2 :
340422 # If the audio is already interleaved, ensure it's in the correct order
341423 # Check if the array is Fortran contiguous (column-major)
342424 if stem_source .flags ["F_CONTIGUOUS" ]:
343425 # Convert to C contiguous (row-major)
344426 stem_source = np .ascontiguousarray (stem_source )
345- # Otherwise, perform interleaving
346- else :
347- stereo_interleaved = np .empty ((2 * stem_source .shape [0 ],), dtype = np .int16 )
348- # Left channel
349- stereo_interleaved [0 ::2 ] = stem_source [:, 0 ]
350- # Right channel
351- stereo_interleaved [1 ::2 ] = stem_source [:, 1 ]
352- stem_source = stereo_interleaved
427+ # No need to manually interleave for soundfile - it handles multi-channel properly
428+ # Just ensure we don't have the wrong shape
353429
354- self .logger .debug (f"Interleaved audio data shape: { stem_source .shape } " )
430+ self .logger .debug (f"Audio data shape for soundfile : { stem_source .shape } " )
355431
356432 """
357433 Write audio using soundfile (for formats other than M4A).
358434 """
359- # Save audio using soundfile
435+ # Save audio using soundfile with the specified subtype
360436 try :
361- # Specify the subtype to define the sample width
362- sf .write (stem_path , stem_source , self .sample_rate )
363- self .logger .debug (f"Exported audio file successfully to { stem_path } " )
437+ # Specify the subtype to match input bit depth
438+ sf .write (stem_path , stem_source , self .sample_rate , subtype = output_subtype )
439+ self .logger .debug (f"Exported audio file successfully to { stem_path } with subtype { output_subtype } " )
364440 except Exception as e :
365441 self .logger .error (f"Error exporting audio file: { e } " )
366442
0 commit comments