Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,4 @@ jobs:
venv/bin/pip install .
- name: Run unit tests
run: |
cd tests
../venv/bin/python -m unittest -v
venv/bin/python -m unittest discover -v tests
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "zenlib"
version = "2.3.0"
version = "2.3.1"
authors = [
{ name="Desultory", email="[email protected]" },
]
Expand Down
51 changes: 32 additions & 19 deletions src/zenlib/util/main_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Functions to help with the main()
"""

__version__ = "1.2.1"
__version__ = "1.3.0"
__author__ = "desultory"


Expand Down Expand Up @@ -53,14 +53,19 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):
return kwargs


def process_args(argparser, logger=None):
def process_args(argparser, logger=None, strict=False):
"""Process argparser args, optionally configuring a logger."""
from logging import Formatter

from zenlib.logging import ColorLognameFormatter

args = argparser.parse_args()
if args.version and argparser.prog != "zenlib_test":
if strict:
args = argparser.parse_args()
else:
args, unknown = argparser.parse_known_args()
args._unknown = unknown

if 'version' in args and args.version and argparser.prog != "zenlib_test":
package = argparser.prog
from importlib.metadata import version

Expand Down Expand Up @@ -96,6 +101,9 @@ def process_args(argparser, logger=None):
handler.setFormatter(formatter)
logger.addHandler(handler)

if '__unknown' in args and args.__unknown:
logger.warning(f"Unknown args: {args.__unknown}")

return args


Expand All @@ -113,37 +121,42 @@ def dump_args_for_autocomplete(args, test=False):
exit(0)


def get_args_n_logger(package, description: str, arguments=[], drop_default=False):
def get_args_n_logger(package, description: str, arguments=[], drop_default=False, strict=False):
"""Takes a package name and description
If arguments are passed, they are added to argparser.
Returns the parsed args and logger.
"""
arguments = get_base_args() + arguments
all_arguments = get_base_args() + arguments
from sys import argv

if "--dump_args" in argv:
dump_args_for_autocomplete(arguments)

from argparse import Namespace
dump_args_for_autocomplete(all_arguments)

from argparse import Namespace, ArgumentError
argparser = init_argparser(prog=package, description=description)
logger = init_logger(package)

for arg in arguments:
dest = arg.pop("flags")
if drop_default:
arg["default"] = None
argparser.add_argument(*dest, **arg)
def add_args(args, argparser):
for arg in args:
dest = arg.pop("flags")
if drop_default:
arg["default"] = None
try:
argparser.add_argument(*dest, **arg)
except ArgumentError:
pass

add_args(arguments, argparser) # Add custom args first, then base args
add_args(get_base_args(), argparser)

args = process_args(argparser, logger=logger)
args = process_args(argparser, logger=logger, strict=strict)

if drop_default:
if drop_default: # Remove defaults from args
args = Namespace(**{name: value for name, value in vars(args).items() if value != argparser.get_default(name)})

return args, logger


def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True):
def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False):
"""Like get_args_n_logger, but only returns kwargs"""
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default)
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default, strict=strict)
return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base)
32 changes: 22 additions & 10 deletions tests/test_main_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def get_test_args():
return [{"flags": ["arg1"], "action": "store", "nargs": "?"}, {"flags": ["arg2"], "action": "store", "nargs": "?"}]
return [
{"flags": ["args"], "action": "store", "nargs": "*"},
]


class TestMainFuncs(TestCase):
Expand All @@ -26,32 +28,42 @@ def test_named_init_argparser(self):

def _check_for_test_args(self, args):
self.assertIsInstance(args, Namespace)
if arg1 := getattr(args, "arg1", None):
self.assertEqual(arg1, "discover")
if arg2 := getattr(args, "arg2", None):
self.assertEqual(arg2, "tests")

if 'discover' not in args.args:
self.fail("discover not found in args")

search_strs = ['tests', './tests']
search_locs = [args.args, args._unknown]

has_tests = False
for search_str in search_strs:
for search_loc in search_locs:
if search_str in search_loc:
has_tests = True
if not has_tests:
self.fail("tests not found in args")

def test_get_args_n_logger(self):
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
self.assertIsInstance(logger, Logger)
self._check_for_test_args(args)

def test_get_args_n_logger_no_default(self):
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), drop_default=True)
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False, drop_default=True)
self.assertIsInstance(logger, Logger)

for arg in DEFAULT_ARGS:
self.assertFalse(hasattr(args, arg))

def test_get_kwargs_from_args(self):
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
self._check_for_test_args(args)
kwargs = get_kwargs_from_args(args, logger)
self.assertIsInstance(kwargs, dict)
self.assertEqual(kwargs["logger"], logger)

def test_not_drop_base(self):
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
self._check_for_test_args(args)
kwargs = get_kwargs_from_args(args, logger, drop_base=False)
self.assertIsInstance(kwargs, dict)
Expand All @@ -60,7 +72,7 @@ def test_not_drop_base(self):
self.assertTrue(arg in kwargs)

def test_get_kwargs(self):
kwargs = get_kwargs("zenlib_test", "test description", get_test_args())
kwargs = get_kwargs("zenlib_test", "test description", get_test_args(), strict=False)
self.assertIsInstance(kwargs, dict)
self.assertTrue("logger" in kwargs)

Expand Down
Loading