Skip to content

Commit fef5b64

Browse files
authored
More robust help injection in CLI (#303)
Currently `sleap-nn-train` does not take commandline overrides because of a bug with how hydra handles overrides, this PR solves it.
1 parent 66f2bd9 commit fef5b64

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

sleap_nn/train.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,27 @@ def outer(fn):
493493
def wrapper(*args, **kwargs):
494494
joined = " ".join(sys.argv[1:])
495495
if "hydra/help=" not in joined and "hydra.help." not in joined:
496-
sys.argv.insert(1, f"hydra/help={name}")
496+
if "--" in sys.argv:
497+
dashdash = sys.argv.index("--")
498+
else:
499+
dashdash = len(sys.argv)
500+
501+
beginning_of_overrides_section = -1
502+
for idx, arg in enumerate(sys.argv[1:dashdash], start=1):
503+
if (
504+
not arg.startswith("-")
505+
and idx < dashdash
506+
and ("=" in arg or arg.startswith("+"))
507+
):
508+
beginning_of_overrides_section = idx
509+
break
510+
if beginning_of_overrides_section != -1:
511+
sys.argv.insert(
512+
beginning_of_overrides_section, f"hydra/help={name}"
513+
)
514+
else:
515+
sys.argv.insert(dashdash, f"hydra/help={name}")
516+
# breakpoint()
497517
return fn(*args, **kwargs)
498518

499519
return wrapper

tests/test_train.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import subprocess
12
from omegaconf import DictConfig, OmegaConf
23
from pathlib import Path
34
import copy
@@ -638,3 +639,120 @@ def test_main(sample_cfg):
638639
invalid_cfg.model_config = None
639640
with pytest.raises(SystemExit):
640641
main(invalid_cfg)
642+
643+
644+
def test_main_cli(sample_cfg, tmp_path):
645+
# Test that train cli handles empty argument gracefully
646+
cmd = [
647+
"uv",
648+
"run",
649+
"sleap-nn-train",
650+
]
651+
result = subprocess.run(
652+
cmd,
653+
capture_output=True,
654+
text=True,
655+
)
656+
# Exit code should be 2
657+
assert result.returncode == 2
658+
assert "No model config found" in result.stdout # Should tell user what is wrong
659+
assert "--help" in result.stdout # should suggest using --help
660+
661+
cmd = [
662+
"uv",
663+
"run",
664+
"sleap-nn-train",
665+
"--help",
666+
]
667+
result = subprocess.run(
668+
cmd,
669+
capture_output=True,
670+
text=True,
671+
)
672+
# Exit code should be 0
673+
assert result.returncode == 0
674+
assert "Usage" in result.stdout # Should show usage information
675+
assert "sleap.ai" in result.stdout # should point user to read the documents
676+
677+
# Now to test overrides and defaults
678+
679+
sample_cfg.trainer_config.trainer_accelerator = (
680+
"cpu" if torch.mps.is_available() else "auto"
681+
)
682+
OmegaConf.save(sample_cfg, (Path(tmp_path) / "test_config.yaml").as_posix())
683+
684+
cmd = [
685+
"uv",
686+
"run",
687+
"sleap-nn-train",
688+
"--config-dir",
689+
f"{tmp_path}",
690+
"--config-name",
691+
"test_config",
692+
]
693+
result = subprocess.run(
694+
cmd,
695+
capture_output=True,
696+
text=True,
697+
)
698+
# Exit code should be 0
699+
assert result.returncode == 0
700+
# Try to parse the output back into the yaml, truncate the beginning (starts with "data_config")
701+
# Only keep stdout starting from "data_config"
702+
stripped_out = result.stdout[result.stdout.find("data_config") :].strip()
703+
stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19]
704+
output = OmegaConf.create(stripped_out)
705+
assert output == sample_cfg
706+
707+
# config override should work
708+
sample_cfg.trainer_config.max_epochs = 2
709+
sample_cfg.data_config.preprocessing.scale = 1.2
710+
cmd = [
711+
"uv",
712+
"run",
713+
"sleap-nn-train",
714+
"--config-dir",
715+
f"{tmp_path}",
716+
"--config-name",
717+
"test_config",
718+
"trainer_config.max_epochs=2",
719+
"data_config.preprocessing.scale=1.2",
720+
]
721+
result = subprocess.run(
722+
cmd,
723+
capture_output=True,
724+
text=True,
725+
)
726+
# Exit code should be 0
727+
assert result.returncode == 0
728+
stripped_out = result.stdout[result.stdout.find("data_config") :].strip()
729+
stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19]
730+
output = OmegaConf.create(stripped_out)
731+
assert output == sample_cfg
732+
733+
# Test CLI with '--' to separate config overrides from positional args
734+
cmd = [
735+
"uv",
736+
"run",
737+
"sleap-nn-train",
738+
"--config-dir",
739+
f"{tmp_path}",
740+
"--config-name",
741+
"test_config",
742+
"--",
743+
"trainer_config.max_epochs=3",
744+
"data_config.preprocessing.scale=1.5",
745+
]
746+
result = subprocess.run(
747+
cmd,
748+
capture_output=True,
749+
text=True,
750+
)
751+
# Exit code should be 0
752+
assert result.returncode == 0
753+
# Check that overrides are applied
754+
stripped_out = result.stdout[result.stdout.find("data_config") :].strip()
755+
stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19]
756+
output = OmegaConf.create(stripped_out)
757+
assert output.trainer_config.max_epochs == 3
758+
assert output.data_config.preprocessing.scale == 1.5

0 commit comments

Comments
 (0)