@@ -45,12 +45,25 @@ def load_model(self,
4545 self ._loading_stats ['new_implementation_success' ] += 1
4646 logger .info (f"Successfully loaded { model_type } model with new implementation" )
4747 return result
48- except (RuntimeError , ValueError ) as e :
48+ except (RuntimeError , ValueError , TypeError ) as e :
4949 logger .error (f"New implementation failed: { e } " )
50- return ModelLoadingResult .failure_result (
51- error_message = f"New implementation failed: { e } " ,
52- implementation = ImplementationVersion .NEW ,
53- )
50+ # Attempt legacy fallback using the original (pre-normalized) configuration
51+ try :
52+ fallback_result = self ._load_with_legacy_implementation (
53+ model_path = model_path ,
54+ original_config = config ,
55+ device = device ,
56+ original_error = str (e )
57+ )
58+ logger .warning ("Fell back to legacy Roformer implementation successfully" )
59+ return fallback_result
60+ except (RuntimeError , ValueError , TypeError ) as fallback_error :
61+ logger .error (f"Legacy implementation also failed: { fallback_error } " )
62+ self ._loading_stats ['total_failures' ] += 1
63+ return ModelLoadingResult .failure_result (
64+ error_message = f"New implementation failed: { e } ; Legacy fallback failed: { fallback_error } " ,
65+ implementation = ImplementationVersion .NEW ,
66+ )
5467
5568 def validate_configuration (self , config : Dict [str , Any ], model_type : str ) -> bool :
5669 try :
@@ -160,13 +173,67 @@ def _create_mel_band_roformer(self, config: Dict[str, Any]):
160173 }
161174 if 'sample_rate' in config :
162175 model_args ['sample_rate' ] = config ['sample_rate' ]
163- if 'fmin' in config :
164- model_args ['fmin' ] = config ['fmin' ]
165- if 'fmax' in config :
166- model_args ['fmax' ] = config ['fmax' ]
176+ # Optional parameters commonly present in legacy configs
177+ for optional_key in [
178+ 'mask_estimator_depth' ,
179+ 'stft_n_fft' ,
180+ 'stft_hop_length' ,
181+ 'stft_win_length' ,
182+ 'stft_normalized' ,
183+ 'stft_window_fn' ,
184+ 'multi_stft_resolution_loss_weight' ,
185+ 'multi_stft_resolutions_window_sizes' ,
186+ 'multi_stft_hop_size' ,
187+ 'multi_stft_normalized' ,
188+ 'multi_stft_window_fn' ,
189+ 'match_input_audio_length' ,
190+ ]:
191+ if optional_key in config :
192+ model_args [optional_key ] = config [optional_key ]
193+ # Note: fmin and fmax are defined in config classes but not accepted by current constructor
167194 logger .debug (f"Creating MelBandRoformer with args: { list (model_args .keys ())} " )
168195 return MelBandRoformer (** model_args )
169196
197+ def _load_with_legacy_implementation (self ,
198+ model_path : str ,
199+ original_config : Dict [str , Any ],
200+ device : str ,
201+ original_error : str ) -> ModelLoadingResult :
202+ """
203+ Attempt to load the model using the legacy direct-constructor path
204+ for maximum backward compatibility with existing checkpoints.
205+ """
206+ import torch
207+
208+ # Use nested 'model' section if present; otherwise assume flat
209+ model_cfg = original_config .get ('model' , original_config )
210+
211+ # Determine model type from config
212+ if 'num_bands' in model_cfg :
213+ from ..uvr_lib_v5 .roformer .mel_band_roformer import MelBandRoformer
214+ model = MelBandRoformer (** model_cfg )
215+ elif 'freqs_per_bands' in model_cfg :
216+ from ..uvr_lib_v5 .roformer .bs_roformer import BSRoformer
217+ model = BSRoformer (** model_cfg )
218+ else :
219+ raise ValueError ("Unknown Roformer model type in legacy configuration" )
220+
221+ # Load checkpoint as raw state dict (legacy behavior)
222+ try :
223+ checkpoint = torch .load (model_path , map_location = 'cpu' , weights_only = True )
224+ except TypeError :
225+ # For older torch versions without weights_only
226+ checkpoint = torch .load (model_path , map_location = 'cpu' )
227+
228+ model .load_state_dict (checkpoint )
229+ model .to (device ).eval ()
230+
231+ return ModelLoadingResult .fallback_success_result (
232+ model = model ,
233+ original_error = original_error ,
234+ config = original_config ,
235+ )
236+
170237 def get_loading_stats (self ) -> Dict [str , int ]:
171238 return self ._loading_stats .copy ()
172239
@@ -232,8 +299,7 @@ def get_default_configuration(self, model_type: str) -> Dict[str, Any]:
232299 'use_torch_checkpoint' : False ,
233300 'skip_connection' : False ,
234301 'sample_rate' : 44100 ,
235- 'fmin' : 0 ,
236- 'fmax' : None ,
302+ # Note: fmin and fmax are not implemented in MelBandRoformer constructor
237303 }
238304 else :
239305 raise ValueError (f"Unknown model type: { model_type } " )
0 commit comments