77import h5py
88import numpy as np
99import torch
10- from typing import Dict , Tuple , Any , Optional , List
10+ from typing import Dict , Any , Optional
1111from pathlib import Path
12- from omegaconf import OmegaConf
1312import re
1413from loguru import logger
1514
@@ -181,18 +180,61 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
181180 return info
182181
183182
183+ def filter_legacy_weights_by_component (
184+ legacy_weights : Dict [str , np .ndarray ], component : Optional [str ]
185+ ) -> Dict [str , np .ndarray ]:
186+ """Filter legacy weights based on component type.
187+
188+ Args:
189+ legacy_weights: Dictionary of legacy weights from load_keras_weights()
190+ component: Component type to filter for. One of:
191+ - "backbone": Keep only encoder/decoder weights (exclude heads)
192+ - "head": Keep only head layer weights
193+ - None: No filtering (keep all weights)
194+
195+ Returns:
196+ Filtered dictionary of legacy weights
197+ """
198+ if component is None :
199+ return legacy_weights
200+
201+ filtered = {}
202+ for path , weight in legacy_weights .items ():
203+ # Check if this is a head layer (contains "Head" in the path)
204+ is_head_layer = "Head" in path
205+
206+ if component == "backbone" and not is_head_layer :
207+ filtered [path ] = weight
208+ elif component == "head" and is_head_layer :
209+ filtered [path ] = weight
210+
211+ return filtered
212+
213+
184214def map_legacy_to_pytorch_layers (
185- legacy_weights : Dict [str , np .ndarray ], pytorch_model : torch .nn .Module
215+ legacy_weights : Dict [str , np .ndarray ],
216+ pytorch_model : torch .nn .Module ,
217+ component : Optional [str ] = None ,
186218) -> Dict [str , str ]:
187219 """Create mapping between legacy Keras layers and PyTorch model layers.
188220
189221 Args:
190222 legacy_weights: Dictionary of legacy weights from load_keras_weights()
191223 pytorch_model: PyTorch model instance to map to
224+ component: Optional component type for filtering weights before mapping.
225+ One of "backbone", "head", or None (no filtering).
192226
193227 Returns:
194228 Dictionary mapping legacy layer paths to PyTorch parameter names
195229 """
230+ # Filter weights based on component type
231+ filtered_weights = filter_legacy_weights_by_component (legacy_weights , component )
232+
233+ if component is not None :
234+ logger .info (
235+ f"Filtered legacy weights for { component } : "
236+ f"{ len (filtered_weights )} /{ len (legacy_weights )} weights"
237+ )
196238 mapping = {}
197239
198240 # Get all PyTorch parameters with their shapes
@@ -201,7 +243,7 @@ def map_legacy_to_pytorch_layers(
201243 pytorch_params [name ] = param .shape
202244
203245 # For each legacy weight, find the corresponding PyTorch parameter
204- for legacy_path , weight in legacy_weights .items ():
246+ for legacy_path , weight in filtered_weights .items ():
205247 # Extract the layer name from the legacy path
206248 # Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
207249 clean_path = legacy_path .replace ("model_weights/" , "" )
@@ -220,8 +262,6 @@ def map_legacy_to_pytorch_layers(
220262 # This handles cases where Keras uses suffixes like _0, _1, etc.
221263 if "Head" in layer_name :
222264 # Remove trailing _N where N is a number
223- import re
224-
225265 layer_name_clean = re .sub (r"_\d+$" , "" , layer_name )
226266 else :
227267 layer_name_clean = layer_name
@@ -266,12 +306,17 @@ def map_legacy_to_pytorch_layers(
266306 if not mapping :
267307 logger .info (
268308 f"No mappings could be created between legacy weights and PyTorch model. "
269- f"Legacy weights: { len (legacy_weights )} , PyTorch parameters: { len (pytorch_params )} "
309+ f"Legacy weights: { len (filtered_weights )} , PyTorch parameters: { len (pytorch_params )} "
270310 )
271311 else :
272312 logger .info (
273- f"Successfully mapped { len (mapping )} /{ len (legacy_weights )} legacy weights to PyTorch parameters "
313+ f"Successfully mapped { len (mapping )} /{ len (pytorch_params )} PyTorch parameters from legacy weights "
274314 )
315+ unmatched_count = len (filtered_weights ) - len (mapping )
316+ if unmatched_count > 0 :
317+ logger .warning (
318+ f"({ unmatched_count } legacy weights did not match any parameters in this model component)"
319+ )
275320
276321 return mapping
277322
@@ -280,6 +325,7 @@ def load_legacy_model_weights(
280325 pytorch_model : torch .nn .Module ,
281326 h5_path : str ,
282327 mapping : Optional [Dict [str , str ]] = None ,
328+ component : Optional [str ] = None ,
283329) -> None :
284330 """Load legacy Keras weights into a PyTorch model.
285331
@@ -288,14 +334,20 @@ def load_legacy_model_weights(
288334 h5_path: Path to the legacy .h5 model file
289335 mapping: Optional manual mapping of layer names. If None,
290336 will attempt automatic mapping.
337+ component: Optional component type for filtering weights. One of:
338+ - "backbone": Only load encoder/decoder weights (exclude heads)
339+ - "head": Only load head layer weights
340+ - None: Load all weights (default, for full model loading)
291341 """
292342 # Load legacy weights
293343 legacy_weights = load_keras_weights (h5_path )
294344
295345 if mapping is None :
296346 # Attempt automatic mapping
297347 try :
298- mapping = map_legacy_to_pytorch_layers (legacy_weights , pytorch_model )
348+ mapping = map_legacy_to_pytorch_layers (
349+ legacy_weights , pytorch_model , component = component
350+ )
299351 except Exception as e :
300352 logger .error (f"Failed to create weight mappings: { e } " )
301353 return
0 commit comments