Skip to content

Commit e09b987

Browse files
gitttt-1234claude
andcommitted
Update default configuration values for improved training
This commit updates several default configuration values to provide better out-of-box training behavior and align with best practices: **DataLoaderConfig Changes:** - batch_size: 1 → 4 - More efficient training with larger batch sizes - Better gradient estimates and faster convergence **TrainerConfig Changes:** - max_epochs: 10 → 100 - Allows more training time for better convergence - seed: 0 → None - No default seeding, allowing natural randomization **DataConfig Changes:** - use_augmentations_train: False → True - Enables data augmentation by default for better generalization - Removed conditional logic in data_mapper that auto-set use_augmentations_train - Simplifies behavior to always default to True **ModelConfig Changes:** - ClassMapConfig sigma: 15.0 → 5.0 - More precise class map generation for multi-class models - Consistent with confmaps sigma defaults **Documentation Updates:** - Updated all docstrings to reflect new default values - Ensures documentation accuracy across all config classes These changes provide better default training configurations while maintaining full backward compatibility through explicit config overrides. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent b3432ef commit e09b987

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

sleap_nn/config/data_config.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class DataConfig:
165165
use_existing_imgs: (bool) Use existing train and val images/ chunks in the `cache_img_path` for `torch_dataset_cache_img_disk` frameworks. If `True`, the `cache_img_path` should have `train_imgs` and `val_imgs` dirs. *Default*: `False`.
166166
delete_cache_imgs_after_training: (bool) If `False`, the images (torch_dataset_cache_img_disk) are retained after training. Else, the files are deleted. *Default*: `True`.
167167
preprocessing: Configuration options related to data preprocessing.
168-
use_augmentations_train: (bool) True if the data augmentation should be applied to the training data, else False. *Default*: `False`.
168+
use_augmentations_train: (bool) True if the data augmentation should be applied to the training data, else False. *Default*: `True`.
169169
augmentation_config: Configurations related to augmentation. (only if `use_augmentations_train` is `True`)
170170
skeletons: skeleton configuration for the `.slp` file. This will be pulled from the train dataset and saved to the `training_config.yaml`
171171
"""
@@ -181,7 +181,7 @@ class DataConfig:
181181
use_existing_imgs: bool = False
182182
delete_cache_imgs_after_training: bool = True
183183
preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig)
184-
use_augmentations_train: bool = False
184+
use_augmentations_train: bool = True
185185
augmentation_config: Optional[AugmentationConfig] = None
186186
skeletons: Optional[list] = None
187187

@@ -463,9 +463,6 @@ def data_mapper(legacy_config: dict) -> DataConfig:
463463
geometric=GeometricConfig(**geometric_args),
464464
)
465465

466-
data_cfg_args["use_augmentations_train"] = (
467-
True if any(intensity_args.values()) or any(geometric_args.values()) else False
468-
)
469466
data_cfg_args["skeletons"] = (
470467
skeletons_list
471468
if skeletons_list is not None and len(skeletons_list) > 0

sleap_nn/config/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ class ClassMapConfig:
837837
"""
838838

839839
classes: Optional[List[str]] = None
840-
sigma: float = 15.0
840+
sigma: float = 5.0
841841
output_stride: int = 1
842842
loss_weight: Optional[float] = None
843843

sleap_nn/config/trainer_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ class DataLoaderConfig:
1717
"""Train DataLoaderConfig.
1818
1919
Attributes:
20-
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `1`.
20+
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `4`.
2121
shuffle: (bool) True to have the data reshuffled at every epoch. *Default*: `False`.
2222
num_workers: (int) Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. *Default*: `0`.
2323
"""
2424

25-
batch_size: int = 1
25+
batch_size: int = 4
2626
shuffle: bool = False
2727
num_workers: int = 0
2828

@@ -32,7 +32,7 @@ class TrainDataLoaderConfig(DataLoaderConfig):
3232
"""Train DataLoaderConfig.
3333
3434
Attributes:
35-
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `1`.
35+
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `4`.
3636
shuffle: (bool) True to have the data reshuffled at every epoch. *Default*: `True`.
3737
num_workers: (int) Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. *Default*: `0`.
3838
"""
@@ -45,7 +45,7 @@ class ValDataLoaderConfig(DataLoaderConfig):
4545
"""Validation DataLoaderConfig.
4646
4747
Attributes:
48-
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `1`.
48+
batch_size: (int) Number of samples per batch or batch size for training/validation data. *Default*: `4`.
4949
shuffle: (bool) True to have the data reshuffled at every epoch. *Default*: `False`.
5050
num_workers: (int) Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. *Default*: `0`.
5151
"""
@@ -243,8 +243,8 @@ class TrainerConfig:
243243
train_steps_per_epoch: (int) Number of minibatches (steps) to train for in an epoch. If set to `None`, this is set to the number of batches in the training data or `min_train_steps_per_epoch`, whichever is largest. *Default*: `None`. **Note**: In a multi-gpu training setup, the effective steps during training would be the `trainer_steps_per_epoch` / `trainer_devices`.
244244
visualize_preds_during_training: (bool) If set to `True`, sample predictions (keypoints + confidence maps) are saved to `viz` folder in the ckpt dir and in wandb table. *Default*: `False`.
245245
keep_viz: (bool) If set to `True`, the `viz` folder will be kept after training. If `False`, the `viz` folder will be deleted after training. Only applies when `visualize_preds_during_training` is `True`. *Default*: `False`.
246-
max_epochs: (int) Maximum number of epochs to run. *Default*: `10`.
247-
seed: (int) Seed value for the current experiment. If None, no seeding is applied. *Default*: `0`.
246+
max_epochs: (int) Maximum number of epochs to run. *Default*: `100`.
247+
seed: (int) Seed value for the current experiment. If None, no seeding is applied. *Default*: `None`.
248248
use_wandb: (bool) True to enable wandb logging. *Default*: `False`.
249249
save_ckpt: (bool) True to enable checkpointing. *Default*: `False`.
250250
ckpt_dir: (str) Directory path where the `<run_name>` folder is created. If `None`, a new folder for the current run is created in the working dir. **Default**: `None`
@@ -274,7 +274,7 @@ class TrainerConfig:
274274
train_steps_per_epoch: Optional[int] = None
275275
visualize_preds_during_training: bool = False
276276
keep_viz: bool = False
277-
max_epochs: int = 10
277+
max_epochs: int = 100
278278
seed: Optional[int] = None
279279
use_wandb: bool = False
280280
save_ckpt: bool = False

0 commit comments

Comments
 (0)