diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 514be0a..a6645a6 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -31,5 +31,4 @@ jobs: venv/bin/pip install . - name: Run unit tests run: | - cd tests - ../venv/bin/python -m unittest -v + venv/bin/python -m unittest discover -v tests diff --git a/pyproject.toml b/pyproject.toml index 6ed562c..accb7f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "zenlib" -version = "2.3.0" +version = "2.3.1" authors = [ { name="Desultory", email="dev@pyl.onl" }, ] diff --git a/src/zenlib/util/main_funcs.py b/src/zenlib/util/main_funcs.py index 309106d..b95860e 100644 --- a/src/zenlib/util/main_funcs.py +++ b/src/zenlib/util/main_funcs.py @@ -2,7 +2,7 @@ Functions to help with the main() """ -__version__ = "1.2.1" +__version__ = "1.3.0" __author__ = "desultory" @@ -53,14 +53,19 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True): return kwargs -def process_args(argparser, logger=None): +def process_args(argparser, logger=None, strict=False): """Process argparser args, optionally configuring a logger.""" from logging import Formatter from zenlib.logging import ColorLognameFormatter - args = argparser.parse_args() - if args.version and argparser.prog != "zenlib_test": + if strict: + args = argparser.parse_args() + else: + args, unknown = argparser.parse_known_args() + args._unknown = unknown + + if 'version' in args and args.version and argparser.prog != "zenlib_test": package = argparser.prog from importlib.metadata import version @@ -96,6 +101,9 @@ def process_args(argparser, logger=None): handler.setFormatter(formatter) logger.addHandler(handler) + if '__unknown' in args and args.__unknown: + logger.warning(f"Unknown args: {args.__unknown}") + return args @@ -113,37 +121,42 @@ def dump_args_for_autocomplete(args, test=False): exit(0) -def get_args_n_logger(package, description: str, arguments=[], drop_default=False): +def get_args_n_logger(package, description: str, arguments=[], drop_default=False, strict=False): """Takes a package name and description If arguments are passed, they are added to argparser. Returns the parsed args and logger. """ - arguments = get_base_args() + arguments + all_arguments = get_base_args() + arguments from sys import argv - if "--dump_args" in argv: - dump_args_for_autocomplete(arguments) - - from argparse import Namespace + dump_args_for_autocomplete(all_arguments) + from argparse import Namespace, ArgumentError argparser = init_argparser(prog=package, description=description) logger = init_logger(package) - for arg in arguments: - dest = arg.pop("flags") - if drop_default: - arg["default"] = None - argparser.add_argument(*dest, **arg) + def add_args(args, argparser): + for arg in args: + dest = arg.pop("flags") + if drop_default: + arg["default"] = None + try: + argparser.add_argument(*dest, **arg) + except ArgumentError: + pass + + add_args(arguments, argparser) # Add custom args first, then base args + add_args(get_base_args(), argparser) - args = process_args(argparser, logger=logger) + args = process_args(argparser, logger=logger, strict=strict) - if drop_default: + if drop_default: # Remove defaults from args args = Namespace(**{name: value for name, value in vars(args).items() if value != argparser.get_default(name)}) return args, logger -def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True): +def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False): """Like get_args_n_logger, but only returns kwargs""" - args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default) + args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default, strict=strict) return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base) diff --git a/tests/test_main_funcs.py b/tests/test_main_funcs.py index e1b6424..2785cb9 100644 --- a/tests/test_main_funcs.py +++ b/tests/test_main_funcs.py @@ -9,7 +9,9 @@ def get_test_args(): - return [{"flags": ["arg1"], "action": "store", "nargs": "?"}, {"flags": ["arg2"], "action": "store", "nargs": "?"}] + return [ + {"flags": ["args"], "action": "store", "nargs": "*"}, + ] class TestMainFuncs(TestCase): @@ -26,32 +28,42 @@ def test_named_init_argparser(self): def _check_for_test_args(self, args): self.assertIsInstance(args, Namespace) - if arg1 := getattr(args, "arg1", None): - self.assertEqual(arg1, "discover") - if arg2 := getattr(args, "arg2", None): - self.assertEqual(arg2, "tests") + + if 'discover' not in args.args: + self.fail("discover not found in args") + + search_strs = ['tests', './tests'] + search_locs = [args.args, args._unknown] + + has_tests = False + for search_str in search_strs: + for search_loc in search_locs: + if search_str in search_loc: + has_tests = True + if not has_tests: + self.fail("tests not found in args") def test_get_args_n_logger(self): - args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args()) + args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False) self.assertIsInstance(logger, Logger) self._check_for_test_args(args) def test_get_args_n_logger_no_default(self): - args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), drop_default=True) + args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False, drop_default=True) self.assertIsInstance(logger, Logger) for arg in DEFAULT_ARGS: self.assertFalse(hasattr(args, arg)) def test_get_kwargs_from_args(self): - args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args()) + args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False) self._check_for_test_args(args) kwargs = get_kwargs_from_args(args, logger) self.assertIsInstance(kwargs, dict) self.assertEqual(kwargs["logger"], logger) def test_not_drop_base(self): - args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args()) + args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False) self._check_for_test_args(args) kwargs = get_kwargs_from_args(args, logger, drop_base=False) self.assertIsInstance(kwargs, dict) @@ -60,7 +72,7 @@ def test_not_drop_base(self): self.assertTrue(arg in kwargs) def test_get_kwargs(self): - kwargs = get_kwargs("zenlib_test", "test description", get_test_args()) + kwargs = get_kwargs("zenlib_test", "test description", get_test_args(), strict=False) self.assertIsInstance(kwargs, dict) self.assertTrue("logger" in kwargs)