Skip to content

Commit df7ae20

Browse files
authored
Fix ID models (#345)
This PR fixes a few minor issues with TopDown and BottomUp ID models. - The ID models dataset classes were re-computing the tracks from the labels file. However, they should just grab it from the head config `classes` parameter. - Fix shape mismatch issue with BottomUp ID models
1 parent de28b41 commit df7ae20

18 files changed

+116
-61
lines changed

docs/config.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,9 @@ The trainer configuration section controls the training process, including data
689689
### Data Loader Settings
690690
- `train_data_loader`:
691691
- `batch_size`: (int) Number of samples per batch or batch size for training data. **Default**: `1`
692-
- `shuffle`: (bool) True to have the data reshuffled at every epoch. **Default**: `False`
692+
- `shuffle`: (bool) True to have the data reshuffled at every epoch. **Default**: `True`
693693
- `num_workers`: (int) Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. **Default**: `0`
694-
- `val_data_loader`: (Similar to `train_data_loader`)
694+
- `val_data_loader`: (Similar to `train_data_loader`, but `shuffle` is set to `False` by default)
695695

696696
**Example Data Loader configurations:**
697697

docs/sample_configs/config_bottomup_unet_large_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_bottomup_unet_medium_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_centroid_unet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_single_instance_unet_large_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_single_instance_unet_medium_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_topdown_centered_instance_unet_large_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ data_config:
4848
mixup_lambda_min: 0.01
4949
mixup_lambda_max: 0.05
5050
mixup_p: 0.0
51-
skeletons: []
51+
skeletons:
5252
model_config:
5353
init_weights: default
5454
pretrained_backbone_weights: null

docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ model_config:
5959
kernel_size: 3
6060
filters: 16
6161
filters_rate: 1.5
62-
max_stride: 8
62+
max_stride: 16
6363
stem_stride: null
6464
middle_block: true
6565
up_interpolate: true

sleap_nn/config/data_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,10 @@ def data_mapper(legacy_config: dict) -> DataConfig:
466466
data_cfg_args["use_augmentations_train"] = (
467467
True if any(intensity_args.values()) or any(geometric_args.values()) else False
468468
)
469-
data_cfg_args["skeletons"] = skeletons_list
469+
data_cfg_args["skeletons"] = (
470+
skeletons_list
471+
if skeletons_list is not None and len(skeletons_list) > 0
472+
else None
473+
)
470474

471475
return DataConfig(**data_cfg_args)

0 commit comments

Comments
 (0)