Skip to content

Commit acace23

Browse files
gitttt-1234claude
andcommitted
Fix TUI config generator reactive updates and head config format
- Add reactive YAML preview updates when form values change (pipeline, backbone, inputs, switches, selects) - Set default pipeline selection to Single Instance - Update head_configs to include all head types (matching expected format) - Add missing event handlers for Switch and Select widgets Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 07a7e4b commit acace23

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

sleap_nn/config_generator/generator.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -517,69 +517,75 @@ def _build_model_config(self) -> dict:
517517
}
518518

519519
def _build_head_config(self) -> dict:
520-
"""Build head configuration based on pipeline type."""
520+
"""Build head configuration based on pipeline type.
521+
522+
Returns config with all head types, setting the active one's config
523+
and null for all others (matches expected sleap-nn config format).
524+
"""
521525
base_confmap = {
522526
"sigma": self._sigma,
523527
"output_stride": self._output_stride,
524528
}
525529

530+
# Initialize all head types to null
531+
head_configs = {
532+
"single_instance": None,
533+
"centroid": None,
534+
"centered_instance": None,
535+
"bottomup": None,
536+
"multi_class_bottomup": None,
537+
"multi_class_topdown": None,
538+
}
539+
526540
if self._pipeline == "single_instance":
527-
return {"single_instance": {"confmaps": base_confmap}}
541+
head_configs["single_instance"] = {"confmaps": base_confmap}
528542

529543
elif self._pipeline == "centroid":
530-
cfg = {"confmaps": base_confmap.copy()}
531-
if self._anchor_part:
532-
cfg["anchor_part"] = self._anchor_part
533-
return {"centroid": cfg}
544+
cfg = {"confmaps": {**base_confmap, "anchor_part": self._anchor_part}}
545+
head_configs["centroid"] = cfg
534546

535547
elif self._pipeline == "centered_instance":
536-
cfg = {"confmaps": base_confmap.copy()}
537-
if self._anchor_part:
538-
cfg["anchor_part"] = self._anchor_part
539-
return {"centered_instance": cfg}
548+
cfg = {"confmaps": {**base_confmap, "anchor_part": self._anchor_part}}
549+
head_configs["centered_instance"] = cfg
540550

541551
elif self._pipeline == "bottomup":
542-
return {
543-
"bottomup": {
544-
"confmaps": {
545-
**base_confmap,
546-
"loss_weight": 1.0,
547-
},
548-
"pafs": {
549-
"sigma": 15.0, # PAFs use larger sigma
550-
"output_stride": max(self._output_stride, 4),
551-
"loss_weight": 1.0,
552-
},
553-
}
552+
head_configs["bottomup"] = {
553+
"confmaps": {
554+
**base_confmap,
555+
"loss_weight": 1.0,
556+
},
557+
"pafs": {
558+
"sigma": 15.0, # PAFs use larger sigma
559+
"output_stride": max(self._output_stride, 4),
560+
"loss_weight": 1.0,
561+
},
554562
}
555563

556564
elif self._pipeline == "multi_class_bottomup":
557-
return {
558-
"multi_class_bottomup": {
559-
"confmaps": {**base_confmap, "loss_weight": 1.0},
560-
"pafs": {
561-
"sigma": 15.0,
562-
"output_stride": max(self._output_stride, 4),
563-
"loss_weight": 1.0,
564-
},
565-
"class_vectors": {
566-
"num_fc_layers": 1,
567-
"num_fc_units": 64,
568-
},
569-
}
565+
head_configs["multi_class_bottomup"] = {
566+
"confmaps": {**base_confmap, "loss_weight": 1.0},
567+
"pafs": {
568+
"sigma": 15.0,
569+
"output_stride": max(self._output_stride, 4),
570+
"loss_weight": 1.0,
571+
},
572+
"class_vectors": {
573+
"num_fc_layers": 1,
574+
"num_fc_units": 64,
575+
},
570576
}
571577

572578
elif self._pipeline == "multi_class_topdown":
573-
cfg = {"confmaps": base_confmap.copy()}
574-
if self._anchor_part:
575-
cfg["anchor_part"] = self._anchor_part
576-
cfg["class_vectors"] = {
577-
"num_fc_layers": 1,
578-
"num_fc_units": 64,
579+
cfg = {
580+
"confmaps": {**base_confmap, "anchor_part": self._anchor_part},
581+
"class_vectors": {
582+
"num_fc_layers": 1,
583+
"num_fc_units": 64,
584+
},
579585
}
580-
return {"multi_class_topdown": cfg}
586+
head_configs["multi_class_topdown"] = cfg
581587

582-
return {}
588+
return head_configs
583589

584590
def _build_trainer_config(self) -> dict:
585591
"""Build trainer configuration section."""

sleap_nn/config_generator/tui/app.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def __init__(self, slp_path: str, **kwargs):
217217
218218
Args:
219219
slp_path: Path to the SLP file to analyze.
220+
**kwargs: Additional arguments passed to parent App class.
220221
"""
221222
super().__init__(**kwargs)
222223
self.slp_path = Path(slp_path)
@@ -280,7 +281,7 @@ def _compose_model_tab(self) -> ComposeResult:
280281
with VerticalScroll():
281282
yield Static("Pipeline Type", classes="section-title")
282283
with RadioSet(id="pipeline-select"):
283-
yield RadioButton("Single Instance", id="pipe-single")
284+
yield RadioButton("Single Instance", id="pipe-single", value=True)
284285
yield RadioButton("Top-Down: Centroid", id="pipe-centroid")
285286
yield RadioButton("Top-Down: Centered Instance", id="pipe-centered")
286287
yield RadioButton("Bottom-Up", id="pipe-bottomup")
@@ -733,13 +734,31 @@ def handle_radio_change(self, event: RadioSet.Changed) -> None:
733734
if event.radio_set.id == "backbone-select":
734735
self._update_memory_estimate()
735736

737+
# Update YAML preview when pipeline or backbone changes
738+
if event.radio_set.id in ["pipeline-select", "backbone-select"]:
739+
self._update_yaml_preview()
740+
736741
@on(Input.Changed)
737742
def handle_input_change(self, event: Input.Changed) -> None:
738743
"""Handle input field changes."""
739744
# Update memory estimate for relevant fields
740745
if event.input.id in ["batch-size-input", "sigma-input"]:
741746
self._update_memory_estimate()
742747

748+
# Update YAML preview for all input changes
749+
self._update_yaml_preview()
750+
751+
@on(Switch.Changed)
752+
def handle_switch_change(self, event: Switch.Changed) -> None:
753+
"""Handle switch changes."""
754+
self._update_yaml_preview()
755+
756+
@on(Select.Changed)
757+
def handle_select_change(self, event: Select.Changed) -> None:
758+
"""Handle select changes."""
759+
self._update_yaml_preview()
760+
self._update_memory_estimate()
761+
743762
def action_quit(self) -> None:
744763
"""Quit the application."""
745764
self.exit()

0 commit comments

Comments
 (0)