Skip to content

Commit ae458ca

Browse files
talmoclaude
andauthored
Fix CSV logger not capturing learning_rate (#423)
## Summary Fixes a regression introduced in PR #417 where the `learning_rate` column in `training_log.csv` was always empty. Also adds model-specific loss columns to the CSV for better parity with wandb logging. Fixes #422 ## Root Cause PR #417 made several changes to metrics logging: 1. Removed `LearningRateMonitor` callback (which logged as `lr-Adam`) 2. Added manual learning rate logging as `train/lr` However, the `CSVLoggerCallback` was only looking for: - `learning_rate` (direct key - never logged) - `lr-*` pattern (LearningRateMonitor format - no longer used) The new `train/lr` key was never checked, resulting in empty `learning_rate` values. ## Changes ### 1. Fix learning rate lookup (`sleap_nn/training/callbacks.py`) The CSVLoggerCallback now checks for the learning rate in this order: 1. `learning_rate` (direct key) 2. `train/lr` (current format from lightning modules) ← **NEW** 3. `lr-*` pattern (legacy LearningRateMonitor format) ### 2. Add model-specific CSV columns (`sleap_nn/training/model_trainer.py`) Added loss breakdown columns for different model types to match what's logged to wandb: | Model Type | New CSV Columns | |------------|-----------------| | `bottomup` | `train/confmaps_loss`, `train/paf_loss`, `val/confmaps_loss`, `val/paf_loss` | | `multi_class_bottomup` | `train/confmaps_loss`, `train/classmap_loss`, `train/class_accuracy`, `val/confmaps_loss`, `val/classmap_loss`, `val/class_accuracy` | | `multi_class_topdown` | `train/confmaps_loss`, `train/classvector_loss`, `train/class_accuracy`, `val/confmaps_loss`, `val/classvector_loss`, `val/class_accuracy` | ### 3. Add test (`tests/training/test_callbacks.py`) Added `test_on_validation_epoch_end_logs_train_lr_format` to verify the new `train/lr` key lookup works correctly. ## Example Output **Before (broken):** ```csv epoch,train/loss,val/loss,learning_rate,train/time,val/time 0,,0.006371453870087862,,, 1,0.0006624094676226377,0.0002221532049588859,,32.815,6.364 ``` **After (fixed):** ```csv epoch,train/loss,val/loss,learning_rate,train/time,val/time 0,,0.006371453870087862,,, 1,0.0006624094676226377,0.0002221532049588859,0.0001,32.815,6.364 ``` ## API Changes ### CSV Column Additions The `training_log.csv` file will now include additional columns depending on the model type. This is a non-breaking change - existing code that reads the CSV will continue to work, and the new columns provide additional information. **Note:** The CSV column name remains `learning_rate` (not `train/lr`) for backward compatibility with existing analysis scripts. ## Design Decisions 1. **Backward compatible column name**: We kept `learning_rate` as the CSV column name rather than changing to `train/lr` to avoid breaking existing analysis pipelines that expect the old name. 2. **Fallback chain for LR lookup**: The callback checks multiple key formats in order, maintaining compatibility with: - Direct `learning_rate` logging (if someone uses it) - New `train/lr` format (current) - Legacy `lr-*` format (LearningRateMonitor) 3. **Model-specific columns**: Rather than logging all possible columns for all models (which would result in many empty columns), we only add columns relevant to each model type. ## Test Plan - [x] `pytest tests/training/test_callbacks.py::TestCSVLoggerCallbackFileOps` - Unit tests for CSV logger - [x] `pytest tests/training/test_model_trainer.py::test_model_trainer_centered_instance` - Integration test verifying learning_rate is logged correctly --- 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent df72b51 commit ae458ca

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

sleap_nn/training/callbacks.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,15 @@ def on_validation_epoch_end(self, trainer, pl_module):
8585
if key == "epoch":
8686
log_data["epoch"] = trainer.current_epoch
8787
elif key == "learning_rate":
88-
# Handle both direct logging and LearningRateMonitor format (lr-*)
88+
# Handle multiple formats:
89+
# 1. Direct "learning_rate" key
90+
# 2. "train/lr" key (current format from lightning modules)
91+
# 3. "lr-*" keys from LearningRateMonitor (legacy)
8992
value = metrics.get(key, None)
9093
if value is None:
91-
# Look for lr-* keys from LearningRateMonitor
94+
value = metrics.get("train/lr", None)
95+
if value is None:
96+
# Look for lr-* keys from LearningRateMonitor (legacy)
9297
for metric_key in metrics.keys():
9398
if metric_key.startswith("lr-"):
9499
value = metrics[metric_key]

sleap_nn/training/model_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,7 @@ def _setup_loggers_callbacks(self, viz_train_dataset, viz_val_dataset):
849849
"train/time",
850850
"val/time",
851851
]
852+
# Add model-specific keys for wandb parity
852853
if self.model_type in [
853854
"single_instance",
854855
"centered_instance",
@@ -857,6 +858,37 @@ def _setup_loggers_callbacks(self, viz_train_dataset, viz_val_dataset):
857858
csv_log_keys.extend(
858859
[f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
859860
)
861+
if self.model_type == "bottomup":
862+
csv_log_keys.extend(
863+
[
864+
"train/confmaps_loss",
865+
"train/paf_loss",
866+
"val/confmaps_loss",
867+
"val/paf_loss",
868+
]
869+
)
870+
if self.model_type == "multi_class_bottomup":
871+
csv_log_keys.extend(
872+
[
873+
"train/confmaps_loss",
874+
"train/classmap_loss",
875+
"train/class_accuracy",
876+
"val/confmaps_loss",
877+
"val/classmap_loss",
878+
"val/class_accuracy",
879+
]
880+
)
881+
if self.model_type == "multi_class_topdown":
882+
csv_log_keys.extend(
883+
[
884+
"train/confmaps_loss",
885+
"train/classvector_loss",
886+
"train/class_accuracy",
887+
"val/confmaps_loss",
888+
"val/classvector_loss",
889+
"val/class_accuracy",
890+
]
891+
)
860892
csv_logger = CSVLoggerCallback(
861893
filepath=Path(self.config.trainer_config.ckpt_dir)
862894
/ self.config.trainer_config.run_name

tests/training/test_callbacks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,38 @@ def test_on_validation_epoch_end_logs_metrics(self):
622622
assert len(lines) == 2 # Header + data row
623623
assert "5" in lines[1] # Epoch
624624

625+
def test_on_validation_epoch_end_logs_train_lr_format(self):
626+
"""Logs learning rate from train/lr key (current format)."""
627+
with tempfile.TemporaryDirectory() as tmpdir:
628+
filepath = Path(tmpdir) / "metrics.csv"
629+
callback = CSVLoggerCallback(filepath=filepath)
630+
631+
mock_trainer = MagicMock()
632+
mock_trainer.is_global_zero = True
633+
mock_trainer.current_epoch = 3
634+
mock_trainer.callback_metrics = {
635+
"train_loss": torch.tensor(0.4),
636+
"val_loss": torch.tensor(0.2),
637+
"train/lr": torch.tensor(
638+
0.0005
639+
), # Current format from lightning modules
640+
}
641+
mock_pl_module = MagicMock()
642+
643+
with patch("sleap_nn.training.callbacks.RANK", 0):
644+
callback.on_validation_epoch_end(mock_trainer, mock_pl_module)
645+
646+
assert filepath.exists()
647+
648+
# Read and verify contents
649+
import csv
650+
651+
with open(filepath) as f:
652+
reader = csv.DictReader(f)
653+
row = next(reader)
654+
assert row["epoch"] == "3"
655+
assert row["learning_rate"].startswith("0.0005")
656+
625657
def test_on_validation_epoch_end_skips_if_not_global_zero(self):
626658
"""Skips logging if not global rank zero."""
627659
with tempfile.TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)