Skip to content

Commit 5c98f88

Browse files
authored
Merge pull request #162 from talmolab/greg/map-old-sleap-config-files-to-new
Map legacy SLEAP `json` configs to SLEAP-NN `OmegaConf` objects
2 parents 5c3a38d + 482c8cd commit 5c98f88

17 files changed

+2401
-9
lines changed

sleap_nn/.DS_Store

6 KB
Binary file not shown.

sleap_nn/config/data_config.py

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ class DataConfig:
170170
train dataset and saved to the `training_config.yaml`
171171
"""
172172

173-
train_labels_path: str = MISSING
174-
val_labels_path: str = MISSING
173+
train_labels_path: Optional[str] = None
174+
val_labels_path: Optional[str] = None # TODO : revisit MISSING!
175175
test_file_path: Optional[str] = None
176176
provider: str = "LabelsReader"
177177
user_instances_only: bool = True
@@ -185,3 +185,149 @@ class DataConfig:
185185
use_augmentations_train: bool = False
186186
augmentation_config: Optional[AugmentationConfig] = None
187187
skeletons: Optional[dict] = None
188+
189+
190+
def data_mapper(legacy_config: dict) -> DataConfig:
191+
"""Maps the legacy data configuration to the new data configuration.
192+
193+
Args:
194+
legacy_config: A dictionary containing the legacy data configuration.
195+
196+
Returns:
197+
An instance of `DataConfig` with the mapped configuration.
198+
"""
199+
legacy_config_data = legacy_config.get("data", {})
200+
legacy_config_optimization = legacy_config.get("optimization", {})
201+
202+
return DataConfig(
203+
train_labels_path=legacy_config_data.get("labels", {}).get(
204+
"training_labels", None
205+
),
206+
val_labels_path=legacy_config_data.get("labels", {}).get(
207+
"validation_labels", None
208+
),
209+
test_file_path=legacy_config_data.get("labels", {}).get("test_labels", None),
210+
preprocessing=PreprocessingConfig(
211+
is_rgb=legacy_config_data.get("preprocessing", {}).get("ensure_rgb", False),
212+
max_height=legacy_config_data.get("preprocessing", {}).get(
213+
"target_height", None
214+
),
215+
max_width=legacy_config_data.get("preprocessing", {}).get(
216+
"target_width", None
217+
),
218+
scale=legacy_config_data.get("preprocessing", {}).get("input_scaling", 1.0),
219+
crop_hw=(
220+
(
221+
legacy_config_data.get("instance_cropping", {}).get(
222+
"crop_size", None
223+
),
224+
legacy_config_data.get("instance_cropping", {}).get(
225+
"crop_size", None
226+
),
227+
)
228+
if legacy_config_data.get("instance_cropping", {}).get(
229+
"crop_size", None
230+
)
231+
is not None
232+
else None
233+
),
234+
),
235+
augmentation_config=(
236+
AugmentationConfig(
237+
intensity=IntensityConfig(
238+
uniform_noise_min=legacy_config_optimization.get(
239+
"augmentation_config", {}
240+
).get("uniform_noise_min_val", 0.0),
241+
uniform_noise_max=min(
242+
legacy_config_optimization.get("augmentation_config", {}).get(
243+
"uniform_noise_max_val", 1.0
244+
),
245+
1.0,
246+
),
247+
uniform_noise_p=float(
248+
legacy_config_optimization.get("augmentation_config", {}).get(
249+
"uniform_noise", 1.0
250+
)
251+
),
252+
gaussian_noise_mean=legacy_config_optimization.get(
253+
"augmentation_config", {}
254+
).get("gaussian_noise_mean", 0.0),
255+
gaussian_noise_std=legacy_config_optimization.get(
256+
"augmentation_config", {}
257+
).get("gaussian_noise_stddev", 1.0),
258+
gaussian_noise_p=float(
259+
legacy_config_optimization.get("augmentation_config", {}).get(
260+
"gaussian_noise", 1.0
261+
)
262+
),
263+
contrast_min=legacy_config_optimization.get(
264+
"augmentation_config", {}
265+
).get("contrast_min_gamma", 0.5),
266+
contrast_max=legacy_config_optimization.get(
267+
"augmentation_config", {}
268+
).get("contrast_max_gamma", 2.0),
269+
contrast_p=float(
270+
legacy_config_optimization.get("augmentation_config", {}).get(
271+
"contrast", 1.0
272+
)
273+
),
274+
brightness=(
275+
legacy_config_optimization.get("augmentation_config", {}).get(
276+
"brightness_min_val", 1.0
277+
),
278+
legacy_config_optimization.get("augmentation_config", {}).get(
279+
"brightness_max_val", 1.0
280+
),
281+
),
282+
brightness_p=float(
283+
legacy_config_optimization.get("augmentation_config", {}).get(
284+
"brightness", 1.0
285+
)
286+
),
287+
),
288+
geometric=GeometricConfig(
289+
rotation=(
290+
legacy_config_optimization.get("augmentation_config", {}).get(
291+
"rotation_max_angle", 15.0
292+
)
293+
if legacy_config_optimization.get(
294+
"augmentation_config", {}
295+
).get("rotate", True)
296+
else 0
297+
),
298+
scale=(
299+
(
300+
legacy_config_optimization.get(
301+
"augmentation_config", {}
302+
).get("scale_min", 0.9),
303+
legacy_config_optimization.get(
304+
"augmentation_config", {}
305+
).get("scale_max", 1.1),
306+
)
307+
if legacy_config_optimization.get(
308+
"augmentation_config", {}
309+
).get("scale", False)
310+
else (1.0, 1.0)
311+
),
312+
affine_p=(
313+
1.0
314+
if any(
315+
[
316+
legacy_config_optimization.get(
317+
"augmentation_config", {}
318+
).get("rotate", True),
319+
legacy_config_optimization.get(
320+
"augmentation_config", {}
321+
).get("scale", False),
322+
]
323+
)
324+
else 0.0
325+
),
326+
),
327+
)
328+
),
329+
use_augmentations_train=True,
330+
skeletons=legacy_config_data.get("labels", {}).get("skeletons", [{}])[
331+
0
332+
], # TODO
333+
)

sleap_nn/config/model_config.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,154 @@ def validate_pre_trained_weights(self, value):
832832
message = "UNet does not support pre-trained weights."
833833
logger.error(message)
834834
raise ValueError(message)
835+
836+
837+
def model_mapper(legacy_config: dict) -> ModelConfig:
838+
"""Map the legacy model configuration to the new model configuration.
839+
840+
Args:
841+
legacy_config: A dictionary containing the legacy model configuration.
842+
843+
Returns:
844+
An instance of `ModelConfig` with the mapped configuration.
845+
"""
846+
legacy_config_model = legacy_config.get("model", {})
847+
return ModelConfig(
848+
backbone_config=BackboneConfig(
849+
unet=(
850+
UNetConfig(
851+
filters=legacy_config_model.get("backbone", {})
852+
.get("unet", {})
853+
.get("filters", 32),
854+
filters_rate=legacy_config_model.get("backbone", {})
855+
.get("unet", {})
856+
.get("filters_rate", 1.5),
857+
max_stride=legacy_config_model.get("backbone", {})
858+
.get("unet", {})
859+
.get("max_stride", 16),
860+
stem_stride=legacy_config_model.get("backbone", {})
861+
.get("unet", {})
862+
.get("stem_stride", 16),
863+
middle_block=legacy_config_model.get("backbone", {})
864+
.get("unet", {})
865+
.get("middle_block", True),
866+
up_interpolate=legacy_config_model.get("backbone", {})
867+
.get("unet", {})
868+
.get("up_interpolate", True),
869+
stacks=legacy_config_model.get("backbone", {})
870+
.get("unet", {})
871+
.get("stacks", 1),
872+
# convs_per_block=2,
873+
output_stride=legacy_config_model.get("backbone", {})
874+
.get("unet", {})
875+
.get("output_stride", 1),
876+
)
877+
if legacy_config_model.get("backbone", {}).get("unet", None) is not None
878+
else None
879+
),
880+
),
881+
head_configs=HeadConfig(
882+
single_instance=(
883+
(
884+
SingleInstanceConfig(
885+
confmaps=SingleInstanceConfMapsConfig(
886+
part_names=legacy_config_model.get("heads", {})
887+
.get("single_instance", {})
888+
.get("part_names", None),
889+
sigma=legacy_config_model.get("heads", {})
890+
.get("single_instance", {})
891+
.get("sigma", 5.0),
892+
output_stride=legacy_config_model.get("heads", {})
893+
.get("single_instance", {})
894+
.get("output_stride", 1),
895+
)
896+
)
897+
)
898+
if legacy_config_model.get("heads", {}).get("single_instance", None)
899+
is not None
900+
else None
901+
),
902+
centroid=(
903+
CentroidConfig(
904+
confmaps=CentroidConfMapsConfig(
905+
anchor_part=legacy_config_model.get("heads", {})
906+
.get("centroid", {})
907+
.get("anchor_part", None),
908+
sigma=legacy_config_model.get("heads", {})
909+
.get("centroid", {})
910+
.get("sigma", 5.0),
911+
output_stride=legacy_config_model.get("heads", {})
912+
.get("centroid", {})
913+
.get("output_stride", 1),
914+
)
915+
)
916+
if legacy_config_model.get("heads", {}).get("centroid", None)
917+
is not None
918+
else None
919+
),
920+
centered_instance=(
921+
CenteredInstanceConfig(
922+
confmaps=CenteredInstanceConfMapsConfig(
923+
anchor_part=legacy_config_model.get("heads", {})
924+
.get("centered_instance", {})
925+
.get("anchor_part", None),
926+
sigma=legacy_config_model.get("heads", {})
927+
.get("centered_instance", {})
928+
.get("sigma", 5.0),
929+
output_stride=legacy_config_model.get("heads", {})
930+
.get("centered_instance", {})
931+
.get("output_stride", 1),
932+
part_names=legacy_config_model.get("heads", {})
933+
.get("centered_instance", {})
934+
.get("part_names", None),
935+
)
936+
)
937+
if legacy_config_model.get("heads", {}).get("centered_instance", None)
938+
is not None
939+
else None
940+
),
941+
bottomup=(
942+
BottomUpConfig(
943+
confmaps=BottomUpConfMapsConfig(
944+
loss_weight=legacy_config_model.get("heads", {})
945+
.get("multi_instance", {})
946+
.get("confmaps", {})
947+
.get("loss_weight", None),
948+
sigma=legacy_config_model.get("heads", {})
949+
.get("multi_instance", {})
950+
.get("confmaps", {})
951+
.get("sigma", 5.0),
952+
output_stride=legacy_config_model.get("heads", {})
953+
.get("multi_instance", {})
954+
.get("confmaps", {})
955+
.get("output_stride", 1),
956+
part_names=legacy_config_model.get("heads", {})
957+
.get("multi_instance", {})
958+
.get("confmaps", {})
959+
.get("part_names", None),
960+
),
961+
pafs=PAFConfig(
962+
edges=legacy_config_model.get("heads", {})
963+
.get("multi_instance", {})
964+
.get("pafs", {})
965+
.get("edges", None),
966+
sigma=legacy_config_model.get("heads", {})
967+
.get("multi_instance", {})
968+
.get("pafs", {})
969+
.get("sigma", 15.0),
970+
output_stride=legacy_config_model.get("heads", {})
971+
.get("multi_instance", {})
972+
.get("pafs", {})
973+
.get("output_stride", 1),
974+
loss_weight=legacy_config_model.get("heads", {})
975+
.get("multi_instance", {})
976+
.get("pafs", {})
977+
.get("loss_weight", None),
978+
),
979+
)
980+
if legacy_config_model.get("heads", {}).get("multi_instance", None)
981+
is not None
982+
else None
983+
),
984+
),
985+
)

0 commit comments

Comments
 (0)