Skip to content

Commit bf93b5c

Browse files
committed
Fix skeleton bug in inference
1 parent 7232bdd commit bf93b5c

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

sleap_nn/inference/predictors.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,12 @@ def from_trained_models(
783783
(Path(centroid_ckpt_path) / "training_config.json").as_posix()
784784
)
785785

786-
# skeletons = get_skeleton_from_config(centroid_config.data_config.skeletons)
786+
skeletons_dict = {}
787+
for skl_idx, skl in enumerate(centroid_config.data_config.skeletons):
788+
skel = SkeletonYAMLDecoder().decode(dict(skl))
789+
skel.name = skl.name
790+
skeletons_dict[skl_idx] = [skel]
791+
skeletons = skeletons_dict[output_head_skeleton_num]
787792

788793
# check which backbone architecture
789794
for k, v in centroid_config.model_config.backbone_config.items():
@@ -792,12 +797,6 @@ def from_trained_models(
792797
break
793798

794799
if not is_sleap_ckpt:
795-
skeletons_dict = {}
796-
for k in centroid_config.data_config.skeletons:
797-
skeletons_dict[k] = get_skeleton_from_config(
798-
centroid_config.data_config.skeletons[k]
799-
)
800-
skeletons = skeletons_dict[output_head_skeleton_num]
801800
ckpt_path = (Path(centroid_ckpt_path) / "best.ckpt").as_posix()
802801
centroid_model = CentroidMultiHeadLightningModule.load_from_checkpoint(
803802
checkpoint_path=ckpt_path,
@@ -906,6 +905,12 @@ def from_trained_models(
906905
)
907906

908907
# skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons)
908+
skeletons_dict = {}
909+
for skl_idx, skl in enumerate(confmap_config.data_config.skeletons):
910+
skel = SkeletonYAMLDecoder().decode(dict(skl))
911+
skel.name = skl.name
912+
skeletons_dict[skl_idx] = [skel]
913+
skeletons = skeletons_dict[output_head_skeleton_num]
909914

910915
# check which backbone architecture
911916
for k, v in confmap_config.model_config.backbone_config.items():
@@ -915,12 +920,6 @@ def from_trained_models(
915920

916921
if not is_sleap_ckpt:
917922
ckpt_path = (Path(confmap_ckpt_path) / "best.ckpt").as_posix()
918-
skeletons_dict = {}
919-
for skl_idx, skl in enumerate(confmap_config.data_config.skeletons):
920-
skel = SkeletonYAMLDecoder().decode(dict(skl))
921-
skel.name = skl.name
922-
skeletons_dict[skl_idx] = [skel]
923-
skeletons = skeletons_dict[output_head_skeleton_num]
924923
confmap_model = TopDownCenteredInstanceMultiHeadLightningModule.load_from_checkpoint(
925924
checkpoint_path=ckpt_path,
926925
model_type="centered_instance",

0 commit comments

Comments
 (0)