Skip to content

Commit 54f5cd4

Browse files
gitttt-1234claude
andcommitted
Fix confusing weight loading logging for legacy models
When loading backbone/head weights separately from a legacy .h5 file, the logging showed confusing counts and warnings about unmatched weights that were expected (e.g., head weights when loading backbone). Changes: - Add `filter_legacy_weights_by_component()` to pre-filter weights - Add `component` parameter to `load_legacy_model_weights()` and `map_legacy_to_pytorch_layers()` ("backbone", "head", or None) - Update training call sites to pass component type - Warnings now only show for truly unexpected mismatches - Clean up unused imports Before: Loading backbone shows warnings for all head weights After: Loading backbone only processes backbone weights, no spurious warnings Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 699a8e2 commit 54f5cd4

File tree

2 files changed

+67
-11
lines changed

2 files changed

+67
-11
lines changed

sleap_nn/legacy_models.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import h5py
88
import numpy as np
99
import torch
10-
from typing import Dict, Tuple, Any, Optional, List
10+
from typing import Dict, Any, Optional
1111
from pathlib import Path
12-
from omegaconf import OmegaConf
1312
import re
1413
from loguru import logger
1514

@@ -181,18 +180,61 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
181180
return info
182181

183182

183+
def filter_legacy_weights_by_component(
184+
legacy_weights: Dict[str, np.ndarray], component: Optional[str]
185+
) -> Dict[str, np.ndarray]:
186+
"""Filter legacy weights based on component type.
187+
188+
Args:
189+
legacy_weights: Dictionary of legacy weights from load_keras_weights()
190+
component: Component type to filter for. One of:
191+
- "backbone": Keep only encoder/decoder weights (exclude heads)
192+
- "head": Keep only head layer weights
193+
- None: No filtering (keep all weights)
194+
195+
Returns:
196+
Filtered dictionary of legacy weights
197+
"""
198+
if component is None:
199+
return legacy_weights
200+
201+
filtered = {}
202+
for path, weight in legacy_weights.items():
203+
# Check if this is a head layer (contains "Head" in the path)
204+
is_head_layer = "Head" in path
205+
206+
if component == "backbone" and not is_head_layer:
207+
filtered[path] = weight
208+
elif component == "head" and is_head_layer:
209+
filtered[path] = weight
210+
211+
return filtered
212+
213+
184214
def map_legacy_to_pytorch_layers(
185-
legacy_weights: Dict[str, np.ndarray], pytorch_model: torch.nn.Module
215+
legacy_weights: Dict[str, np.ndarray],
216+
pytorch_model: torch.nn.Module,
217+
component: Optional[str] = None,
186218
) -> Dict[str, str]:
187219
"""Create mapping between legacy Keras layers and PyTorch model layers.
188220
189221
Args:
190222
legacy_weights: Dictionary of legacy weights from load_keras_weights()
191223
pytorch_model: PyTorch model instance to map to
224+
component: Optional component type for filtering weights before mapping.
225+
One of "backbone", "head", or None (no filtering).
192226
193227
Returns:
194228
Dictionary mapping legacy layer paths to PyTorch parameter names
195229
"""
230+
# Filter weights based on component type
231+
filtered_weights = filter_legacy_weights_by_component(legacy_weights, component)
232+
233+
if component is not None:
234+
logger.info(
235+
f"Filtered legacy weights for {component}: "
236+
f"{len(filtered_weights)}/{len(legacy_weights)} weights"
237+
)
196238
mapping = {}
197239

198240
# Get all PyTorch parameters with their shapes
@@ -201,7 +243,7 @@ def map_legacy_to_pytorch_layers(
201243
pytorch_params[name] = param.shape
202244

203245
# For each legacy weight, find the corresponding PyTorch parameter
204-
for legacy_path, weight in legacy_weights.items():
246+
for legacy_path, weight in filtered_weights.items():
205247
# Extract the layer name from the legacy path
206248
# Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
207249
clean_path = legacy_path.replace("model_weights/", "")
@@ -220,8 +262,6 @@ def map_legacy_to_pytorch_layers(
220262
# This handles cases where Keras uses suffixes like _0, _1, etc.
221263
if "Head" in layer_name:
222264
# Remove trailing _N where N is a number
223-
import re
224-
225265
layer_name_clean = re.sub(r"_\d+$", "", layer_name)
226266
else:
227267
layer_name_clean = layer_name
@@ -266,12 +306,17 @@ def map_legacy_to_pytorch_layers(
266306
if not mapping:
267307
logger.info(
268308
f"No mappings could be created between legacy weights and PyTorch model. "
269-
f"Legacy weights: {len(legacy_weights)}, PyTorch parameters: {len(pytorch_params)}"
309+
f"Legacy weights: {len(filtered_weights)}, PyTorch parameters: {len(pytorch_params)}"
270310
)
271311
else:
272312
logger.info(
273-
f"Successfully mapped {len(mapping)}/{len(legacy_weights)} legacy weights to PyTorch parameters"
313+
f"Successfully mapped {len(mapping)}/{len(pytorch_params)} PyTorch parameters from legacy weights"
274314
)
315+
unmatched_count = len(filtered_weights) - len(mapping)
316+
if unmatched_count > 0:
317+
logger.warning(
318+
f"({unmatched_count} legacy weights did not match any parameters in this model component)"
319+
)
275320

276321
return mapping
277322

@@ -280,6 +325,7 @@ def load_legacy_model_weights(
280325
pytorch_model: torch.nn.Module,
281326
h5_path: str,
282327
mapping: Optional[Dict[str, str]] = None,
328+
component: Optional[str] = None,
283329
) -> None:
284330
"""Load legacy Keras weights into a PyTorch model.
285331
@@ -288,14 +334,20 @@ def load_legacy_model_weights(
288334
h5_path: Path to the legacy .h5 model file
289335
mapping: Optional manual mapping of layer names. If None,
290336
will attempt automatic mapping.
337+
component: Optional component type for filtering weights. One of:
338+
- "backbone": Only load encoder/decoder weights (exclude heads)
339+
- "head": Only load head layer weights
340+
- None: Load all weights (default, for full model loading)
291341
"""
292342
# Load legacy weights
293343
legacy_weights = load_keras_weights(h5_path)
294344

295345
if mapping is None:
296346
# Attempt automatic mapping
297347
try:
298-
mapping = map_legacy_to_pytorch_layers(legacy_weights, pytorch_model)
348+
mapping = map_legacy_to_pytorch_layers(
349+
legacy_weights, pytorch_model, component=component
350+
)
299351
except Exception as e:
300352
logger.error(f"Failed to create weight mappings: {e}")
301353
return

sleap_nn/training/lightning_modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def __init__(
229229
elif self.pretrained_backbone_weights.endswith(".h5"):
230230
# load from sleap model weights
231231
load_legacy_model_weights(
232-
self.model.backbone, self.pretrained_backbone_weights
232+
self.model.backbone,
233+
self.pretrained_backbone_weights,
234+
component="backbone",
233235
)
234236

235237
else:
@@ -258,7 +260,9 @@ def __init__(
258260
elif self.pretrained_head_weights.endswith(".h5"):
259261
# load from sleap model weights
260262
load_legacy_model_weights(
261-
self.model.head_layers, self.pretrained_head_weights
263+
self.model.head_layers,
264+
self.pretrained_head_weights,
265+
component="head",
262266
)
263267

264268
else:

0 commit comments

Comments
 (0)