Skip to content

Commit 8231975

Browse files
authored
Fix mel regression (#235)
* Fixed MelBandRoformer modal loading regression * Added mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt to integration test
1 parent cec32b7 commit 8231975

11 files changed

+82
-12
lines changed

audio_separator/separator/roformer/roformer_loader.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,25 @@ def load_model(self,
4545
self._loading_stats['new_implementation_success'] += 1
4646
logger.info(f"Successfully loaded {model_type} model with new implementation")
4747
return result
48-
except (RuntimeError, ValueError) as e:
48+
except (RuntimeError, ValueError, TypeError) as e:
4949
logger.error(f"New implementation failed: {e}")
50-
return ModelLoadingResult.failure_result(
51-
error_message=f"New implementation failed: {e}",
52-
implementation=ImplementationVersion.NEW,
53-
)
50+
# Attempt legacy fallback using the original (pre-normalized) configuration
51+
try:
52+
fallback_result = self._load_with_legacy_implementation(
53+
model_path=model_path,
54+
original_config=config,
55+
device=device,
56+
original_error=str(e)
57+
)
58+
logger.warning("Fell back to legacy Roformer implementation successfully")
59+
return fallback_result
60+
except (RuntimeError, ValueError, TypeError) as fallback_error:
61+
logger.error(f"Legacy implementation also failed: {fallback_error}")
62+
self._loading_stats['total_failures'] += 1
63+
return ModelLoadingResult.failure_result(
64+
error_message=f"New implementation failed: {e}; Legacy fallback failed: {fallback_error}",
65+
implementation=ImplementationVersion.NEW,
66+
)
5467

5568
def validate_configuration(self, config: Dict[str, Any], model_type: str) -> bool:
5669
try:
@@ -160,13 +173,67 @@ def _create_mel_band_roformer(self, config: Dict[str, Any]):
160173
}
161174
if 'sample_rate' in config:
162175
model_args['sample_rate'] = config['sample_rate']
163-
if 'fmin' in config:
164-
model_args['fmin'] = config['fmin']
165-
if 'fmax' in config:
166-
model_args['fmax'] = config['fmax']
176+
# Optional parameters commonly present in legacy configs
177+
for optional_key in [
178+
'mask_estimator_depth',
179+
'stft_n_fft',
180+
'stft_hop_length',
181+
'stft_win_length',
182+
'stft_normalized',
183+
'stft_window_fn',
184+
'multi_stft_resolution_loss_weight',
185+
'multi_stft_resolutions_window_sizes',
186+
'multi_stft_hop_size',
187+
'multi_stft_normalized',
188+
'multi_stft_window_fn',
189+
'match_input_audio_length',
190+
]:
191+
if optional_key in config:
192+
model_args[optional_key] = config[optional_key]
193+
# Note: fmin and fmax are defined in config classes but not accepted by current constructor
167194
logger.debug(f"Creating MelBandRoformer with args: {list(model_args.keys())}")
168195
return MelBandRoformer(**model_args)
169196

197+
def _load_with_legacy_implementation(self,
198+
model_path: str,
199+
original_config: Dict[str, Any],
200+
device: str,
201+
original_error: str) -> ModelLoadingResult:
202+
"""
203+
Attempt to load the model using the legacy direct-constructor path
204+
for maximum backward compatibility with existing checkpoints.
205+
"""
206+
import torch
207+
208+
# Use nested 'model' section if present; otherwise assume flat
209+
model_cfg = original_config.get('model', original_config)
210+
211+
# Determine model type from config
212+
if 'num_bands' in model_cfg:
213+
from ..uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
214+
model = MelBandRoformer(**model_cfg)
215+
elif 'freqs_per_bands' in model_cfg:
216+
from ..uvr_lib_v5.roformer.bs_roformer import BSRoformer
217+
model = BSRoformer(**model_cfg)
218+
else:
219+
raise ValueError("Unknown Roformer model type in legacy configuration")
220+
221+
# Load checkpoint as raw state dict (legacy behavior)
222+
try:
223+
checkpoint = torch.load(model_path, map_location='cpu', weights_only=True)
224+
except TypeError:
225+
# For older torch versions without weights_only
226+
checkpoint = torch.load(model_path, map_location='cpu')
227+
228+
model.load_state_dict(checkpoint)
229+
model.to(device).eval()
230+
231+
return ModelLoadingResult.fallback_success_result(
232+
model=model,
233+
original_error=original_error,
234+
config=original_config,
235+
)
236+
170237
def get_loading_stats(self) -> Dict[str, int]:
171238
return self._loading_stats.copy()
172239

@@ -232,8 +299,7 @@ def get_default_configuration(self, model_type: str) -> Dict[str, Any]:
232299
'use_torch_checkpoint': False,
233300
'skip_connection': False,
234301
'sample_rate': 44100,
235-
'fmin': 0,
236-
'fmax': None,
302+
# Note: fmin and fmax are not implemented in MelBandRoformer constructor
237303
}
238304
else:
239305
raise ValueError(f"Unknown model type: {model_type}")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "audio-separator"
7-
version = "0.38.0"
7+
version = "0.38.1"
88
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
99
authors = ["Andrew Beveridge <[email protected]>"]
1010
license = "MIT"
Binary file not shown.
Binary file not shown.
729 KB
Loading
52.5 KB
Loading
741 KB
Loading
46 KB
Loading
0 Bytes
Loading
0 Bytes
Loading

0 commit comments

Comments
 (0)