@@ -783,7 +783,12 @@ def from_trained_models(
783783 (Path (centroid_ckpt_path ) / "training_config.json" ).as_posix ()
784784 )
785785
786- # skeletons = get_skeleton_from_config(centroid_config.data_config.skeletons)
786+ skeletons_dict = {}
787+ for skl_idx , skl in enumerate (centroid_config .data_config .skeletons ):
788+ skel = SkeletonYAMLDecoder ().decode (dict (skl ))
789+ skel .name = skl .name
790+ skeletons_dict [skl_idx ] = [skel ]
791+ skeletons = skeletons_dict [output_head_skeleton_num ]
787792
788793 # check which backbone architecture
789794 for k , v in centroid_config .model_config .backbone_config .items ():
@@ -792,12 +797,6 @@ def from_trained_models(
792797 break
793798
794799 if not is_sleap_ckpt :
795- skeletons_dict = {}
796- for k in centroid_config .data_config .skeletons :
797- skeletons_dict [k ] = get_skeleton_from_config (
798- centroid_config .data_config .skeletons [k ]
799- )
800- skeletons = skeletons_dict [output_head_skeleton_num ]
801800 ckpt_path = (Path (centroid_ckpt_path ) / "best.ckpt" ).as_posix ()
802801 centroid_model = CentroidMultiHeadLightningModule .load_from_checkpoint (
803802 checkpoint_path = ckpt_path ,
@@ -906,6 +905,12 @@ def from_trained_models(
906905 )
907906
908907 # skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons)
908+ skeletons_dict = {}
909+ for skl_idx , skl in enumerate (confmap_config .data_config .skeletons ):
910+ skel = SkeletonYAMLDecoder ().decode (dict (skl ))
911+ skel .name = skl .name
912+ skeletons_dict [skl_idx ] = [skel ]
913+ skeletons = skeletons_dict [output_head_skeleton_num ]
909914
910915 # check which backbone architecture
911916 for k , v in confmap_config .model_config .backbone_config .items ():
@@ -915,12 +920,6 @@ def from_trained_models(
915920
916921 if not is_sleap_ckpt :
917922 ckpt_path = (Path (confmap_ckpt_path ) / "best.ckpt" ).as_posix ()
918- skeletons_dict = {}
919- for skl_idx , skl in enumerate (confmap_config .data_config .skeletons ):
920- skel = SkeletonYAMLDecoder ().decode (dict (skl ))
921- skel .name = skl .name
922- skeletons_dict [skl_idx ] = [skel ]
923- skeletons = skeletons_dict [output_head_skeleton_num ]
924923 confmap_model = TopDownCenteredInstanceMultiHeadLightningModule .load_from_checkpoint (
925924 checkpoint_path = ckpt_path ,
926925 model_type = "centered_instance" ,
0 commit comments