Skip to content

Commit 407861d

Browse files
talmoclaude
andauthored
Add --config flag for simpler train CLI + fix crop device mismatch (#429)
## Summary - **Simpler train CLI**: Add `--config` flag and positional argument support for `sleap-nn train` - **Bug fix**: Fix device mismatch in `crop_bboxes` during top-down inference - **Bug fix**: Fix inference progress ending at 99% instead of 100% in GUI mode ## New Train CLI Usage ```bash # Positional config path (new!) sleap-nn train path/to/config.yaml # With --config flag (new!) sleap-nn train --config path/to/config.yaml # With Hydra overrides sleap-nn train config.yaml trainer_config.max_epochs=100 # Legacy still works sleap-nn train --config-dir /path/to/dir --config-name myrun ``` Also adds `rich-click` for styled CLI help output. ## Bug Fixes ### Device mismatch in crop_bboxes Fixed `RuntimeError: indices should be either on cpu or on the same device as the indexed tensor` when bboxes tensor is on GPU but images are on CPU during top-down inference. ### Progress ends at 99% Fixed inference progress bar ending at 99% instead of 100% in GUI mode. The throttled progress reporting (~4Hz) was skipping the final update when the last batch completed within 0.25s of the previous report. ## Test plan - [x] `pytest tests/inference/test_peak_finding.py` - all pass - [x] `pytest tests/inference/test_topdown.py` - all pass - [x] `pytest tests/test_cli.py` - all pass - [ ] Manual test of new CLI patterns - [ ] Verify inference progress shows 100% in SLEAP GUI 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 2051b9f commit 407861d

File tree

6 files changed

+148
-47
lines changed

6 files changed

+148
-47
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"jupyter",
4949
"jupyterlab",
5050
"pyzmq",
51+
"rich-click>=1.9.5",
5152
]
5253
dynamic = ["version", "readme"]
5354

sleap_nn/cli.py

Lines changed: 126 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
"""Unified CLI for SLEAP-NN using Click."""
1+
"""Unified CLI for SLEAP-NN using rich-click for styled output."""
22

3-
import click
3+
import rich_click as click
4+
from click import Command
45
from loguru import logger
56
from pathlib import Path
67
from omegaconf import OmegaConf, DictConfig
@@ -13,7 +14,36 @@
1314
from sleap_nn import __version__
1415
import hydra
1516
import sys
16-
from click import Command
17+
18+
# Rich-click configuration for styled help
19+
click.rich_click.TEXT_MARKUP = "markdown"
20+
click.rich_click.SHOW_ARGUMENTS = True
21+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
22+
click.rich_click.STYLE_ERRORS_SUGGESTION = "magenta italic"
23+
click.rich_click.ERRORS_EPILOGUE = (
24+
"Try 'sleap-nn [COMMAND] --help' for more information."
25+
)
26+
27+
28+
def is_config_path(arg: str) -> bool:
29+
"""Check if an argument looks like a config file path.
30+
31+
Returns True if the arg ends with .yaml or .yml.
32+
"""
33+
return arg.endswith(".yaml") or arg.endswith(".yml")
34+
35+
36+
def split_config_path(config_path: str) -> tuple:
37+
"""Split a full config path into (config_dir, config_name).
38+
39+
Args:
40+
config_path: Full path to a config file.
41+
42+
Returns:
43+
Tuple of (config_dir, config_name) where config_dir is an absolute path.
44+
"""
45+
path = Path(config_path).resolve()
46+
return path.parent.as_posix(), path.name
1747

1848

1949
def print_version(ctx, param, value):
@@ -66,38 +96,77 @@ def cli():
6696

6797

6898
def show_training_help():
69-
"""Display training help information."""
70-
help_text = """
71-
sleap-nn train — Train SLEAP models from a config YAML file.
72-
73-
Usage:
74-
sleap-nn train --config-dir <dir> --config-name <name> [overrides]
75-
76-
Common overrides:
77-
trainer_config.max_epochs=100
78-
trainer_config.batch_size=32
79-
80-
Examples:
81-
Start new run:
82-
sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun
83-
Resume 20 more epochs:
84-
sleap-nn train --config-dir /path/to/config_dir/ --config-name myrun \\
85-
trainer_config.resume_ckpt_path=<path/to/ckpt> \\
86-
trainer_config.max_epochs=20
87-
88-
Tips:
89-
- Use -m/--multirun for sweeps; outputs go under hydra.sweep.dir.
90-
- For Hydra flags and completion, use --hydra-help.
91-
92-
For a detailed list of all available config options, please refer to https://nn.sleap.ai/config/.
99+
"""Display training help information with rich formatting."""
100+
from rich.console import Console
101+
from rich.panel import Panel
102+
from rich.markdown import Markdown
103+
104+
console = Console()
105+
106+
help_md = """
107+
## Usage
108+
109+
```
110+
sleap-nn train <config.yaml> [overrides]
111+
sleap-nn train --config <path/to/config.yaml> [overrides]
112+
```
113+
114+
## Common Overrides
115+
116+
| Override | Description |
117+
|----------|-------------|
118+
| `trainer_config.max_epochs=100` | Set maximum training epochs |
119+
| `trainer_config.batch_size=32` | Set batch size |
120+
| `trainer_config.save_ckpt=true` | Enable checkpoint saving |
121+
122+
## Examples
123+
124+
**Start a new training run:**
125+
```bash
126+
sleap-nn train path/to/config.yaml
127+
sleap-nn train --config path/to/config.yaml
128+
```
129+
130+
**With overrides:**
131+
```bash
132+
sleap-nn train config.yaml trainer_config.max_epochs=100
133+
```
134+
135+
**Resume training:**
136+
```bash
137+
sleap-nn train config.yaml trainer_config.resume_ckpt_path=/path/to/ckpt
138+
```
139+
140+
**Legacy usage (still supported):**
141+
```bash
142+
sleap-nn train --config-dir /path/to/dir --config-name myrun
143+
```
144+
145+
## Tips
146+
147+
- Use `-m/--multirun` for sweeps; outputs go under `hydra.sweep.dir`
148+
- For Hydra flags and completion, use `--hydra-help`
149+
- Config documentation: https://nn.sleap.ai/config/
93150
"""
94-
click.echo(help_text)
151+
console.print(
152+
Panel(
153+
Markdown(help_md),
154+
title="[bold cyan]sleap-nn train[/bold cyan]",
155+
subtitle="Train SLEAP models from a config YAML file",
156+
border_style="cyan",
157+
)
158+
)
95159

96160

97161
@cli.command(cls=TrainCommand)
98-
@click.option("--config-name", "-c", type=str, help="Configuration file name")
99162
@click.option(
100-
"--config-dir", "-d", type=str, default=".", help="Configuration directory path"
163+
"--config",
164+
type=str,
165+
help="Path to configuration file (e.g., path/to/config.yaml)",
166+
)
167+
@click.option("--config-name", "-c", type=str, help="Configuration file name (legacy)")
168+
@click.option(
169+
"--config-dir", "-d", type=str, default=".", help="Configuration directory (legacy)"
101170
)
102171
@click.option(
103172
"--video-paths",
@@ -130,25 +199,43 @@ def show_training_help():
130199
'Example: --prefix-map "/old/server/path" "/new/local/path"',
131200
)
132201
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
133-
def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
202+
def train(
203+
config, config_name, config_dir, video_paths, video_path_map, prefix_map, overrides
204+
):
134205
"""Run training workflow with Hydra config overrides.
135206
136207
Examples:
137-
sleap-nn train --config-name myconfig --config-dir /path/to/config_dir/
208+
sleap-nn train path/to/config.yaml
209+
sleap-nn train --config path/to/config.yaml trainer_config.max_epochs=100
138210
sleap-nn train -c myconfig -d /path/to/config_dir/ trainer_config.max_epochs=100
139-
sleap-nn train -c myconfig -d /path/to/config_dir/ +experiment=new_model
140211
"""
141-
# Show help if no config name provided
142-
if not config_name:
212+
# Convert overrides to a mutable list
213+
overrides = list(overrides)
214+
215+
# Check if the first positional arg is a config path (not a Hydra override)
216+
config_from_positional = None
217+
if overrides and is_config_path(overrides[0]):
218+
config_from_positional = overrides.pop(0)
219+
220+
# Resolve config path with priority:
221+
# 1. Positional config path (e.g., sleap-nn train config.yaml)
222+
# 2. --config flag (e.g., sleap-nn train --config config.yaml)
223+
# 3. Legacy --config-dir/--config-name flags
224+
if config_from_positional:
225+
config_dir, config_name = split_config_path(config_from_positional)
226+
elif config:
227+
config_dir, config_name = split_config_path(config)
228+
elif config_name:
229+
config_dir = Path(config_dir).resolve().as_posix()
230+
else:
231+
# No config provided - show help
143232
show_training_help()
144233
return
145234

146-
# Initialize Hydra manually
147-
# resolve the path to the config directory (hydra expects absolute path)
148-
config_dir = Path(config_dir).resolve().as_posix()
235+
# Initialize Hydra manually (config_dir is already an absolute path)
149236
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
150237
# Compose config with overrides
151-
cfg = hydra.compose(config_name=config_name, overrides=list(overrides))
238+
cfg = hydra.compose(config_name=config_name, overrides=overrides)
152239

153240
# Validate config
154241
if not hasattr(cfg, "model_config") or not cfg.model_config:

sleap_nn/inference/peak_finding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def crop_bboxes(
7474
# Get crop centers from bboxes.
7575
# The bbox top-left is at index 0, with (x, y) coordinates.
7676
# We need the center of the crop (peak location), which is top-left + half_size.
77-
crop_x = (bboxes[:, 0, 0] + half_w).to(torch.long)
78-
crop_y = (bboxes[:, 0, 1] + half_h).to(torch.long)
77+
# Ensure bboxes are on the same device as images for index computation.
78+
bboxes_on_device = bboxes.to(device)
79+
crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
80+
crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
7981

8082
# Clamp indices to valid bounds to handle edge cases where centroids
8183
# might be at or beyond image boundaries.
@@ -86,7 +88,7 @@ def crop_bboxes(
8688
# Convert sample_inds to tensor if it's a list.
8789
if not isinstance(sample_inds, torch.Tensor):
8890
sample_inds = torch.tensor(sample_inds, device=device)
89-
sample_inds_long = sample_inds.to(torch.long)
91+
sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
9092
crops = patches[sample_inds_long, :, crop_y, crop_x]
9193
# Shape: (n_crops, channels, height, width)
9294

sleap_nn/inference/predictors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,16 @@ def _predict_generator_gui(
567567
print(json.dumps(progress_data), flush=True)
568568
last_report = time()
569569

570+
# Final progress emit to ensure 100% is shown
571+
elapsed = time() - start_time
572+
progress_data = {
573+
"n_processed": total_frames,
574+
"n_total": total_frames,
575+
"rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
576+
"eta": 0,
577+
}
578+
print(json.dumps(progress_data), flush=True)
579+
570580
def _predict_generator_rich(
571581
self, total_frames: int
572582
) -> Iterator[Dict[str, np.ndarray]]:

tests/test_cli.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_train_help(self):
5656
result = runner.invoke(cli, ["train", "--help"])
5757
assert result.exit_code == 0
5858
assert "sleap-nn train" in result.output
59-
assert "Usage:" in result.output
59+
assert "Usage" in result.output # Rich-click renders ## Usage as header
6060
assert "sleap.ai" in result.output
6161

6262
def test_train_no_config_shows_help(self):
@@ -65,7 +65,7 @@ def test_train_no_config_shows_help(self):
6565
result = runner.invoke(cli, ["train", "--config-dir", "."])
6666
assert result.exit_code == 0
6767
assert "sleap-nn train" in result.output
68-
assert "Usage:" in result.output
68+
assert "Usage" in result.output # Rich-click renders ## Usage as header
6969

7070

7171
class TestSystemCommand:
@@ -138,9 +138,8 @@ def test_show_training_help_output(self, capsys):
138138
show_training_help()
139139
captured = capsys.readouterr()
140140
assert "sleap-nn train" in captured.out
141-
assert "Usage:" in captured.out
142-
assert "--config-dir" in captured.out
143-
assert "--config-name" in captured.out
141+
assert "Usage" in captured.out # Rich-click renders ## Usage as header
142+
assert "config.yaml" in captured.out # New positional arg usage
144143
assert "sleap.ai" in captured.out
145144

146145

0 commit comments

Comments
 (0)