-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
outline for loading old sleap config #162
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis change introduces three new functions across different configuration modules: Changes
Possibly related issues
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (3)
sleap_nn/config/training_job_config.py (3)
154-162
: Missing docstring for new functionUnlike other functions in this file,
load_sleap_config
lacks a docstring describing its purpose, parameters, and return value.Add a descriptive docstring:
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: + """Load a training job configuration from a legacy SLEAP JSON config file. + + Arguments: + json_file_path: Path to a legacy SLEAP JSON configuration file. + + Returns: + A TrainingJobConfig instance containing the mapped configuration data. + """ with open(json_file_path, "r") as f: old_config = json.load(f)
158-161
: Missing error handling for mapper functionsThere's no error handling for potential exceptions in the mapper functions.
Add error handling for the mapper functions:
- data_config = data_mapper(old_config) - model_config = model_mapper(old_config) - trainer_config = trainer_mapper(old_config) + try: + data_config = data_mapper(old_config) + model_config = model_mapper(old_config) + trainer_config = trainer_mapper(old_config) + except Exception as e: + raise ValueError(f"Failed to map configuration from {json_file_path}: {str(e)}")
162-162
: Filename attribute not set in the created configUnlike the YAML loading methods, this function doesn't set the
filename
attribute to the source file path.Set the filename attribute in the returned config:
- return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) + return cls( + data_config=data_config, + model_config=model_config, + trainer_config=trainer_config, + filename=json_file_path + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sleap_nn/config/training_job_config.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
with open(json_file_path, "r") as f: | ||
old_config = json.load(f) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Missing error handling for file operations and JSON parsing
The function doesn't handle potential exceptions when opening the file or parsing the JSON content.
Add error handling for file operations and JSON parsing:
- with open(json_file_path, "r") as f:
- old_config = json.load(f)
+ try:
+ with open(json_file_path, "r") as f:
+ old_config = json.load(f)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found: {json_file_path}")
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid JSON in configuration file: {json_file_path}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
try: | |
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Configuration file not found: {json_file_path}") | |
except json.JSONDecodeError: | |
raise ValueError(f"Invalid JSON in configuration file: {json_file_path}") |
model_config = model_mapper(old_config) | ||
trainer_config = trainer_mapper(old_config) | ||
|
||
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Missing validation for loaded configuration
The function doesn't apply the check_output_strides
validation that's performed in the YAML loading methods.
Consider whether the check_output_strides
validation should be applied to JSON loaded configurations as well. This validation ensures compatibility between backbone and head configurations.
🏁 Script executed:
#!/bin/bash
# Check if check_output_strides is applied in any other loaders
grep -rn "check_output_strides" --include="*.py" ./
Length of output: 379
Action Required: Apply Missing check_output_strides
Validation in JSON Loading
It appears that the JSON configuration loading path in sleap_nn/config/training_job_config.py
(culminating in the return statement at line 162) does not invoke the check_output_strides
validation. In contrast, YAML loaders in this file call TrainingJobConfig.check_output_strides(config)
(as seen on lines 99 and 117). For consistency and to ensure compatibility between backbone and head configurations, please consider applying the same validation for JSON-loaded configurations.
- Evaluate adding a call to
TrainingJobConfig.check_output_strides(config)
right before the instance initialization in the JSON loader. - Confirm that this adjustment aligns with the intended configuration compatibility checks across all loaders.
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: | ||
with open(json_file_path, "r") as f: | ||
old_config = json.load(f) | ||
|
||
data_config = data_mapper(old_config) | ||
model_config = model_mapper(old_config) | ||
trainer_config = trainer_mapper(old_config) | ||
|
||
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation doesn't match actual implementation
The function signature indicates it returns a TrainerConfig
, but it actually returns an instance of cls
which contains data_config, model_config, and trainer_config objects.
-def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: | |
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
data_config = data_mapper(old_config) | |
model_config = model_mapper(old_config) | |
trainer_config = trainer_mapper(old_config) | |
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) | |
def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig": | |
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
data_config = data_mapper(old_config) | |
model_config = model_mapper(old_config) | |
trainer_config = trainer_mapper(old_config) | |
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) |
@@ -146,3 +150,13 @@ | |||
The parsed `OmegaConf`. | |||
""" | |||
return TrainingJobConfig.load_yaml(filename) | |||
|
|||
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Function lacks classmethod decorator despite using cls
parameter
This function appears to be intended as a class method since it takes cls
as its first parameter, but it's missing the @classmethod
decorator and is defined at module level rather than within the TrainingJobConfig
class.
Consider moving this function into the TrainingJobConfig
class and adding the @classmethod
decorator:
-def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+@classmethod
+def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":
Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
sleap_nn/config/data_config.py
(1 hunks)sleap_nn/config/model_config.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (1)
sleap_nn/config/data_config.py (1)
215-221
:✅ Verification successful
Verify the probability values for augmentation configuration
The probability values (
uniform_noise_p
,gaussian_noise_p
,contrast_p
) are set to default values of 1.0, but the base classIntensityConfig
defines them with defaults of 0.0. Additionally, validation for these fields ensures they are between 0.0 and 1.0.Verify whether the default values of 1.0 from the legacy configuration are correct. This seems inconsistent with the class definitions where these probabilities default to 0.0, suggesting they might not be applied by default.
🏁 Script executed:
#!/bin/bash # Check if there's any precedent in the codebase for using 1.0 as default probability rg -A 2 -B 2 "noise_p|contrast_p|brightness_p"Length of output: 32418
Legacy augmentation probability defaults confirmed
After reviewing the legacy configuration across the codebase and corresponding tests, it appears that the default values of 1.0 for
uniform_noise_p
,gaussian_noise_p
, andcontrast_p
in the legacy configuration are intentional. Although the base class (IntensityConfig
) defaults these probabilities to 0.0 and enforces a [0.0, 1.0] range, the legacy branch explicitly overrides them to 1.0—as validated by tests in files liketests/test_train.py
andtests/data/test_augmentation.py
. To avoid future confusion, consider documenting this intentional behavior and noting the divergence from the new defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
sleap_nn/config/data_config.py (1)
189-243
:⚠️ Potential issueInconsistency in parameter handling
The function implementation has several issues that need to be addressed:
- Many parameters are commented out (lines 191-201) including critical path parameters like
train_labels_path
andval_labels_path
which are marked asMISSING
in the DataConfig class.- Line 210 comments out
use_augmentations_train
, but it's used in the conditional at line 241.- The structure of the legacy_config appears inconsistent with some values coming from
data.preprocessing
and others fromoptimization.augmentation_config
.def data_mapper(legacy_config: dict) -> DataConfig: return DataConfig( - # train_labels_path=legacy_config.get("train_labels_path", MISSING), - # val_labels_path=legacy_config.get("val_labels_path", MISSING), + train_labels_path=legacy_config.get("train_labels_path", MISSING), + val_labels_path=legacy_config.get("val_labels_path", MISSING), # test_file_path=legacy_config.get("test_file_path"), # provider=legacy_config.get("provider", "LabelsReader"), # user_instances_only=legacy_config.get("user_instances_only", True), # data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"), # np_chunks_path=legacy_config.get("np_chunks_path"), # litdata_chunks_path=legacy_config.get("litdata_chunks_path"), # use_existing_chunks=legacy_config.get("use_existing_chunks", False), # chunk_size=int(legacy_config.get("chunk_size", 100)), # delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True), preprocessing=PreprocessingConfig( is_rgb=legacy_config.get("data", {}).get("preprocessing", {}).get("ensure_rgb", False), max_height=legacy_config.get("data", {}).get("preprocessing", {}).get("target_height"), max_width=legacy_config.get("data", {}).get("preprocessing", {}).get("target_width"), scale=legacy_config.get("data", {}).get("preprocessing", {}).get("input_scaling", 1.0), crop_hw=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size"), min_crop_size=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size_detection_padding", 100), ), - # use_augmentations_train=legacy_config.get("use_augmentations_train", False), + use_augmentations_train=legacy_config.get("use_augmentations_train", False), augmentation_config=AugmentationConfig( intensity=IntensityConfig( uniform_noise_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise_min_val", 0.0),
🧹 Nitpick comments (1)
sleap_nn/config/data_config.py (1)
189-190
: Add docstring to data_mapper functionThe
data_mapper
function lacks a docstring to explain its purpose, parameters, and return value. This is important for maintainability and understanding the code's intent.+""" +Convert a legacy SLEAP configuration dictionary to a DataConfig object. + +Args: + legacy_config: Dictionary containing configuration from older SLEAP version + +Returns: + DataConfig: A structured configuration object matching the legacy settings +""" def data_mapper(legacy_config: dict) -> DataConfig:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sleap_nn/config/data_config.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)
222-224
: Fixed brightness tuple creationThe brightness tuple is now correctly created using parentheses, fixing the syntax error from the previous review.
213-221
:❓ Verification inconclusive
Inconsistent default values for noise and contrast parameters
The default values for noise and contrast parameters don't match the class defaults. The class defaults are:
uniform_noise_p
: 0.0 in class, 1.0 in mappergaussian_noise_p
: 0.0 in class, 1.0 in mappercontrast_p
: 0.0 in class, 1.0 in mapperPlease verify if these different default values are intentional or if they should match the class defaults.
🏁 Script executed:
#!/bin/bash # Check where these values are used elsewhere in the codebase echo "Searching for uniform_noise, gaussian_noise, and contrast default usages:" rg --type py "uniform_noise|gaussian_noise|contrast.*p" -A 2 -B 2Length of output: 34468
Action Required: Verify Default Augmentation Parameter Values
The legacy configuration mapper in
sleap_nn/config/data_config.py
is currently defaulting the following parameters to 1.0:
uniform_noise_p
gaussian_noise_p
contrast_p
However, the corresponding defaults defined in the augmentation module (
sleap_nn/data/augmentation.py
) are 0.0 for these parameters. This discrepancy may lead to inconsistent augmentation behavior depending on whether values come from the legacy config or are used directly from the class defaults. Some tests (e.g., intests/test_train.py
andtests/data/test_augmentation.py
) appear to expect a default of 1.0 when augmentations are explicitly enabled, while the augmentation module itself defines a “disabled” default of 0.0.Please verify whether the legacy mapping defaults are intended to override the class defaults. If the intended behavior is to have augmentations off by default, update the mapper in
sleap_nn/config/data_config.py
to use 0.0 (matching the class defaults). Otherwise, if the legacy behavior (defaulting to 1.0) is deliberate, ensure that the discrepancies are well documented and that the tests reflect this intended behavior.
sleap_nn/config/data_config.py
Outdated
geometric=GeometricConfig( | ||
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | ||
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), | ||
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)), | ||
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | ||
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | ||
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | ||
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | ||
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | ||
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | ||
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | ||
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | ||
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | ||
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Potential issues with geometric configuration parameters
There are several issues with the geometric configuration section:
- The scale parameter uses
None
as potential values which might conflict with the validator in theGeometricConfig
class - Most parameters are commented out without a clear indication why
- The rotation parameter uses a different default (180.0) compared to the class default (15.0)
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
- scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
- legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
+ scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9),
+ legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
Consider adding a comment to explain why most of these parameters are commented out, or uncomment them if they should be included.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
geometric=GeometricConfig( | |
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | |
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), | |
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)), | |
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | |
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | |
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | |
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | |
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | |
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | |
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | |
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | |
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | |
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | |
) | |
geometric=GeometricConfig( | |
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | |
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9), | |
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)), | |
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | |
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | |
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | |
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | |
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | |
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | |
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | |
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | |
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | |
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
sleap_nn/config/model_config.py (1)
906-917
:⚠️ Potential issueAlways creating centroid head violates one-of semantics.
Because
centroid
is set unconditionally, it will be active alongside any other head, contradicting the@oneof
policy that only one head can be non-None. Wrap this in a condition to ensure that we only create the centroid head ifhead_type
is set to"centroid"
.- centroid=CentroidConfig( + centroid=( + CentroidConfig( + confmaps=CentroidConfMapsConfig( + anchor_part=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("anchor_part"), + sigma=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("sigma", 5.0), + output_stride=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("output_stride", 1), + ) + ) + if legacy_config.get("head_type") == "centroid" + else None + ),sleap_nn/config/data_config.py (1)
270-277
:⚠️ Potential issueScale set to
(None, None)
may break validation logic.By default,
GeometricConfig.scale
expects numeric tuples (e.g.,(0.9, 1.1)
). Using(None, None)
can trigger type errors or produce unintended outcomes. Provide valid numeric values if you want to enable scaling, or set this toNone
if scaling is not used.- scale=( - legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), - legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None), - ), + scale=( + legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9), + legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1), + ),
🧹 Nitpick comments (2)
sleap_nn/config/model_config.py (1)
835-836
: Remove or clarify leftover comment.These lines appear to contain a partial or uncertain comment about
pretrained_backbone_weights
(e.g.,)?? # i think its different
). This can be confusing to future maintainers. Consider removing or clarifying these comments to ensure the code is clean and unambiguous.- # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")?? # i think its different + # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig") # Clarify if needed or remove entirelysleap_nn/config/data_config.py (1)
223-224
:use_augmentations_train
is commented out, causing mismatch.The
DataConfig
instance setsuse_augmentations_train
to its defaultFalse
, even iflegacy_config
indicates otherwise. Uncomment and mapuse_augmentations_train
fromlegacy_config
to keep the final config in sync.- # use_augmentations_train=legacy_config.get("use_augmentations_train", False), + use_augmentations_train=legacy_config.get("use_augmentations_train", False),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
sleap_nn/config/data_config.py
(1 hunks)sleap_nn/config/model_config.py
(1 hunks)sleap_nn/config/training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- sleap_nn/config/training_job_config.py
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (4)
sleap_nn/config/model_config.py (2)
837-842
: Confirm usage of commented-out parameters.In the returned
ModelConfig
, multiple parameters likeinit_weights
andpre_trained_weights
are commented out (lines 839-841). If these are intentionally omitted from legacy handling, consider removing them altogether. Otherwise, uncomment and map them fromlegacy_config
to ensure backward compatibility.
889-905
: Conditional creation for single_instance head looks good.The code correctly creates a
SingleInstanceConfig
only whenlegacy_config.get("head_type") == "single_instance"
. This aligns with the@oneof
requirement that only one head should be active.sleap_nn/config/data_config.py (2)
190-202
: Review whether commented-out fields should be restored or removed.Several fields (e.g.,
train_labels_path
,val_labels_path
, etc.) are commented out, preventing them from being set inDataConfig
. Verify if you need these fields for accurate data loading. If they are essential, uncomment and map them; otherwise, remove them to avoid confusion.
267-269
: Rotation uses 180.0 by default, diverging from internal default of 15.0.This large rotation range can significantly alter training behavior. If this is intentional, confirm it aligns with your training objectives. Otherwise, consider a smaller angle or the class default.
bottomup=BottomUpConfig( | ||
confmaps=BottomUpConfMapsConfig( | ||
loss_weight=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("loss_weight", None), | ||
sigma=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("sigma", 5.0), | ||
output_stride=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("output_stride", 1), | ||
part_names=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("part_names", None), | ||
), | ||
pafs=PAFConfig( | ||
edges=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("edges", None), | ||
sigma=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("sigma", 15.0), | ||
output_stride=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("output_stride", 1), | ||
loss_weight=legacy_config.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("loss_weight", None), | ||
), | ||
), | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bottomup
head is always created, overriding the one-of restriction.
Like centroid
and centered_instance
, this head is being instantiated unconditionally. This breaks the @oneof
constraint, which mandates that only a single head may be active at once. Wrap the creation of BottomUpConfig
in a conditional check for "bottomup"
.
- bottomup=BottomUpConfig(
+ bottomup=(
+ BottomUpConfig(
confmaps=BottomUpConfMapsConfig(
...
),
pafs=PAFConfig(
...
),
)
+ if legacy_config.get("head_type") == "bottomup"
+ else None
+ ),
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
bottomup=BottomUpConfig( | |
confmaps=BottomUpConfMapsConfig( | |
loss_weight=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("loss_weight", None), | |
sigma=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("sigma", 5.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("output_stride", 1), | |
part_names=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("part_names", None), | |
), | |
pafs=PAFConfig( | |
edges=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("edges", None), | |
sigma=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("sigma", 15.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("output_stride", 1), | |
loss_weight=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("loss_weight", None), | |
), | |
), | |
), | |
bottomup=( | |
BottomUpConfig( | |
confmaps=BottomUpConfMapsConfig( | |
loss_weight=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("loss_weight", None), | |
sigma=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("sigma", 5.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("output_stride", 1), | |
part_names=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("part_names", None), | |
), | |
pafs=PAFConfig( | |
edges=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("edges", None), | |
sigma=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("sigma", 15.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("output_stride", 1), | |
loss_weight=legacy_config.get("heads", {}) | |
.get("multi_instance", {}) | |
.get("loss_weight", None), | |
), | |
) | |
if legacy_config.get("head_type") == "bottomup" | |
else None | |
), |
centered_instance=CenteredInstanceConfig( | ||
confmaps=CentroidConfMapsConfig( | ||
anchor_part=legacy_config.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("anchor_part"), | ||
sigma=legacy_config.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("sigma", 5.0), | ||
output_stride=legacy_config.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("output_stride", 1), | ||
part_names=legacy_config.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("part_names", None), | ||
) | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CenteredInstanceConfig
uses CentroidConfMapsConfig
instead of CenteredInstanceConfMapsConfig
.
This is likely an error, as the centered_instance
head is mismatched with the CentroidConfMapsConfig
. Use CenteredInstanceConfMapsConfig
to align with the class definitions.
- centered_instance=CenteredInstanceConfig(
- confmaps=CentroidConfMapsConfig(
+ centered_instance=(
+ CenteredInstanceConfig(
+ confmaps=CenteredInstanceConfMapsConfig(
anchor_part=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("anchor_part"),
sigma=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("sigma", 5.0),
output_stride=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("output_stride", 1),
part_names=legacy_config.get("heads", {})
.get("centered_instance", {})
.get("part_names", None),
)
- )
- ),
+ )
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
centered_instance=CenteredInstanceConfig( | |
confmaps=CentroidConfMapsConfig( | |
anchor_part=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("anchor_part"), | |
sigma=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("sigma", 5.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("output_stride", 1), | |
part_names=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("part_names", None), | |
) | |
), | |
centered_instance=( | |
CenteredInstanceConfig( | |
confmaps=CenteredInstanceConfMapsConfig( | |
anchor_part=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("anchor_part"), | |
sigma=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("sigma", 5.0), | |
output_stride=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("output_stride", 1), | |
part_names=legacy_config.get("heads", {}) | |
.get("centered_instance", {}) | |
.get("part_names", None), | |
) | |
) | |
) |
Summary by CodeRabbit