Skip to content

Commit 41022d8

Browse files
gitttt-1234claude
andcommitted
Consolidate metrics saving to use SLEAP 1.4 format and eliminate code duplication
- Updated run_evaluation() to save metrics using SLEAP 1.4 format (single "metrics" key) - Refactored train.py to use run_evaluation() instead of duplicating evaluation code - Removed unused imports (numpy, sleap_io) from train.py - Updated test_evaluation.py to load metrics in SLEAP 1.4 format - Ensured load_metrics() function is compatible with the save format 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3eb808a commit 41022d8

File tree

3 files changed

+24
-40
lines changed

3 files changed

+24
-40
lines changed

sleap_nn/evaluation.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -745,22 +745,8 @@ def run_evaluation(
745745
save_path = Path(save_metrics)
746746
save_path.parent.mkdir(parents=True, exist_ok=True)
747747

748-
# Convert metrics to numpy arrays for saving
749-
np.savez(
750-
save_path,
751-
mOKS=metrics["mOKS"]["mOKS"],
752-
mAP=metrics["voc_metrics"]["oks_voc.mAP"],
753-
mAR=metrics["voc_metrics"]["oks_voc.mAR"],
754-
avg_distance=metrics["distance_metrics"]["avg"],
755-
mPCK=metrics["pck_metrics"]["mPCK"],
756-
visibility_precision=metrics["visibility_metrics"]["precision"],
757-
visibility_recall=metrics["visibility_metrics"]["recall"],
758-
# Save full metrics dict as well
759-
voc_metrics=metrics["voc_metrics"],
760-
distance_metrics=metrics["distance_metrics"],
761-
pck_metrics=metrics["pck_metrics"],
762-
visibility_metrics=metrics["visibility_metrics"],
763-
)
748+
# Save metrics in SLEAP 1.4 format (single "metrics" key)
749+
np.savez_compressed(save_path, **{"metrics": metrics})
764750
logger.info(f"Metrics saved successfully to {save_path}")
765751

766752
return metrics

sleap_nn/train.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,19 @@
22

33
from loguru import logger
44
from pathlib import Path
5-
import numpy as np
65
from datetime import datetime
76
from time import time
87
from omegaconf import DictConfig, OmegaConf
98
from typing import Any, Dict, Optional, List, Tuple, Union
10-
import sleap_io as sio
119
from sleap_nn.config.training_job_config import TrainingJobConfig
1210
from sleap_nn.training.model_trainer import ModelTrainer
1311
from sleap_nn.predict import run_inference as predict
14-
from sleap_nn.evaluation import Evaluator
12+
from sleap_nn.evaluation import run_evaluation
1513
from sleap_nn.config.get_config import (
1614
get_trainer_config,
1715
get_model_config,
1816
get_data_config,
1917
)
20-
from typing import Any, Dict, Optional, List, Tuple, Union
2118

2219

2320
def run_training(config: DictConfig):
@@ -64,7 +61,16 @@ def run_training(config: DictConfig):
6461
data_paths["test"] = config.data_config.test_file_path
6562

6663
for d_name, path in data_paths.items():
67-
labels = sio.load_slp(path)
64+
pred_path = (
65+
Path(trainer.config.trainer_config.ckpt_dir)
66+
/ trainer.config.trainer_config.run_name
67+
/ f"pred_{d_name}.slp"
68+
)
69+
metrics_path = (
70+
Path(trainer.config.trainer_config.ckpt_dir)
71+
/ trainer.config.trainer_config.run_name
72+
/ f"{d_name}_pred_metrics.npz"
73+
)
6874

6975
pred_labels = predict(
7076
data_path=path,
@@ -75,9 +81,7 @@ def run_training(config: DictConfig):
7581
peak_threshold=0.2,
7682
make_labels=True,
7783
device=trainer.trainer.strategy.root_device,
78-
output_path=Path(trainer.config.trainer_config.ckpt_dir)
79-
/ trainer.config.trainer_config.run_name
80-
/ f"pred_{d_name}.slp",
84+
output_path=pred_path,
8185
ensure_rgb=config.data_config.preprocessing.ensure_rgb,
8286
ensure_grayscale=config.data_config.preprocessing.ensure_grayscale,
8387
)
@@ -88,27 +92,18 @@ def run_training(config: DictConfig):
8892
)
8993
continue # skip if there are no labeled frames
9094

91-
evaluator = Evaluator(
92-
ground_truth_instances=labels, predicted_instances=pred_labels
93-
)
94-
metrics = evaluator.evaluate()
95-
np.savez_compressed(
96-
(
97-
Path(trainer.config.trainer_config.ckpt_dir)
98-
/ trainer.config.trainer_config.run_name
99-
/ f"{d_name}_pred_metrics.npz"
100-
).as_posix(),
101-
**{"metrics": metrics},
95+
# Run evaluation and save metrics
96+
metrics = run_evaluation(
97+
ground_truth_path=path,
98+
predicted_path=pred_path.as_posix(),
99+
save_metrics=metrics_path.as_posix(),
102100
)
103101

104102
logger.info(f"---------Evaluation on `{d_name}` dataset---------")
105103
logger.info(f"OKS mAP: {metrics['voc_metrics']['oks_voc.mAP']}")
106104
logger.info(f"Average distance: {metrics['distance_metrics']['avg']}")
107105
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
108106
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
109-
logger.info(
110-
f"metrics saved to {Path(trainer.config.trainer_config.ckpt_dir) / trainer.config.trainer_config.run_name / (d_name + '_pred_metrics.npz')}"
111-
)
112107

113108

114109
def train(

tests/test_evaluation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,13 +526,16 @@ def test_evaluator_main(
526526
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
527527
assert Path(f"{tmp_path}/metrics_test.npz").exists()
528528

529-
metrics = np.load(f"{tmp_path}/metrics_test.npz", allow_pickle=True)
529+
# Load metrics in SLEAP 1.4 format (single "metrics" key)
530+
metrics_npz = np.load(f"{tmp_path}/metrics_test.npz", allow_pickle=True)
531+
assert "metrics" in metrics_npz
532+
metrics = metrics_npz["metrics"].item()
530533
assert "voc_metrics" in metrics
531534
assert "mOKS" in metrics
532535
assert "distance_metrics" in metrics
533536
assert "pck_metrics" in metrics
534537
assert "visibility_metrics" in metrics
535-
voc_metrics = metrics["voc_metrics"].item()
538+
voc_metrics = metrics["voc_metrics"]
536539
assert "pck_voc.mAP" in voc_metrics
537540
assert "pck_voc.mAR" in voc_metrics
538541
assert "oks_voc.mAP" in voc_metrics

0 commit comments

Comments
 (0)