@@ -216,11 +216,32 @@ def test_topdown_predictor(
216216 )
217217
218218 assert np .all (np .abs (head_layer_ckpt - model_weights ) < 1e-6 )
219- print (
220- f"centered instance model: " ,
221- predictor .inference_model .instance_peaks .torch_model .model ,
219+
220+ # load only backbone and head ckpt as None - centered instance
221+ predictor = Predictor .from_model_paths (
222+ [minimal_instance_ckpt ],
223+ backbone_ckpt_path = Path (minimal_instance_ckpt ) / "best.ckpt" ,
224+ head_ckpt_path = None ,
225+ peak_threshold = 0.03 ,
226+ max_instances = 6 ,
227+ preprocess_config = OmegaConf .create (preprocess_config ),
228+ )
229+
230+ ckpt = torch .load (Path (minimal_instance_ckpt ) / "best.ckpt" )
231+ backbone_ckpt = ckpt ["state_dict" ][
232+ "model.backbone.enc.encoder_stack.0.blocks.0.weight"
233+ ][0 , 0 , :].numpy ()
234+
235+ model_weights = (
236+ next (predictor .inference_model .instance_peaks .torch_model .model .parameters ())[
237+ 0 , 0 , :
238+ ]
239+ .detach ()
240+ .numpy ()
222241 )
223242
243+ assert np .all (np .abs (backbone_ckpt - model_weights ) < 1e-6 )
244+
224245 # check loading diff head ckpt for centroid
225246 preprocess_config = {
226247 "is_rgb" : False ,
@@ -238,11 +259,32 @@ def test_topdown_predictor(
238259 preprocess_config = OmegaConf .create (preprocess_config ),
239260 )
240261
241- print (
242- f"centroid model: " , predictor .inference_model .centroid_crop .torch_model .model
262+ ckpt = torch .load (Path (minimal_instance_ckpt ) / "best.ckpt" )
263+ backbone_ckpt = ckpt ["state_dict" ][
264+ "model.backbone.enc.encoder_stack.0.blocks.0.weight"
265+ ][0 , 0 , :].numpy ()
266+
267+ model_weights = (
268+ next (predictor .inference_model .centroid_crop .torch_model .model .parameters ())[
269+ 0 , 0 , :
270+ ]
271+ .detach ()
272+ .numpy ()
273+ )
274+
275+ assert np .all (np .abs (backbone_ckpt - model_weights ) < 1e-6 )
276+
277+ # load only backbone and head ckpt as None - centroid
278+ predictor = Predictor .from_model_paths (
279+ [minimal_instance_centroid_ckpt ],
280+ backbone_ckpt_path = Path (minimal_instance_centroid_ckpt ) / "best.ckpt" ,
281+ head_ckpt_path = None ,
282+ peak_threshold = 0.03 ,
283+ max_instances = 6 ,
284+ preprocess_config = OmegaConf .create (preprocess_config ),
243285 )
244286
245- ckpt = torch .load (Path (minimal_instance_ckpt ) / "best.ckpt" )
287+ ckpt = torch .load (Path (minimal_instance_centroid_ckpt ) / "best.ckpt" )
246288 backbone_ckpt = ckpt ["state_dict" ][
247289 "model.backbone.enc.encoder_stack.0.blocks.0.weight"
248290 ][0 , 0 , :].numpy ()
@@ -261,7 +303,6 @@ def test_topdown_predictor(
261303def test_single_instance_predictor (
262304 minimal_instance ,
263305 minimal_instance_ckpt ,
264- minimal_instance_centroid_ckpt ,
265306 minimal_instance_bottomup_ckpt ,
266307):
267308 """Test SingleInstancePredictor module."""
@@ -453,6 +494,55 @@ def test_single_instance_predictor(
453494 # save the original config back
454495 OmegaConf .save (_config , f"{ minimal_instance_ckpt } /training_config.yaml" )
455496
497+ _config = OmegaConf .load (f"{ minimal_instance_ckpt } /training_config.yaml" )
498+
499+ config = _config .copy ()
500+
501+ try :
502+ head_config = config .model_config .head_configs .centered_instance
503+ del config .model_config .head_configs .centered_instance
504+ OmegaConf .update (
505+ config , "model_config.head_configs.single_instance" , head_config
506+ )
507+ del config .model_config .head_configs .single_instance .confmaps .anchor_part
508+ OmegaConf .update (config , "data_config.preprocessing.scale" , 0.9 )
509+
510+ OmegaConf .save (config , f"{ minimal_instance_ckpt } /training_config.yaml" )
511+
512+ # check loading diff head ckpt
513+ preprocess_config = {
514+ "is_rgb" : False ,
515+ "crop_hw" : None ,
516+ "max_width" : None ,
517+ "max_height" : None ,
518+ }
519+
520+ predictor = Predictor .from_model_paths (
521+ [minimal_instance_bottomup_ckpt ],
522+ backbone_ckpt_path = Path (minimal_instance_ckpt ) / "best.ckpt" ,
523+ head_ckpt_path = None ,
524+ peak_threshold = 0.03 ,
525+ max_instances = 6 ,
526+ preprocess_config = OmegaConf .create (preprocess_config ),
527+ )
528+
529+ ckpt = torch .load (Path (minimal_instance_ckpt ) / "best.ckpt" )
530+ backbone_ckpt = ckpt ["state_dict" ][
531+ "model.backbone.enc.encoder_stack.0.blocks.0.weight"
532+ ][0 , 0 , :].numpy ()
533+
534+ model_weights = (
535+ next (predictor .inference_model .torch_model .model .parameters ())[0 , 0 , :]
536+ .detach ()
537+ .numpy ()
538+ )
539+
540+ assert np .all (np .abs (backbone_ckpt - model_weights ) < 1e-6 )
541+
542+ finally :
543+ # save the original config back
544+ OmegaConf .save (_config , f"{ minimal_instance_ckpt } /training_config.yaml" )
545+
456546
457547def test_bottomup_predictor (
458548 minimal_instance , minimal_instance_bottomup_ckpt , minimal_instance_ckpt
@@ -611,3 +701,26 @@ def test_bottomup_predictor(
611701 print (model_weights )
612702
613703 assert np .all (np .abs (head_layer_ckpt - model_weights ) < 1e-6 )
704+
705+ # load only backbone and head ckpt as None
706+ predictor = Predictor .from_model_paths (
707+ [minimal_instance_bottomup_ckpt ],
708+ backbone_ckpt_path = Path (minimal_instance_ckpt ) / "best.ckpt" ,
709+ head_ckpt_path = None ,
710+ peak_threshold = 0.03 ,
711+ max_instances = 6 ,
712+ preprocess_config = OmegaConf .create (preprocess_config ),
713+ )
714+
715+ ckpt = torch .load (Path (minimal_instance_ckpt ) / "best.ckpt" )
716+ backbone_ckpt = ckpt ["state_dict" ][
717+ "model.backbone.enc.encoder_stack.0.blocks.0.weight"
718+ ][0 , 0 , :].numpy ()
719+
720+ model_weights = (
721+ next (predictor .inference_model .torch_model .model .parameters ())[0 , 0 , :]
722+ .detach ()
723+ .numpy ()
724+ )
725+
726+ assert np .all (np .abs (backbone_ckpt - model_weights ) < 1e-6 )
0 commit comments