@@ -832,3 +832,154 @@ def validate_pre_trained_weights(self, value):
832832 message = "UNet does not support pre-trained weights."
833833 logger .error (message )
834834 raise ValueError (message )
835+
836+
837+ def model_mapper (legacy_config : dict ) -> ModelConfig :
838+ """Map the legacy model configuration to the new model configuration.
839+
840+ Args:
841+ legacy_config: A dictionary containing the legacy model configuration.
842+
843+ Returns:
844+ An instance of `ModelConfig` with the mapped configuration.
845+ """
846+ legacy_config_model = legacy_config .get ("model" , {})
847+ return ModelConfig (
848+ backbone_config = BackboneConfig (
849+ unet = (
850+ UNetConfig (
851+ filters = legacy_config_model .get ("backbone" , {})
852+ .get ("unet" , {})
853+ .get ("filters" , 32 ),
854+ filters_rate = legacy_config_model .get ("backbone" , {})
855+ .get ("unet" , {})
856+ .get ("filters_rate" , 1.5 ),
857+ max_stride = legacy_config_model .get ("backbone" , {})
858+ .get ("unet" , {})
859+ .get ("max_stride" , 16 ),
860+ stem_stride = legacy_config_model .get ("backbone" , {})
861+ .get ("unet" , {})
862+ .get ("stem_stride" , 16 ),
863+ middle_block = legacy_config_model .get ("backbone" , {})
864+ .get ("unet" , {})
865+ .get ("middle_block" , True ),
866+ up_interpolate = legacy_config_model .get ("backbone" , {})
867+ .get ("unet" , {})
868+ .get ("up_interpolate" , True ),
869+ stacks = legacy_config_model .get ("backbone" , {})
870+ .get ("unet" , {})
871+ .get ("stacks" , 1 ),
872+ # convs_per_block=2,
873+ output_stride = legacy_config_model .get ("backbone" , {})
874+ .get ("unet" , {})
875+ .get ("output_stride" , 1 ),
876+ )
877+ if legacy_config_model .get ("backbone" , {}).get ("unet" , None ) is not None
878+ else None
879+ ),
880+ ),
881+ head_configs = HeadConfig (
882+ single_instance = (
883+ (
884+ SingleInstanceConfig (
885+ confmaps = SingleInstanceConfMapsConfig (
886+ part_names = legacy_config_model .get ("heads" , {})
887+ .get ("single_instance" , {})
888+ .get ("part_names" , None ),
889+ sigma = legacy_config_model .get ("heads" , {})
890+ .get ("single_instance" , {})
891+ .get ("sigma" , 5.0 ),
892+ output_stride = legacy_config_model .get ("heads" , {})
893+ .get ("single_instance" , {})
894+ .get ("output_stride" , 1 ),
895+ )
896+ )
897+ )
898+ if legacy_config_model .get ("heads" , {}).get ("single_instance" , None )
899+ is not None
900+ else None
901+ ),
902+ centroid = (
903+ CentroidConfig (
904+ confmaps = CentroidConfMapsConfig (
905+ anchor_part = legacy_config_model .get ("heads" , {})
906+ .get ("centroid" , {})
907+ .get ("anchor_part" , None ),
908+ sigma = legacy_config_model .get ("heads" , {})
909+ .get ("centroid" , {})
910+ .get ("sigma" , 5.0 ),
911+ output_stride = legacy_config_model .get ("heads" , {})
912+ .get ("centroid" , {})
913+ .get ("output_stride" , 1 ),
914+ )
915+ )
916+ if legacy_config_model .get ("heads" , {}).get ("centroid" , None )
917+ is not None
918+ else None
919+ ),
920+ centered_instance = (
921+ CenteredInstanceConfig (
922+ confmaps = CenteredInstanceConfMapsConfig (
923+ anchor_part = legacy_config_model .get ("heads" , {})
924+ .get ("centered_instance" , {})
925+ .get ("anchor_part" , None ),
926+ sigma = legacy_config_model .get ("heads" , {})
927+ .get ("centered_instance" , {})
928+ .get ("sigma" , 5.0 ),
929+ output_stride = legacy_config_model .get ("heads" , {})
930+ .get ("centered_instance" , {})
931+ .get ("output_stride" , 1 ),
932+ part_names = legacy_config_model .get ("heads" , {})
933+ .get ("centered_instance" , {})
934+ .get ("part_names" , None ),
935+ )
936+ )
937+ if legacy_config_model .get ("heads" , {}).get ("centered_instance" , None )
938+ is not None
939+ else None
940+ ),
941+ bottomup = (
942+ BottomUpConfig (
943+ confmaps = BottomUpConfMapsConfig (
944+ loss_weight = legacy_config_model .get ("heads" , {})
945+ .get ("multi_instance" , {})
946+ .get ("confmaps" , {})
947+ .get ("loss_weight" , None ),
948+ sigma = legacy_config_model .get ("heads" , {})
949+ .get ("multi_instance" , {})
950+ .get ("confmaps" , {})
951+ .get ("sigma" , 5.0 ),
952+ output_stride = legacy_config_model .get ("heads" , {})
953+ .get ("multi_instance" , {})
954+ .get ("confmaps" , {})
955+ .get ("output_stride" , 1 ),
956+ part_names = legacy_config_model .get ("heads" , {})
957+ .get ("multi_instance" , {})
958+ .get ("confmaps" , {})
959+ .get ("part_names" , None ),
960+ ),
961+ pafs = PAFConfig (
962+ edges = legacy_config_model .get ("heads" , {})
963+ .get ("multi_instance" , {})
964+ .get ("pafs" , {})
965+ .get ("edges" , None ),
966+ sigma = legacy_config_model .get ("heads" , {})
967+ .get ("multi_instance" , {})
968+ .get ("pafs" , {})
969+ .get ("sigma" , 15.0 ),
970+ output_stride = legacy_config_model .get ("heads" , {})
971+ .get ("multi_instance" , {})
972+ .get ("pafs" , {})
973+ .get ("output_stride" , 1 ),
974+ loss_weight = legacy_config_model .get ("heads" , {})
975+ .get ("multi_instance" , {})
976+ .get ("pafs" , {})
977+ .get ("loss_weight" , None ),
978+ ),
979+ )
980+ if legacy_config_model .get ("heads" , {}).get ("multi_instance" , None )
981+ is not None
982+ else None
983+ ),
984+ ),
985+ )
0 commit comments