Skip to content

Commit b843410

Browse files
committed
Add more tests for backbone ckpt
1 parent 915ccb0 commit b843410

File tree

1 file changed

+120
-7
lines changed

1 file changed

+120
-7
lines changed

tests/inference/test_predictors.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,32 @@ def test_topdown_predictor(
216216
)
217217

218218
assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
219-
print(
220-
f"centered instance model: ",
221-
predictor.inference_model.instance_peaks.torch_model.model,
219+
220+
# load only backbone and head ckpt as None - centered instance
221+
predictor = Predictor.from_model_paths(
222+
[minimal_instance_ckpt],
223+
backbone_ckpt_path=Path(minimal_instance_ckpt) / "best.ckpt",
224+
head_ckpt_path=None,
225+
peak_threshold=0.03,
226+
max_instances=6,
227+
preprocess_config=OmegaConf.create(preprocess_config),
228+
)
229+
230+
ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt")
231+
backbone_ckpt = ckpt["state_dict"][
232+
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
233+
][0, 0, :].numpy()
234+
235+
model_weights = (
236+
next(predictor.inference_model.instance_peaks.torch_model.model.parameters())[
237+
0, 0, :
238+
]
239+
.detach()
240+
.numpy()
222241
)
223242

243+
assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6)
244+
224245
# check loading diff head ckpt for centroid
225246
preprocess_config = {
226247
"is_rgb": False,
@@ -238,11 +259,32 @@ def test_topdown_predictor(
238259
preprocess_config=OmegaConf.create(preprocess_config),
239260
)
240261

241-
print(
242-
f"centroid model: ", predictor.inference_model.centroid_crop.torch_model.model
262+
ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt")
263+
backbone_ckpt = ckpt["state_dict"][
264+
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
265+
][0, 0, :].numpy()
266+
267+
model_weights = (
268+
next(predictor.inference_model.centroid_crop.torch_model.model.parameters())[
269+
0, 0, :
270+
]
271+
.detach()
272+
.numpy()
273+
)
274+
275+
assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6)
276+
277+
# load only backbone and head ckpt as None - centroid
278+
predictor = Predictor.from_model_paths(
279+
[minimal_instance_centroid_ckpt],
280+
backbone_ckpt_path=Path(minimal_instance_centroid_ckpt) / "best.ckpt",
281+
head_ckpt_path=None,
282+
peak_threshold=0.03,
283+
max_instances=6,
284+
preprocess_config=OmegaConf.create(preprocess_config),
243285
)
244286

245-
ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt")
287+
ckpt = torch.load(Path(minimal_instance_centroid_ckpt) / "best.ckpt")
246288
backbone_ckpt = ckpt["state_dict"][
247289
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
248290
][0, 0, :].numpy()
@@ -261,7 +303,6 @@ def test_topdown_predictor(
261303
def test_single_instance_predictor(
262304
minimal_instance,
263305
minimal_instance_ckpt,
264-
minimal_instance_centroid_ckpt,
265306
minimal_instance_bottomup_ckpt,
266307
):
267308
"""Test SingleInstancePredictor module."""
@@ -453,6 +494,55 @@ def test_single_instance_predictor(
453494
# save the original config back
454495
OmegaConf.save(_config, f"{minimal_instance_ckpt}/training_config.yaml")
455496

497+
_config = OmegaConf.load(f"{minimal_instance_ckpt}/training_config.yaml")
498+
499+
config = _config.copy()
500+
501+
try:
502+
head_config = config.model_config.head_configs.centered_instance
503+
del config.model_config.head_configs.centered_instance
504+
OmegaConf.update(
505+
config, "model_config.head_configs.single_instance", head_config
506+
)
507+
del config.model_config.head_configs.single_instance.confmaps.anchor_part
508+
OmegaConf.update(config, "data_config.preprocessing.scale", 0.9)
509+
510+
OmegaConf.save(config, f"{minimal_instance_ckpt}/training_config.yaml")
511+
512+
# check loading diff head ckpt
513+
preprocess_config = {
514+
"is_rgb": False,
515+
"crop_hw": None,
516+
"max_width": None,
517+
"max_height": None,
518+
}
519+
520+
predictor = Predictor.from_model_paths(
521+
[minimal_instance_bottomup_ckpt],
522+
backbone_ckpt_path=Path(minimal_instance_ckpt) / "best.ckpt",
523+
head_ckpt_path=None,
524+
peak_threshold=0.03,
525+
max_instances=6,
526+
preprocess_config=OmegaConf.create(preprocess_config),
527+
)
528+
529+
ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt")
530+
backbone_ckpt = ckpt["state_dict"][
531+
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
532+
][0, 0, :].numpy()
533+
534+
model_weights = (
535+
next(predictor.inference_model.torch_model.model.parameters())[0, 0, :]
536+
.detach()
537+
.numpy()
538+
)
539+
540+
assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6)
541+
542+
finally:
543+
# save the original config back
544+
OmegaConf.save(_config, f"{minimal_instance_ckpt}/training_config.yaml")
545+
456546

457547
def test_bottomup_predictor(
458548
minimal_instance, minimal_instance_bottomup_ckpt, minimal_instance_ckpt
@@ -611,3 +701,26 @@ def test_bottomup_predictor(
611701
print(model_weights)
612702

613703
assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
704+
705+
# load only backbone and head ckpt as None
706+
predictor = Predictor.from_model_paths(
707+
[minimal_instance_bottomup_ckpt],
708+
backbone_ckpt_path=Path(minimal_instance_ckpt) / "best.ckpt",
709+
head_ckpt_path=None,
710+
peak_threshold=0.03,
711+
max_instances=6,
712+
preprocess_config=OmegaConf.create(preprocess_config),
713+
)
714+
715+
ckpt = torch.load(Path(minimal_instance_ckpt) / "best.ckpt")
716+
backbone_ckpt = ckpt["state_dict"][
717+
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
718+
][0, 0, :].numpy()
719+
720+
model_weights = (
721+
next(predictor.inference_model.torch_model.model.parameters())[0, 0, :]
722+
.detach()
723+
.numpy()
724+
)
725+
726+
assert np.all(np.abs(backbone_ckpt - model_weights) < 1e-6)

0 commit comments

Comments
 (0)