Skip to content

Commit 64b53fc

Browse files
gitttt-1234gitttt-1234
authored andcommitted
Minor fixes to mappers
1 parent f447d37 commit 64b53fc

File tree

6 files changed

+153
-213
lines changed

6 files changed

+153
-213
lines changed

sleap_nn/config/data_config.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ class DataConfig:
170170
train dataset and saved to the `training_config.yaml`
171171
"""
172172

173-
train_labels_path: str = MISSING
174-
val_labels_path: str = MISSING
173+
train_labels_path: Optional[str] = None
174+
val_labels_path: Optional[str] = None # TODO : revisit MISSING!
175175
test_file_path: Optional[str] = None
176176
provider: str = "LabelsReader"
177177
user_instances_only: bool = True
@@ -201,20 +201,35 @@ def data_mapper(legacy_config: dict) -> DataConfig:
201201

202202
return DataConfig(
203203
train_labels_path=legacy_config_data.get("labels", {}).get(
204-
"training_labels", MISSING
204+
"training_labels", None
205205
),
206206
val_labels_path=legacy_config_data.get("labels", {}).get(
207-
"validation_labels", MISSING
207+
"validation_labels", None
208208
),
209209
test_file_path=legacy_config_data.get("labels", {}).get("test_labels", None),
210210
preprocessing=PreprocessingConfig(
211211
is_rgb=legacy_config_data.get("preprocessing", {}).get("ensure_rgb", False),
212-
max_height=legacy_config_data.get("preprocessing", {}).get("target_height"),
213-
max_width=legacy_config_data.get("preprocessing", {}).get("target_width"),
212+
max_height=legacy_config_data.get("preprocessing", {}).get(
213+
"target_height", None
214+
),
215+
max_width=legacy_config_data.get("preprocessing", {}).get(
216+
"target_width", None
217+
),
214218
scale=legacy_config_data.get("preprocessing", {}).get("input_scaling", 1.0),
215-
crop_hw=legacy_config_data.get("preprocessing", {}).get("crop_size"),
216-
min_crop_size=legacy_config_data.get("preprocessing", {}).get(
217-
"crop_size_detection_padding", 100
219+
crop_hw=(
220+
(
221+
legacy_config_data.get("instance_cropping", {}).get(
222+
"crop_size", None
223+
),
224+
legacy_config_data.get("instance_cropping", {}).get(
225+
"crop_size", None
226+
),
227+
)
228+
if legacy_config_data.get("instance_cropping", {}).get(
229+
"crop_size", None
230+
)
231+
is not None
232+
else None
218233
),
219234
),
220235
augmentation_config=(
@@ -271,20 +286,48 @@ def data_mapper(legacy_config: dict) -> DataConfig:
271286
),
272287
),
273288
geometric=GeometricConfig(
274-
rotation=legacy_config_optimization.get(
275-
"augmentation_config", {}
276-
).get("rotation_max_angle", 180.0),
277-
scale=(
278-
legacy_config_optimization.get("augmentation_config", {}).get(
279-
"scale_min", None
280-
),
289+
rotation=(
281290
legacy_config_optimization.get("augmentation_config", {}).get(
282-
"scale_max", None
283-
),
291+
"rotation_max_angle", 15.0
292+
)
293+
if legacy_config_optimization.get(
294+
"augmentation_config", {}
295+
).get("rotate", True)
296+
else 0
297+
),
298+
scale=(
299+
(
300+
legacy_config_optimization.get(
301+
"augmentation_config", {}
302+
).get("scale_min", 0.9),
303+
legacy_config_optimization.get(
304+
"augmentation_config", {}
305+
).get("scale_max", 1.1),
306+
)
307+
if legacy_config_optimization.get(
308+
"augmentation_config", {}
309+
).get("scale", False)
310+
else (1.0, 1.0)
311+
),
312+
affine_p=(
313+
1.0
314+
if any(
315+
[
316+
legacy_config_optimization.get(
317+
"augmentation_config", {}
318+
).get("rotate", True),
319+
legacy_config_optimization.get(
320+
"augmentation_config", {}
321+
).get("scale", False),
322+
]
323+
)
324+
else 0.0
284325
),
285326
),
286327
)
287328
),
288329
use_augmentations_train=True,
289-
skeletons=legacy_config_data.get("labels", {}).get("skeletons", [{}])[0],
330+
skeletons=legacy_config_data.get("labels", {}).get("skeletons", [{}])[
331+
0
332+
], # TODO
290333
)

sleap_nn/config/model_config.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
874874
.get("unet", {})
875875
.get("output_stride", 1),
876876
)
877-
if legacy_config_model.get("backbone", {}).get("unet")
877+
if legacy_config_model.get("backbone", {}).get("unet", None) is not None
878878
else None
879879
),
880880
),
@@ -885,7 +885,7 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
885885
confmaps=SingleInstanceConfMapsConfig(
886886
part_names=legacy_config_model.get("heads", {})
887887
.get("single_instance", {})
888-
.get("part_names"),
888+
.get("part_names", None),
889889
sigma=legacy_config_model.get("heads", {})
890890
.get("single_instance", {})
891891
.get("sigma", 5.0),
@@ -895,15 +895,16 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
895895
)
896896
)
897897
)
898-
if legacy_config_model.get("heads", {}).get("single_instance")
898+
if legacy_config_model.get("heads", {}).get("single_instance", None)
899+
is not None
899900
else None
900901
),
901902
centroid=(
902903
CentroidConfig(
903904
confmaps=CentroidConfMapsConfig(
904905
anchor_part=legacy_config_model.get("heads", {})
905906
.get("centroid", {})
906-
.get("anchor_part"),
907+
.get("anchor_part", None),
907908
sigma=legacy_config_model.get("heads", {})
908909
.get("centroid", {})
909910
.get("sigma", 5.0),
@@ -912,15 +913,16 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
912913
.get("output_stride", 1),
913914
)
914915
)
915-
if legacy_config_model.get("heads", {}).get("centroid")
916+
if legacy_config_model.get("heads", {}).get("centroid", None)
917+
is not None
916918
else None
917919
),
918920
centered_instance=(
919921
CenteredInstanceConfig(
920922
confmaps=CenteredInstanceConfMapsConfig(
921923
anchor_part=legacy_config_model.get("heads", {})
922924
.get("centered_instance", {})
923-
.get("anchor_part"),
925+
.get("anchor_part", None),
924926
sigma=legacy_config_model.get("heads", {})
925927
.get("centered_instance", {})
926928
.get("sigma", 5.0),
@@ -932,7 +934,8 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
932934
.get("part_names", None),
933935
)
934936
)
935-
if legacy_config_model.get("heads", {}).get("centered_instance")
937+
if legacy_config_model.get("heads", {}).get("centered_instance", None)
938+
is not None
936939
else None
937940
),
938941
bottomup=(
@@ -974,7 +977,8 @@ def model_mapper(legacy_config: dict) -> ModelConfig:
974977
.get("loss_weight", None),
975978
),
976979
)
977-
if legacy_config_model.get("heads", {}).get("multi_instance")
980+
if legacy_config_model.get("heads", {}).get("multi_instance", None)
981+
is not None
978982
else None
979983
),
980984
),

sleap_nn/config/trainer_config.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class TrainerConfig:
205205
default="Adam",
206206
validator=lambda inst, attr, val: TrainerConfig.validate_optimizer_name(val),
207207
)
208-
optimizer: OptimizerConfig = OptimizerConfig()
208+
optimizer: OptimizerConfig = field(factory=OptimizerConfig)
209209
lr_scheduler: Optional[LRSchedulerConfig] = None
210210
early_stopping: Optional[EarlyStoppingConfig] = None
211211

@@ -249,29 +249,39 @@ def trainer_mapper(legacy_config: dict) -> TrainerConfig:
249249
train_data_loader=DataLoaderConfig(
250250
batch_size=legacy_config_optimization.get("batch_size", 1),
251251
shuffle=legacy_config_optimization.get("online_shuffling", False),
252+
num_workers=1,
252253
),
253254
val_data_loader=DataLoaderConfig(
254-
batch_size=legacy_config_optimization.get("batch_size", 1),
255-
shuffle=legacy_config_optimization.get("online_shuffling", False),
255+
batch_size=legacy_config_optimization.get("batch_size", 1), num_workers=1
256256
),
257257
model_ckpt=ModelCkptConfig(
258-
save_last=legacy_config_outputs.get("save_outputs", False),
258+
save_last=legacy_config_outputs.get("checkpointing", {}).get(
259+
"latest_model", False
260+
),
259261
),
260262
max_epochs=legacy_config_optimization.get("epochs", 10),
261-
save_ckpt=legacy_config_optimization.get("checkpointing", {}).get(
262-
"latest_model", False
263-
),
263+
save_ckpt=True,
264+
save_ckpt_path=legacy_config_outputs.get("runs_folder", None),
264265
optimizer_name=re.sub(
265266
r"^[a-z]",
266267
lambda x: x.group().upper(),
267268
legacy_config_optimization.get("optimizer", "adam"),
268269
),
269270
optimizer=OptimizerConfig(
270-
lr=legacy_config_optimization.get("initial_learning_rate", 1e-3),
271+
lr=legacy_config_optimization.get("initial_learning_rate", 1e-4),
271272
),
272273
lr_scheduler=(
273274
LRSchedulerConfig(
274275
reduce_lr_on_plateau=ReduceLROnPlateauConfig(
276+
threshold=legacy_config_optimization.get(
277+
"learning_rate_schedule", {}
278+
).get("plateau_min_delta", 1e-4),
279+
cooldown=legacy_config_optimization.get(
280+
"learning_rate_schedule", {}
281+
).get("plateau_cooldown", 3),
282+
factor=legacy_config_optimization.get(
283+
"learning_rate_schedule", {}
284+
).get("reduction_factor", 0.1),
275285
patience=legacy_config_optimization.get(
276286
"learning_rate_schedule", {}
277287
).get("plateau_patience", 10),
@@ -280,22 +290,24 @@ def trainer_mapper(legacy_config: dict) -> TrainerConfig:
280290
).get("min_learning_rate", 0.0),
281291
)
282292
)
283-
if legacy_config_optimization.get("learning_rate_schedule")
293+
if legacy_config_optimization.get("learning_rate_schedule", {}).get(
294+
"reduce_on_plateau", False
295+
)
284296
else None
285297
),
286298
early_stopping=(
287299
EarlyStoppingConfig(
288300
stop_training_on_plateau=legacy_config_optimization.get(
289-
"learning_rate_schedule", {}
290-
).get("reduce_on_plateau", False),
291-
min_delta=legacy_config_optimization.get(
292-
"learning_rate_schedule", {}
293-
).get("plateau_min_delta", 0.0),
294-
patience=legacy_config_optimization.get(
295-
"learning_rate_schedule", {}
296-
).get("plateau_patience", 1),
301+
"early_stopping", {}
302+
).get("stop_training_on_plateau", False),
303+
min_delta=legacy_config_optimization.get("early_stopping", {}).get(
304+
"plateau_min_delta", 0.0
305+
),
306+
patience=legacy_config_optimization.get("early_stopping", {}).get(
307+
"plateau_patience", 1
308+
),
297309
)
298-
if legacy_config_optimization.get("learning_rate_schedule")
310+
if legacy_config_optimization.get("early_stopping")
299311
else None
300312
),
301313
)

tests/config/test_data_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ def test_data_mapper():
138138
"target_height": 256,
139139
"target_width": 256,
140140
"input_scaling": 0.5,
141-
"crop_size": (100, 100),
142-
"crop_size_detection_padding": 50,
143141
},
144142
},
145143
"optimization": {
@@ -157,8 +155,10 @@ def test_data_mapper():
157155
"brightness_max_val": 1.2,
158156
"brightness": 0.6,
159157
"rotation_max_angle": 90.0,
158+
"rotation": True,
160159
"scale_min": 0.8,
161160
"scale_max": 1.2,
161+
"scale": False,
162162
},
163163
},
164164
}
@@ -170,8 +170,8 @@ def test_data_mapper():
170170
assert config.preprocessing.max_height == 256
171171
assert config.preprocessing.max_width == 256
172172
assert config.preprocessing.scale == 0.5
173-
assert config.preprocessing.crop_hw == (100, 100)
174-
assert config.preprocessing.min_crop_size == 50
173+
assert config.preprocessing.crop_hw is None
174+
assert config.preprocessing.min_crop_size == 100
175175

176176
# Test augmentation config
177177
assert config.use_augmentations_train is True
@@ -194,7 +194,7 @@ def test_data_mapper():
194194
# Test geometric config
195195
geometric = config.augmentation_config.geometric
196196
assert geometric.rotation == 90.0
197-
assert geometric.scale == (0.8, 1.2)
197+
assert geometric.scale == (1.0, 1.0)
198198

199199
# Test skeletons
200200
assert config.skeletons == {"edges": [[0, 1], [1, 2]]}

tests/config/test_trainer_config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ def test_trainer_mapper():
252252
"optimizer": "Adam",
253253
"initial_learning_rate": 0.001,
254254
"checkpointing": {},
255+
"early_stopping": {
256+
"stop_training_on_plateau": True,
257+
"plateau_min_delta": 1e-06,
258+
"plateau_patience": 10,
259+
},
255260
},
256261
"outputs": {
257262
"save_outputs": True,
@@ -263,11 +268,11 @@ def test_trainer_mapper():
263268
# Assertions to check if the output matches expected values
264269
assert config.train_data_loader.batch_size == 32
265270
assert config.train_data_loader.shuffle is True
266-
assert config.train_data_loader.num_workers == 0 # Default value
271+
assert config.train_data_loader.num_workers == 1
267272
assert config.max_epochs == 20
268273
assert config.optimizer_name == "Adam"
269274
assert config.optimizer.lr == 0.001
270-
assert config.model_ckpt.save_last is True
275+
assert config.model_ckpt.save_last is False
271276

272277
# Test for default values (unspecified by legacy config)
273278
assert config.trainer_devices == "auto"
@@ -276,7 +281,7 @@ def test_trainer_mapper():
276281
assert config.steps_per_epoch is None
277282
assert config.seed is None
278283
assert config.use_wandb is False
279-
assert config.save_ckpt is False
284+
assert config.save_ckpt is True
280285
assert config.save_ckpt_path is None
281286
assert config.resume_ckpt_path is None
282287
assert config.wandb.entity is None
@@ -291,6 +296,6 @@ def test_trainer_mapper():
291296
assert config.lr_scheduler.reduce_lr_on_plateau.patience == 5
292297
assert config.lr_scheduler.reduce_lr_on_plateau.min_lr == 0.0001
293298
assert config.early_stopping is not None
294-
assert config.early_stopping.patience == 5
295-
assert config.early_stopping.min_delta == 0.01
299+
assert config.early_stopping.patience == 10
300+
assert config.early_stopping.min_delta == 1e-6
296301
assert config.early_stopping.stop_training_on_plateau is True

0 commit comments

Comments
 (0)