Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

outline for loading old sleap config #162

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

gqcpm
Copy link
Contributor

@gqcpm gqcpm commented Mar 18, 2025

Summary by CodeRabbit

  • New Features
    • Introduced support for JSON configuration files, enabling users to set up training configurations using JSON format alongside existing options for greater flexibility.
    • Added functions to map legacy configurations for data and model setups, streamlining the transition to the new configuration format.
    • Enhanced configuration handling with new mapping functions for data and model configurations.

Copy link
Contributor

coderabbitai bot commented Mar 18, 2025

Walkthrough

This change introduces three new functions across different configuration modules: load_sleap_config in training_job_config.py, data_mapper in data_config.py, and model_mapper in model_config.py. The load_sleap_config function reads and processes a JSON configuration file, while the data_mapper and model_mapper functions convert legacy configuration dictionaries into structured configuration objects. These updates enhance the ability to handle JSON configurations alongside existing YAML formats.

Changes

File Change Summary
sleap_nn/config/…/training_job_config.py Added load_sleap_config function for loading and processing JSON configuration; introduced imports for json and mapping functions.
sleap_nn/config/…/data_config.py Added data_mapper function to convert legacy configuration dictionary to DataConfig instance.
sleap_nn/config/…/model_config.py Added model_mapper function to convert legacy configuration dictionary to ModelConfig instance.

Possibly related issues

Poem

I hopped through lines of code with glee,
A JSON file beckoned, as clear as could be.
Mapping data, model, and trainer so neat,
My config now sings a rhythmic beat!
Hop on the branch, and enjoy this treat! 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
sleap_nn/config/training_job_config.py (3)

154-162: Missing docstring for new function

Unlike other functions in this file, load_sleap_config lacks a docstring describing its purpose, parameters, and return value.

Add a descriptive docstring:

def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+    """Load a training job configuration from a legacy SLEAP JSON config file.
+
+    Arguments:
+        json_file_path: Path to a legacy SLEAP JSON configuration file.
+
+    Returns:
+        A TrainingJobConfig instance containing the mapped configuration data.
+    """
    with open(json_file_path, "r") as f:
        old_config = json.load(f)

158-161: Missing error handling for mapper functions

There's no error handling for potential exceptions in the mapper functions.

Add error handling for the mapper functions:

-    data_config = data_mapper(old_config)
-    model_config = model_mapper(old_config)
-    trainer_config = trainer_mapper(old_config)
+    try:
+        data_config = data_mapper(old_config)
+        model_config = model_mapper(old_config)
+        trainer_config = trainer_mapper(old_config)
+    except Exception as e:
+        raise ValueError(f"Failed to map configuration from {json_file_path}: {str(e)}")

162-162: Filename attribute not set in the created config

Unlike the YAML loading methods, this function doesn't set the filename attribute to the source file path.

Set the filename attribute in the returned config:

-    return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)
+    return cls(
+        data_config=data_config, 
+        model_config=model_config, 
+        trainer_config=trainer_config,
+        filename=json_file_path
+    )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 34aa25d and 53345e4.

📒 Files selected for processing (1)
  • sleap_nn/config/training_job_config.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint

Comment on lines 155 to 157
with open(json_file_path, "r") as f:
old_config = json.load(f)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Missing error handling for file operations and JSON parsing

The function doesn't handle potential exceptions when opening the file or parsing the JSON content.

Add error handling for file operations and JSON parsing:

-    with open(json_file_path, "r") as f:
-        old_config = json.load(f)
+    try:
+        with open(json_file_path, "r") as f:
+            old_config = json.load(f)
+    except FileNotFoundError:
+        raise FileNotFoundError(f"Configuration file not found: {json_file_path}")
+    except json.JSONDecodeError:
+        raise ValueError(f"Invalid JSON in configuration file: {json_file_path}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with open(json_file_path, "r") as f:
old_config = json.load(f)
try:
with open(json_file_path, "r") as f:
old_config = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found: {json_file_path}")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON in configuration file: {json_file_path}")

model_config = model_mapper(old_config)
trainer_config = trainer_mapper(old_config)

return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Missing validation for loaded configuration

The function doesn't apply the check_output_strides validation that's performed in the YAML loading methods.

Consider whether the check_output_strides validation should be applied to JSON loaded configurations as well. This validation ensures compatibility between backbone and head configurations.


🏁 Script executed:

#!/bin/bash
# Check if check_output_strides is applied in any other loaders
grep -rn "check_output_strides" --include="*.py" ./

Length of output: 379


Action Required: Apply Missing check_output_strides Validation in JSON Loading

It appears that the JSON configuration loading path in sleap_nn/config/training_job_config.py (culminating in the return statement at line 162) does not invoke the check_output_strides validation. In contrast, YAML loaders in this file call TrainingJobConfig.check_output_strides(config) (as seen on lines 99 and 117). For consistency and to ensure compatibility between backbone and head configurations, please consider applying the same validation for JSON-loaded configurations.

  • Evaluate adding a call to TrainingJobConfig.check_output_strides(config) right before the instance initialization in the JSON loader.
  • Confirm that this adjustment aligns with the intended configuration compatibility checks across all loaders.

Comment on lines 154 to 162
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
with open(json_file_path, "r") as f:
old_config = json.load(f)

data_config = data_mapper(old_config)
model_config = model_mapper(old_config)
trainer_config = trainer_mapper(old_config)

return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return type annotation doesn't match actual implementation

The function signature indicates it returns a TrainerConfig, but it actually returns an instance of cls which contains data_config, model_config, and trainer_config objects.

-def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
with open(json_file_path, "r") as f:
old_config = json.load(f)
data_config = data_mapper(old_config)
model_config = model_mapper(old_config)
trainer_config = trainer_mapper(old_config)
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)
def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":
with open(json_file_path, "r") as f:
old_config = json.load(f)
data_config = data_mapper(old_config)
model_config = model_mapper(old_config)
trainer_config = trainer_mapper(old_config)
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config)

@@ -146,3 +150,13 @@
The parsed `OmegaConf`.
"""
return TrainingJobConfig.load_yaml(filename)

def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Function lacks classmethod decorator despite using cls parameter

This function appears to be intended as a class method since it takes cls as its first parameter, but it's missing the @classmethod decorator and is defined at module level rather than within the TrainingJobConfig class.

Consider moving this function into the TrainingJobConfig class and adding the @classmethod decorator:

-def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+@classmethod
+def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":

Committable suggestion skipped: line range outside the PR's diff.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 53345e4 and 369dcff.

📒 Files selected for processing (2)
  • sleap_nn/config/data_config.py (1 hunks)
  • sleap_nn/config/model_config.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint
🔇 Additional comments (1)
sleap_nn/config/data_config.py (1)

215-221:

✅ Verification successful

Verify the probability values for augmentation configuration

The probability values (uniform_noise_p, gaussian_noise_p, contrast_p) are set to default values of 1.0, but the base class IntensityConfig defines them with defaults of 0.0. Additionally, validation for these fields ensures they are between 0.0 and 1.0.

Verify whether the default values of 1.0 from the legacy configuration are correct. This seems inconsistent with the class definitions where these probabilities default to 0.0, suggesting they might not be applied by default.


🏁 Script executed:

#!/bin/bash
# Check if there's any precedent in the codebase for using 1.0 as default probability
rg -A 2 -B 2 "noise_p|contrast_p|brightness_p"

Length of output: 32418


Legacy augmentation probability defaults confirmed

After reviewing the legacy configuration across the codebase and corresponding tests, it appears that the default values of 1.0 for uniform_noise_p, gaussian_noise_p, and contrast_p in the legacy configuration are intentional. Although the base class (IntensityConfig) defaults these probabilities to 0.0 and enforces a [0.0, 1.0] range, the legacy branch explicitly overrides them to 1.0—as validated by tests in files like tests/test_train.py and tests/data/test_augmentation.py. To avoid future confusion, consider documenting this intentional behavior and noting the divergence from the new defaults.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
sleap_nn/config/data_config.py (1)

189-243: ⚠️ Potential issue

Inconsistency in parameter handling

The function implementation has several issues that need to be addressed:

  1. Many parameters are commented out (lines 191-201) including critical path parameters like train_labels_path and val_labels_path which are marked as MISSING in the DataConfig class.
  2. Line 210 comments out use_augmentations_train, but it's used in the conditional at line 241.
  3. The structure of the legacy_config appears inconsistent with some values coming from data.preprocessing and others from optimization.augmentation_config.
def data_mapper(legacy_config: dict) -> DataConfig:
    return DataConfig(
-        # train_labels_path=legacy_config.get("train_labels_path", MISSING),
-        # val_labels_path=legacy_config.get("val_labels_path", MISSING),
+        train_labels_path=legacy_config.get("train_labels_path", MISSING),
+        val_labels_path=legacy_config.get("val_labels_path", MISSING),
        # test_file_path=legacy_config.get("test_file_path"),
        # provider=legacy_config.get("provider", "LabelsReader"),
        # user_instances_only=legacy_config.get("user_instances_only", True),
        # data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"),
        # np_chunks_path=legacy_config.get("np_chunks_path"),
        # litdata_chunks_path=legacy_config.get("litdata_chunks_path"),
        # use_existing_chunks=legacy_config.get("use_existing_chunks", False),
        # chunk_size=int(legacy_config.get("chunk_size", 100)),
        # delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True),
        preprocessing=PreprocessingConfig(
            is_rgb=legacy_config.get("data", {}).get("preprocessing", {}).get("ensure_rgb", False),
            max_height=legacy_config.get("data", {}).get("preprocessing", {}).get("target_height"),
            max_width=legacy_config.get("data", {}).get("preprocessing", {}).get("target_width"),
            scale=legacy_config.get("data", {}).get("preprocessing", {}).get("input_scaling", 1.0),
            crop_hw=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size"),
            min_crop_size=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size_detection_padding", 100),
        ),
-        # use_augmentations_train=legacy_config.get("use_augmentations_train", False),
+        use_augmentations_train=legacy_config.get("use_augmentations_train", False),
        augmentation_config=AugmentationConfig(
            intensity=IntensityConfig(
                uniform_noise_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise_min_val", 0.0),
🧹 Nitpick comments (1)
sleap_nn/config/data_config.py (1)

189-190: Add docstring to data_mapper function

The data_mapper function lacks a docstring to explain its purpose, parameters, and return value. This is important for maintainability and understanding the code's intent.

+"""
+Convert a legacy SLEAP configuration dictionary to a DataConfig object.
+
+Args:
+    legacy_config: Dictionary containing configuration from older SLEAP version
+
+Returns:
+    DataConfig: A structured configuration object matching the legacy settings
+"""
def data_mapper(legacy_config: dict) -> DataConfig:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 369dcff and 5649f45.

📒 Files selected for processing (1)
  • sleap_nn/config/data_config.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint
🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)

222-224: Fixed brightness tuple creation

The brightness tuple is now correctly created using parentheses, fixing the syntax error from the previous review.


213-221:

❓ Verification inconclusive

Inconsistent default values for noise and contrast parameters

The default values for noise and contrast parameters don't match the class defaults. The class defaults are:

  • uniform_noise_p: 0.0 in class, 1.0 in mapper
  • gaussian_noise_p: 0.0 in class, 1.0 in mapper
  • contrast_p: 0.0 in class, 1.0 in mapper

Please verify if these different default values are intentional or if they should match the class defaults.


🏁 Script executed:

#!/bin/bash
# Check where these values are used elsewhere in the codebase
echo "Searching for uniform_noise, gaussian_noise, and contrast default usages:"
rg --type py "uniform_noise|gaussian_noise|contrast.*p" -A 2 -B 2

Length of output: 34468


Action Required: Verify Default Augmentation Parameter Values

The legacy configuration mapper in sleap_nn/config/data_config.py is currently defaulting the following parameters to 1.0:

  • uniform_noise_p
  • gaussian_noise_p
  • contrast_p

However, the corresponding defaults defined in the augmentation module (sleap_nn/data/augmentation.py) are 0.0 for these parameters. This discrepancy may lead to inconsistent augmentation behavior depending on whether values come from the legacy config or are used directly from the class defaults. Some tests (e.g., in tests/test_train.py and tests/data/test_augmentation.py) appear to expect a default of 1.0 when augmentations are explicitly enabled, while the augmentation module itself defines a “disabled” default of 0.0.

Please verify whether the legacy mapping defaults are intended to override the class defaults. If the intended behavior is to have augmentations off by default, update the mapper in sleap_nn/config/data_config.py to use 0.0 (matching the class defaults). Otherwise, if the legacy behavior (defaulting to 1.0) is deliberate, ensure that the discrepancies are well documented and that the tests reflect this intended behavior.

Comment on lines 226 to 240
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0),
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001),
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01),
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0),
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0),
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0),
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]),
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Potential issues with geometric configuration parameters

There are several issues with the geometric configuration section:

  1. The scale parameter uses None as potential values which might conflict with the validator in the GeometricConfig class
  2. Most parameters are commented out without a clear indication why
  3. The rotation parameter uses a different default (180.0) compared to the class default (15.0)
            geometric=GeometricConfig(
                rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
-                scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
-                       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
+                scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9),
+                       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)),
                # translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
                # translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),

Consider adding a comment to explain why most of these parameters are commented out, or uncomment them if they should be included.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0),
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001),
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01),
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0),
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0),
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0),
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]),
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
)
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9),
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0),
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001),
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01),
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0),
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0),
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0),
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]),
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
sleap_nn/config/model_config.py (1)

906-917: ⚠️ Potential issue

Always creating centroid head violates one-of semantics.

Because centroid is set unconditionally, it will be active alongside any other head, contradicting the @oneof policy that only one head can be non-None. Wrap this in a condition to ensure that we only create the centroid head if head_type is set to "centroid".

-            centroid=CentroidConfig(
+            centroid=(
+                CentroidConfig(
+                    confmaps=CentroidConfMapsConfig(
+                        anchor_part=legacy_config.get("heads", {})
+                        .get("centroid", {})
+                        .get("anchor_part"),
+                        sigma=legacy_config.get("heads", {})
+                        .get("centroid", {})
+                        .get("sigma", 5.0),
+                        output_stride=legacy_config.get("heads", {})
+                        .get("centroid", {})
+                        .get("output_stride", 1),
+                    )
+                )
+                if legacy_config.get("head_type") == "centroid"
+                else None
+            ),
sleap_nn/config/data_config.py (1)

270-277: ⚠️ Potential issue

Scale set to (None, None) may break validation logic.

By default, GeometricConfig.scale expects numeric tuples (e.g., (0.9, 1.1)). Using (None, None) can trigger type errors or produce unintended outcomes. Provide valid numeric values if you want to enable scaling, or set this to None if scaling is not used.

-   scale=(
-       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
-       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None),
-   ),
+   scale=(
+       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9),
+       legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1),
+   ),
🧹 Nitpick comments (2)
sleap_nn/config/model_config.py (1)

835-836: Remove or clarify leftover comment.

These lines appear to contain a partial or uncertain comment about pretrained_backbone_weights (e.g., )?? # i think its different). This can be confusing to future maintainers. Consider removing or clarifying these comments to ensure the code is clean and unambiguous.

-        # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")?? # i think its different
+        # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")  # Clarify if needed or remove entirely
sleap_nn/config/data_config.py (1)

223-224: use_augmentations_train is commented out, causing mismatch.

The DataConfig instance sets use_augmentations_train to its default False, even if legacy_config indicates otherwise. Uncomment and map use_augmentations_train from legacy_config to keep the final config in sync.

-        # use_augmentations_train=legacy_config.get("use_augmentations_train", False),
+        use_augmentations_train=legacy_config.get("use_augmentations_train", False),
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5649f45 and cdc2c99.

📒 Files selected for processing (3)
  • sleap_nn/config/data_config.py (1 hunks)
  • sleap_nn/config/model_config.py (1 hunks)
  • sleap_nn/config/training_job_config.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • sleap_nn/config/training_job_config.py
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: Tests (macos-14, Python 3.9)
  • GitHub Check: Tests (windows-latest, Python 3.9)
  • GitHub Check: Tests (ubuntu-latest, Python 3.9)
  • GitHub Check: Lint
🔇 Additional comments (4)
sleap_nn/config/model_config.py (2)

837-842: Confirm usage of commented-out parameters.

In the returned ModelConfig, multiple parameters like init_weights and pre_trained_weights are commented out (lines 839-841). If these are intentionally omitted from legacy handling, consider removing them altogether. Otherwise, uncomment and map them from legacy_config to ensure backward compatibility.


889-905: Conditional creation for single_instance head looks good.

The code correctly creates a SingleInstanceConfig only when legacy_config.get("head_type") == "single_instance". This aligns with the @oneof requirement that only one head should be active.

sleap_nn/config/data_config.py (2)

190-202: Review whether commented-out fields should be restored or removed.

Several fields (e.g., train_labels_path, val_labels_path, etc.) are commented out, preventing them from being set in DataConfig. Verify if you need these fields for accurate data loading. If they are essential, uncomment and map them; otherwise, remove them to avoid confusion.


267-269: Rotation uses 180.0 by default, diverging from internal default of 15.0.

This large rotation range can significantly alter training behavior. If this is intentional, confirm it aligns with your training objectives. Otherwise, consider a smaller angle or the class default.

Comment on lines +935 to +965
bottomup=BottomUpConfig(
confmaps=BottomUpConfMapsConfig(
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("part_names", None),
),
pafs=PAFConfig(
edges=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("edges", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 15.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
),
),
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

bottomup head is always created, overriding the one-of restriction.

Like centroid and centered_instance, this head is being instantiated unconditionally. This breaks the @oneof constraint, which mandates that only a single head may be active at once. Wrap the creation of BottomUpConfig in a conditional check for "bottomup".

-            bottomup=BottomUpConfig(
+            bottomup=(
+                BottomUpConfig(
                     confmaps=BottomUpConfMapsConfig(
                         ...
                     ),
                     pafs=PAFConfig(
                         ...
                     ),
                 )
+                if legacy_config.get("head_type") == "bottomup"
+                else None
+            ),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bottomup=BottomUpConfig(
confmaps=BottomUpConfMapsConfig(
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("part_names", None),
),
pafs=PAFConfig(
edges=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("edges", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 15.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
),
),
),
bottomup=(
BottomUpConfig(
confmaps=BottomUpConfMapsConfig(
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("part_names", None),
),
pafs=PAFConfig(
edges=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("edges", None),
sigma=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("sigma", 15.0),
output_stride=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("output_stride", 1),
loss_weight=legacy_config.get("heads", {})
.get("multi_instance", {})
.get("loss_weight", None),
),
)
if legacy_config.get("head_type") == "bottomup"
else None
),

Comment on lines +919 to +934
centered_instance=CenteredInstanceConfig(
confmaps=CentroidConfMapsConfig(
anchor_part=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("anchor_part"),
sigma=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("part_names", None),
)
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

CenteredInstanceConfig uses CentroidConfMapsConfig instead of CenteredInstanceConfMapsConfig.

This is likely an error, as the centered_instance head is mismatched with the CentroidConfMapsConfig. Use CenteredInstanceConfMapsConfig to align with the class definitions.

-            centered_instance=CenteredInstanceConfig(
-                confmaps=CentroidConfMapsConfig(
+            centered_instance=(
+                CenteredInstanceConfig(
+                    confmaps=CenteredInstanceConfMapsConfig(
                         anchor_part=legacy_config.get("heads", {})
                         .get("centered_instance", {})
                         .get("anchor_part"),
                         sigma=legacy_config.get("heads", {})
                         .get("centered_instance", {})
                         .get("sigma", 5.0),
                         output_stride=legacy_config.get("heads", {})
                         .get("centered_instance", {})
                         .get("output_stride", 1),
                         part_names=legacy_config.get("heads", {})
                         .get("centered_instance", {})
                         .get("part_names", None),
                     )
-                )
-            ),
+                )
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
centered_instance=CenteredInstanceConfig(
confmaps=CentroidConfMapsConfig(
anchor_part=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("anchor_part"),
sigma=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("part_names", None),
)
),
centered_instance=(
CenteredInstanceConfig(
confmaps=CenteredInstanceConfMapsConfig(
anchor_part=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("anchor_part"),
sigma=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("part_names", None),
)
)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant