Skip to content

Commit cc03bf6

Browse files
gitttt-1234claude
andcommitted
Add config generator with TUI for training configuration
Implements a config generator tool that helps users create sleap-nn training configurations from SLP files. Features include: - Python API with fluent builder pattern (ConfigGenerator class) - Auto-configuration based on dataset analysis - Smart pipeline/backbone/parameter recommendations - Memory estimation for GPU and CPU - Interactive TUI with 4 tabs (Data, Model, Training, Export) - CLI command: sleap-nn config <slp_file> Usage examples: # Quick auto-config from sleap_nn.config_generator import generate_config generate_config("labels.slp", "config.yaml") # Fluent API ConfigGenerator.from_slp("labels.slp").auto(view="top").batch_size(8).save("config.yaml") # CLI sleap-nn config labels.slp -o config.yaml sleap-nn config labels.slp --interactive Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 699a8e2 commit cc03bf6

File tree

12 files changed

+3019
-0
lines changed

12 files changed

+3019
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"jupyterlab",
5050
"pyzmq",
5151
"rich-click>=1.9.5",
52+
"textual>=0.40.0",
5253
]
5354
dynamic = ["version", "readme"]
5455

sleap_nn/cli.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,112 @@ def system():
738738
print_system_info()
739739

740740

741+
@cli.command()
742+
@click.argument("slp_path", type=click.Path(exists=True))
743+
@click.option("-o", "--output", type=str, help="Output config YAML path")
744+
@click.option(
745+
"-v",
746+
"--view",
747+
type=click.Choice(["side", "top"]),
748+
help="Camera view type (affects rotation augmentation)",
749+
)
750+
@click.option(
751+
"--pipeline",
752+
type=click.Choice(
753+
[
754+
"single_instance",
755+
"centroid",
756+
"centered_instance",
757+
"bottomup",
758+
"multi_class_bottomup",
759+
"multi_class_topdown",
760+
]
761+
),
762+
help="Override pipeline type",
763+
)
764+
@click.option("--backbone", type=str, help="Override backbone architecture")
765+
@click.option("--batch-size", type=int, help="Override batch size")
766+
@click.option(
767+
"-i", "--interactive", is_flag=True, help="Launch interactive TUI mode"
768+
)
769+
@click.option("--analyze-only", is_flag=True, help="Only show dataset analysis")
770+
@click.option("--show-yaml", is_flag=True, help="Print YAML to stdout")
771+
def config(
772+
slp_path, output, view, pipeline, backbone, batch_size, interactive, analyze_only, show_yaml
773+
):
774+
"""Generate training configuration from SLP file.
775+
776+
Analyzes your labeled data and generates an optimized training
777+
configuration with sensible defaults.
778+
779+
\b
780+
Examples:
781+
# Auto-generate config
782+
sleap-nn config labels.slp -o config.yaml
783+
784+
# Specify camera view for better augmentation defaults
785+
sleap-nn config labels.slp -o config.yaml --view top
786+
787+
# Launch interactive TUI
788+
sleap-nn config labels.slp --interactive
789+
790+
# Override specific parameters
791+
sleap-nn config labels.slp -o config.yaml --batch-size 8
792+
793+
# Just analyze the data
794+
sleap-nn config labels.slp --analyze-only
795+
"""
796+
from rich.console import Console
797+
798+
from sleap_nn.config_generator import ConfigGenerator, analyze_slp
799+
800+
console = Console()
801+
802+
if analyze_only:
803+
stats = analyze_slp(slp_path)
804+
console.print(str(stats))
805+
806+
# Also show recommendation
807+
from sleap_nn.config_generator import recommend_config
808+
809+
rec = recommend_config(stats)
810+
console.print("\n[bold]Recommendation:[/bold]")
811+
console.print(f" Pipeline: {rec.pipeline.recommended}")
812+
console.print(f" Reason: {rec.pipeline.reason}")
813+
if rec.pipeline.warnings:
814+
console.print("\n[yellow]Warnings:[/yellow]")
815+
for w in rec.pipeline.warnings:
816+
console.print(f" * {w}")
817+
return
818+
819+
if interactive:
820+
from sleap_nn.config_generator.tui import launch_tui
821+
822+
launch_tui(slp_path)
823+
return
824+
825+
# Generate config
826+
gen = ConfigGenerator.from_slp(slp_path).auto(view=view)
827+
828+
if pipeline:
829+
gen.pipeline(pipeline)
830+
if backbone:
831+
gen.backbone(backbone)
832+
if batch_size:
833+
gen.batch_size(batch_size)
834+
835+
# Print summary
836+
console.print(gen.summary())
837+
838+
if output:
839+
gen.save(output)
840+
console.print(f"\n[green]Config saved to: {output}[/green]")
841+
842+
if show_yaml or not output:
843+
console.print("\n[bold]YAML Configuration:[/bold]")
844+
console.print(gen.to_yaml())
845+
846+
741847
cli.add_command(export_command)
742848
cli.add_command(predict_command)
743849

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Config generator for SLEAP-NN training configurations.
2+
3+
This module provides tools for automatically generating training configurations
4+
from SLP label files with sensible defaults based on data analysis.
5+
6+
Quick Start
7+
-----------
8+
One-liner to generate config::
9+
10+
from sleap_nn.config_generator import generate_config
11+
generate_config("labels.slp", "config.yaml")
12+
13+
With customization::
14+
15+
from sleap_nn.config_generator import ConfigGenerator
16+
config = (
17+
ConfigGenerator.from_slp("labels.slp")
18+
.auto(view="top")
19+
.batch_size(8)
20+
.save("config.yaml")
21+
)
22+
23+
Just analyze the data::
24+
25+
from sleap_nn.config_generator import analyze_slp
26+
stats = analyze_slp("labels.slp")
27+
print(stats)
28+
29+
Interactive TUI::
30+
31+
sleap-nn config labels.slp --interactive
32+
"""
33+
34+
from omegaconf import DictConfig
35+
36+
from sleap_nn.config_generator.analyzer import DatasetStats, ViewType, analyze_slp
37+
from sleap_nn.config_generator.generator import ConfigGenerator
38+
from sleap_nn.config_generator.memory import MemoryEstimate, estimate_memory
39+
from sleap_nn.config_generator.recommender import (
40+
BackboneType,
41+
ConfigRecommendation,
42+
PipelineRecommendation,
43+
PipelineType,
44+
recommend_config,
45+
recommend_pipeline,
46+
)
47+
48+
__all__ = [
49+
# Core classes
50+
"ConfigGenerator",
51+
"DatasetStats",
52+
"MemoryEstimate",
53+
"PipelineRecommendation",
54+
"ConfigRecommendation",
55+
# Type aliases
56+
"PipelineType",
57+
"BackboneType",
58+
"ViewType",
59+
# Functions
60+
"analyze_slp",
61+
"recommend_pipeline",
62+
"recommend_config",
63+
"estimate_memory",
64+
"generate_config",
65+
]
66+
67+
68+
def generate_config(
69+
slp_path: str,
70+
output_path: str = None,
71+
*,
72+
view: str = None,
73+
pipeline: str = None,
74+
backbone: str = None,
75+
batch_size: int = None,
76+
**kwargs,
77+
) -> DictConfig:
78+
"""Generate a training configuration from an SLP file.
79+
80+
This is a convenience function for quick config generation.
81+
For more control, use the ConfigGenerator class directly.
82+
83+
Args:
84+
slp_path: Path to the .slp label file.
85+
output_path: Optional path to save YAML config.
86+
view: Camera view type ("side" or "top") for augmentation defaults.
87+
pipeline: Override auto-detected pipeline type.
88+
backbone: Override auto-detected backbone.
89+
batch_size: Override auto-detected batch size.
90+
**kwargs: Additional parameters to override.
91+
92+
Returns:
93+
Training configuration as OmegaConf DictConfig.
94+
95+
Examples:
96+
Auto-generate and save::
97+
98+
generate_config("labels.slp", "config.yaml")
99+
100+
Auto-generate with view hint::
101+
102+
config = generate_config("labels.slp", view="top")
103+
104+
With overrides::
105+
106+
config = generate_config(
107+
"labels.slp",
108+
"config.yaml",
109+
pipeline="bottomup",
110+
batch_size=8
111+
)
112+
"""
113+
gen = ConfigGenerator.from_slp(slp_path).auto(view=view)
114+
115+
# Apply overrides
116+
if pipeline:
117+
gen.pipeline(pipeline)
118+
if backbone:
119+
gen.backbone(backbone)
120+
if batch_size:
121+
gen.batch_size(batch_size)
122+
123+
# Apply any additional kwargs
124+
for key, value in kwargs.items():
125+
if value is not None and hasattr(gen, key):
126+
getattr(gen, key)(value)
127+
128+
config = gen.build()
129+
130+
if output_path:
131+
gen.save(output_path)
132+
133+
return config

0 commit comments

Comments
 (0)