Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
53345e4
outline for loading old sleap config
gqcpm Mar 18, 2025
369dcff
initial mappers function
gqcpm Mar 20, 2025
5649f45
check geometric config and comment out all parameters not addressed i…
gqcpm Mar 21, 2025
cdc2c99
model_config map backbone, heads, and lint other files
gqcpm Mar 22, 2025
d5c871b
map train_loader, val_loader, ckpt, max_epochs, optimizer, lr_schedul…
gqcpm Mar 27, 2025
e642b41
test trainer_mapper function
gqcpm Mar 28, 2025
69a65d4
model_config tests and legacy config path adjustments
gqcpm Mar 28, 2025
3e86b38
add tests for data config and adjust paths to map legacy config
gqcpm Mar 28, 2025
ac9327c
test training_job_config with json file
gqcpm Apr 1, 2025
d4edbaa
linted.
gqcpm Apr 1, 2025
70a5052
uniformnoisemax set to max 1, missingerror on no train/val path
gqcpm Apr 3, 2025
4b82c4c
add different head types and adjust with tests
gqcpm Apr 3, 2025
e391eb2
output omegaconf, check for mandatory values, change default values
gqcpm Apr 3, 2025
4ab26ad
docstrings
gqcpm Apr 4, 2025
a198136
adjust skeletons parameter
gqcpm Apr 4, 2025
9293589
change docstring and headers
gqcpm Apr 4, 2025
a4ab4dd
rename file and function for bottomup_multiclass
gqcpm Apr 4, 2025
e91cc6d
move sleap configs for tests to folder. Update test path to access ne…
gqcpm Apr 15, 2025
ea32f9f
change load_sleap_config to be class method of TrainingJobConfig
gqcpm Apr 15, 2025
f447d37
remove missing parameter comments and put in separate doc
gqcpm Apr 15, 2025
64b53fc
Minor fixes to mappers
May 22, 2025
e5d25a6
Merge branch 'main' into greg/map-old-sleap-config-files-to-new
gitttt-1234 May 22, 2025
482c8cd
Fix tests
May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions sleap_nn/config/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,59 @@ class DataConfig:
use_augmentations_train: bool = False
augmentation_config: Optional[AugmentationConfig] = None
skeletons: Optional[dict] = None

def data_mapper(legacy_config: dict) -> DataConfig:
return DataConfig(
# train_labels_path=legacy_config.get("train_labels_path", MISSING),
# val_labels_path=legacy_config.get("val_labels_path", MISSING),
# test_file_path=legacy_config.get("test_file_path"),
# provider=legacy_config.get("provider", "LabelsReader"),
# user_instances_only=legacy_config.get("user_instances_only", True),
# data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"),
# np_chunks_path=legacy_config.get("np_chunks_path"),
# litdata_chunks_path=legacy_config.get("litdata_chunks_path"),
# use_existing_chunks=legacy_config.get("use_existing_chunks", False),
# chunk_size=int(legacy_config.get("chunk_size", 100)),
# delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True),
preprocessing=PreprocessingConfig(
is_rgb=legacy_config.get("data", {}).get("preprocessing", {}).get("ensure_rgb", False),
max_height=legacy_config.get("data", {}).get("preprocessing", {}).get("target_height"),
max_width=legacy_config.get("data", {}).get("preprocessing", {}).get("target_width"),
scale=legacy_config.get("data", {}).get("preprocessing", {}).get("input_scaling", 1.0),
crop_hw=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size"),
min_crop_size=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size_detection_padding", 100),
),
# use_augmentations_train=legacy_config.get("use_augmentations_train", False),
augmentation_config=AugmentationConfig(
intensity=IntensityConfig(
uniform_noise_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise_min_val", 0.0),
uniform_noise_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise_max_val", 1.0),
uniform_noise_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise", 1.0),
gaussian_noise_mean=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("gaussian_noise_mean", 0.0),
gaussian_noise_std=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("gaussian_noise_stddev", 1.0),
gaussian_noise_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("gaussian_noise", 1.0),
contrast_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("contrast_min_gamma", 0.5),
contrast_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("contrast_max_gamma", 2.0),
contrast_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("contrast", 1.0),
brightness=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("brightness_min_val", 1.0),
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("brightness_max_val", 1.0)),
brightness_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("brightness", 1.0),
),
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0),
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001),
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01),
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0),
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0),
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0),
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]),
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
)
) if legacy_config.get("use_augmentations_train", False) else None,
skeletons=legacy_config.get("skeletons"),
)
46 changes: 46 additions & 0 deletions sleap_nn/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,49 @@ def validate_pre_trained_weights(self, value):
message = "UNet does not support pre-trained weights."
logger.error(message)
raise ValueError(message)

def model_mapper(legacy_config: dict) -> ModelConfig:
return ModelConfig(
init_weights=legacy_config.get("init_weights", "default"),
# pre_trained_weights not in old config
# pretrained_backbone_weights not in old config
# pretrained_head_weights not in old config
backbone_config=BackboneConfig(
unet=UNetConfig(
# in_channels=legacy_config.get("backbone", {}).get("in_channels", 1),
# kernel_size=legacy_config.get("backbone", {}).get("kernel_size", 3),
filters=legacy_config.get("backbone", {}).get("filters", 32),
filters_rate=legacy_config.get("backbone", {}).get("filters_rate", 1.5),
max_stride=legacy_config.get("backbone", {}).get("max_stride", 16),
stem_stride=None, # stem_stride not in legacy
middle_block=legacy_config.get("backbone", {}).get("middle_block", True),
up_interpolate=legacy_config.get("backbone", {}).get("up_interpolate", True),
stacks=legacy_config.get("backbone", {}).get("stacks", 1),
convs_per_block=2,
output_stride=legacy_config.get("backbone", {}).get("output_stride", 1),
) if legacy_config.get("backbone_type") == "unet" else None,
# convnext not in old config
# swint not in old config
),
head_configs=HeadConfig(
single_instance=SingleInstanceConfig(
confmaps=SingleInstanceConfMapsConfig(
part_names=legacy_config.get("heads", {}).get("part_names"),
sigma=legacy_config.get("heads", {}).get("sigma", 5.0),
output_stride=legacy_config.get("heads", {}).get("output_stride", 1),
)
) if legacy_config.get("head_type") == "single_instance" else None,
centroid = CentroidConfig(
confmaps = CentroidConfMapsConfig(
anchor_part = legacy_config.get("CentroidsHeadConfig",{}).get("anchor_part"),
sigma = legacy_config.get("CentroidsHeadConfig",{}).get("sigma"),
output_stride = legacy_config.get("CentroidsHeadConfig",{}).get("output_stride"),
)
)

# Other head types not in old config
# centered_instance=None,
# bottomup=None,
),
# total_params calculated during training
)
14 changes: 14 additions & 0 deletions sleap_nn/config/training_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
from typing import Text, Optional
from omegaconf import OmegaConf
import sleap_nn
import json
from sleap_nn.config.data_config import DataConfig
from sleap_nn.config.data_config import data_mapper
from sleap_nn.config.model_config import ModelConfig
from sleap_nn.config.model_config import model_mapper
from sleap_nn.config.trainer_config import TrainerConfig
from sleap_nn.config.trainer_config import trainer_mapper
from sleap_nn.config.utils import get_output_strides_from_heads


Expand Down Expand Up @@ -146,3 +150,13 @@ def load_config(filename: Text, load_training_config: bool = True) -> OmegaConf:
The parsed `OmegaConf`.
"""
return TrainingJobConfig.load_yaml(filename)

def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
with open(json_file_path, "r") as f:
old_config = json.load(f)

data_config = data_mapper(old_config)
model_config = model_mapper(old_config)
trainer_config = trainer_mapper(old_config)

return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)
Loading