|
| 1 | +import subprocess |
1 | 2 | from omegaconf import DictConfig, OmegaConf |
2 | 3 | from pathlib import Path |
3 | 4 | import copy |
@@ -638,3 +639,120 @@ def test_main(sample_cfg): |
638 | 639 | invalid_cfg.model_config = None |
639 | 640 | with pytest.raises(SystemExit): |
640 | 641 | main(invalid_cfg) |
| 642 | + |
| 643 | + |
| 644 | +def test_main_cli(sample_cfg, tmp_path): |
| 645 | + # Test that train cli handles empty argument gracefully |
| 646 | + cmd = [ |
| 647 | + "uv", |
| 648 | + "run", |
| 649 | + "sleap-nn-train", |
| 650 | + ] |
| 651 | + result = subprocess.run( |
| 652 | + cmd, |
| 653 | + capture_output=True, |
| 654 | + text=True, |
| 655 | + ) |
| 656 | + # Exit code should be 2 |
| 657 | + assert result.returncode == 2 |
| 658 | + assert "No model config found" in result.stdout # Should tell user what is wrong |
| 659 | + assert "--help" in result.stdout # should suggest using --help |
| 660 | + |
| 661 | + cmd = [ |
| 662 | + "uv", |
| 663 | + "run", |
| 664 | + "sleap-nn-train", |
| 665 | + "--help", |
| 666 | + ] |
| 667 | + result = subprocess.run( |
| 668 | + cmd, |
| 669 | + capture_output=True, |
| 670 | + text=True, |
| 671 | + ) |
| 672 | + # Exit code should be 0 |
| 673 | + assert result.returncode == 0 |
| 674 | + assert "Usage" in result.stdout # Should show usage information |
| 675 | + assert "sleap.ai" in result.stdout # should point user to read the documents |
| 676 | + |
| 677 | + # Now to test overrides and defaults |
| 678 | + |
| 679 | + sample_cfg.trainer_config.trainer_accelerator = ( |
| 680 | + "cpu" if torch.mps.is_available() else "auto" |
| 681 | + ) |
| 682 | + OmegaConf.save(sample_cfg, (Path(tmp_path) / "test_config.yaml").as_posix()) |
| 683 | + |
| 684 | + cmd = [ |
| 685 | + "uv", |
| 686 | + "run", |
| 687 | + "sleap-nn-train", |
| 688 | + "--config-dir", |
| 689 | + f"{tmp_path}", |
| 690 | + "--config-name", |
| 691 | + "test_config", |
| 692 | + ] |
| 693 | + result = subprocess.run( |
| 694 | + cmd, |
| 695 | + capture_output=True, |
| 696 | + text=True, |
| 697 | + ) |
| 698 | + # Exit code should be 0 |
| 699 | + assert result.returncode == 0 |
| 700 | + # Try to parse the output back into the yaml, truncate the beginning (starts with "data_config") |
| 701 | + # Only keep stdout starting from "data_config" |
| 702 | + stripped_out = result.stdout[result.stdout.find("data_config") :].strip() |
| 703 | + stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19] |
| 704 | + output = OmegaConf.create(stripped_out) |
| 705 | + assert output == sample_cfg |
| 706 | + |
| 707 | + # config override should work |
| 708 | + sample_cfg.trainer_config.max_epochs = 2 |
| 709 | + sample_cfg.data_config.preprocessing.scale = 1.2 |
| 710 | + cmd = [ |
| 711 | + "uv", |
| 712 | + "run", |
| 713 | + "sleap-nn-train", |
| 714 | + "--config-dir", |
| 715 | + f"{tmp_path}", |
| 716 | + "--config-name", |
| 717 | + "test_config", |
| 718 | + "trainer_config.max_epochs=2", |
| 719 | + "data_config.preprocessing.scale=1.2", |
| 720 | + ] |
| 721 | + result = subprocess.run( |
| 722 | + cmd, |
| 723 | + capture_output=True, |
| 724 | + text=True, |
| 725 | + ) |
| 726 | + # Exit code should be 0 |
| 727 | + assert result.returncode == 0 |
| 728 | + stripped_out = result.stdout[result.stdout.find("data_config") :].strip() |
| 729 | + stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19] |
| 730 | + output = OmegaConf.create(stripped_out) |
| 731 | + assert output == sample_cfg |
| 732 | + |
| 733 | + # Test CLI with '--' to separate config overrides from positional args |
| 734 | + cmd = [ |
| 735 | + "uv", |
| 736 | + "run", |
| 737 | + "sleap-nn-train", |
| 738 | + "--config-dir", |
| 739 | + f"{tmp_path}", |
| 740 | + "--config-name", |
| 741 | + "test_config", |
| 742 | + "--", |
| 743 | + "trainer_config.max_epochs=3", |
| 744 | + "data_config.preprocessing.scale=1.5", |
| 745 | + ] |
| 746 | + result = subprocess.run( |
| 747 | + cmd, |
| 748 | + capture_output=True, |
| 749 | + text=True, |
| 750 | + ) |
| 751 | + # Exit code should be 0 |
| 752 | + assert result.returncode == 0 |
| 753 | + # Check that overrides are applied |
| 754 | + stripped_out = result.stdout[result.stdout.find("data_config") :].strip() |
| 755 | + stripped_out = stripped_out[: stripped_out.find(" | INFO") - 19] |
| 756 | + output = OmegaConf.create(stripped_out) |
| 757 | + assert output.trainer_config.max_epochs == 3 |
| 758 | + assert output.data_config.preprocessing.scale == 1.5 |
0 commit comments