Skip to content

Map legacy SLEAP json configs to SLEAP-NN OmegaConf objects #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
53345e4
outline for loading old sleap config
gqcpm Mar 18, 2025
369dcff
initial mappers function
gqcpm Mar 20, 2025
5649f45
check geometric config and comment out all parameters not addressed i…
gqcpm Mar 21, 2025
cdc2c99
model_config map backbone, heads, and lint other files
gqcpm Mar 22, 2025
d5c871b
map train_loader, val_loader, ckpt, max_epochs, optimizer, lr_schedul…
gqcpm Mar 27, 2025
e642b41
test trainer_mapper function
gqcpm Mar 28, 2025
69a65d4
model_config tests and legacy config path adjustments
gqcpm Mar 28, 2025
3e86b38
add tests for data config and adjust paths to map legacy config
gqcpm Mar 28, 2025
ac9327c
test training_job_config with json file
gqcpm Apr 1, 2025
d4edbaa
linted.
gqcpm Apr 1, 2025
70a5052
uniformnoisemax set to max 1, missingerror on no train/val path
gqcpm Apr 3, 2025
4b82c4c
add different head types and adjust with tests
gqcpm Apr 3, 2025
e391eb2
output omegaconf, check for mandatory values, change default values
gqcpm Apr 3, 2025
4ab26ad
docstrings
gqcpm Apr 4, 2025
a198136
adjust skeletons parameter
gqcpm Apr 4, 2025
9293589
change docstring and headers
gqcpm Apr 4, 2025
a4ab4dd
rename file and function for bottomup_multiclass
gqcpm Apr 4, 2025
e91cc6d
move sleap configs for tests to folder. Update test path to access ne…
gqcpm Apr 15, 2025
ea32f9f
change load_sleap_config to be class method of TrainingJobConfig
gqcpm Apr 15, 2025
f447d37
remove missing parameter comments and put in separate doc
gqcpm Apr 15, 2025
64b53fc
Minor fixes to mappers
May 22, 2025
e5d25a6
Merge branch 'main' into greg/map-old-sleap-config-files-to-new
gitttt-1234 May 22, 2025
482c8cd
Fix tests
May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions sleap_nn/config/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from attrs import define, field, validators
from omegaconf import MISSING
from omegaconf import MISSING, MissingMandatoryValue
from typing import Optional, Tuple, Any, List
from loguru import logger

Expand Down Expand Up @@ -184,4 +184,129 @@ class DataConfig:
preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig)
use_augmentations_train: bool = False
augmentation_config: Optional[AugmentationConfig] = None
skeletons: Optional[dict] = None
skeletons: Optional[list] = None


def data_mapper(legacy_config: dict) -> DataConfig:
return DataConfig(
train_labels_path=legacy_config.get("data", {})
.get("labels", {})
.get("training_labels", MISSING),
val_labels_path=legacy_config.get("data", {})
.get("labels", {})
.get("validation_labels", MISSING),
test_file_path=legacy_config.get("data", {})
.get("labels", {})
.get("test_labels", None),
# provider=legacy_config.get("provider", "LabelsReader"),
# user_instances_only=legacy_config.get("user_instances_only", True),
# data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"),
# np_chunks_path=legacy_config.get("np_chunks_path"),
# litdata_chunks_path=legacy_config.get("litdata_chunks_path"),
# use_existing_chunks=legacy_config.get("use_existing_chunks", False),
# chunk_size=int(legacy_config.get("chunk_size", 100)),
# delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True),
preprocessing=PreprocessingConfig(
is_rgb=legacy_config.get("data", {})
.get("preprocessing", {})
.get("ensure_rgb", False),
max_height=legacy_config.get("data", {})
.get("preprocessing", {})
.get("target_height"),
max_width=legacy_config.get("data", {})
.get("preprocessing", {})
.get("target_width"),
scale=legacy_config.get("data", {})
.get("preprocessing", {})
.get("input_scaling", 1.0),
crop_hw=legacy_config.get("data", {})
.get("preprocessing", {})
.get("crop_size"),
min_crop_size=legacy_config.get("data", {})
.get("preprocessing", {})
.get("crop_size_detection_padding", 100),
),
# use_augmentations_train=legacy_config.get("use_augmentations_train", False),
augmentation_config=(
AugmentationConfig(
intensity=IntensityConfig(
uniform_noise_min=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("uniform_noise_min_val", 0.0),
uniform_noise_max=min(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("uniform_noise_max_val", 1.0),
1.0,
),
uniform_noise_p=float(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("uniform_noise", 1.0)
),
gaussian_noise_mean=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("gaussian_noise_mean", 0.0),
gaussian_noise_std=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("gaussian_noise_stddev", 1.0),
gaussian_noise_p=float(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("gaussian_noise", 1.0)
),
contrast_min=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("contrast_min_gamma", 0.5),
contrast_max=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("contrast_max_gamma", 2.0),
contrast_p=float(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("contrast", 1.0)
),
brightness=(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("brightness_min_val", 1.0),
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("brightness_max_val", 1.0),
),
brightness_p=float(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("brightness", 1.0)
),
),
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("rotation_max_angle", 180.0),
scale=(
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("scale_min", None),
legacy_config.get("optimization", {})
.get("augmentation_config", {})
.get("scale_max", None),
),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0),
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001),
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01),
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0),
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0),
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0),
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]),
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
),
)
# if legacy_config.get("use_augmentations_train", False)
# else None
),
use_augmentations_train=True,
skeletons=legacy_config.get("data", {}).get("labels", {}).get("skeletons"),
)
176 changes: 176 additions & 0 deletions sleap_nn/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,179 @@ def validate_pre_trained_weights(self, value):
message = "UNet does not support pre-trained weights."
logger.error(message)
raise ValueError(message)


def model_mapper(legacy_config: dict) -> ModelConfig:
return ModelConfig(
# init_weights=legacy_config.get("init_weights", "default"),
# pre_trained_weights not in old config
# pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")?? # i think its different
# pretrained_head_weights not in old config
backbone_config=BackboneConfig(
unet=(
UNetConfig(
# in_channels=legacy_config.get("backbone", {}).get("in_channels", 1),
# kernel_size=legacy_config.get("backbone", {}).get("kernel_size", 3),
filters=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("filters", 32),
filters_rate=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("filters_rate", 1.5),
max_stride=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("max_stride", 16),
stem_stride=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("stem_stride", 16),
middle_block=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("middle_block", True),
up_interpolate=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("up_interpolate", True),
stacks=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("stacks", 1),
# convs_per_block=2,
output_stride=legacy_config.get("model", {})
.get("backbone", {})
.get("unet", {})
.get("output_stride", 1),
)
if legacy_config.get("model", {}).get("backbone", {}).get("unet")
else None
),
# convnext not in old config
# swint not in old config
),
head_configs=HeadConfig(
single_instance=(
(
SingleInstanceConfig(
confmaps=SingleInstanceConfMapsConfig(
part_names=legacy_config.get("model", {})
.get("heads", {})
.get("single_instance", {})
.get("part_names"),
sigma=legacy_config.get("model", {})
.get("heads", {})
.get("single_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("model", {})
.get("heads", {})
.get("single_instance", {})
.get("output_stride", 1),
)
)
)
if legacy_config.get("model", {})
.get("heads", {})
.get("single_instance")
else None
),
centroid=(
CentroidConfig(
confmaps=CentroidConfMapsConfig(
anchor_part=legacy_config.get("model", {})
.get("heads", {})
.get("centroid", {})
.get("anchor_part"),
sigma=legacy_config.get("model", {})
.get("heads", {})
.get("centroid", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("model", {})
.get("heads", {})
.get("centroid", {})
.get("output_stride", 1),
)
)
if legacy_config.get("model", {}).get("heads", {}).get("centroid")
else None
),
centered_instance=(
CenteredInstanceConfig(
confmaps=CenteredInstanceConfMapsConfig(
anchor_part=legacy_config.get("model", {})
.get("heads", {})
.get("centered_instance", {})
.get("anchor_part"),
sigma=legacy_config.get("model", {})
.get("heads", {})
.get("centered_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("model", {})
.get("heads", {})
.get("centered_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("model", {})
.get("heads", {})
.get("centered_instance", {})
.get("part_names", None),
)
)
if legacy_config.get("model", {})
.get("heads", {})
.get("centered_instance")
else None
),
bottomup=(
BottomUpConfig(
confmaps=BottomUpConfMapsConfig(
loss_weight=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("confmaps", {})
.get("loss_weight", None),
sigma=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("confmaps", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("confmaps", {})
.get("output_stride", 1),
part_names=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("confmaps", {})
.get("part_names", None),
),
pafs=PAFConfig(
edges=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("pafs", {})
.get("edges", None),
sigma=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("pafs", {})
.get("sigma", 15.0),
output_stride=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("pafs", {})
.get("output_stride", 1),
loss_weight=legacy_config.get("model", {})
.get("heads", {})
.get("multi_instance", {})
.get("pafs", {})
.get("loss_weight", None),
),
)
if legacy_config.get("model", {}).get("heads", {}).get("multi_instance")
else None
),
),
)
Loading
Loading