Skip to content

Commit 9f1e224

Browse files
committed
add strict toggle for arg parsing, fix argparsing tests
tests were broken if -v was in the middle of the test command like: python -m unittest discover -v tests Argparse doesn't like optional args being between positional ones. parse_known_args must be used to function under these circumstances Signed-off-by: Zen <[email protected]>
1 parent 0dd8151 commit 9f1e224

File tree

2 files changed

+54
-29
lines changed

2 files changed

+54
-29
lines changed

src/zenlib/util/main_funcs.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Functions to help with the main()
33
"""
44

5-
__version__ = "1.2.1"
5+
__version__ = "1.3.0"
66
__author__ = "desultory"
77

88

@@ -53,14 +53,19 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):
5353
return kwargs
5454

5555

56-
def process_args(argparser, logger=None):
56+
def process_args(argparser, logger=None, strict=False):
5757
"""Process argparser args, optionally configuring a logger."""
5858
from logging import Formatter
5959

6060
from zenlib.logging import ColorLognameFormatter
6161

62-
args = argparser.parse_args()
63-
if args.version and argparser.prog != "zenlib_test":
62+
if strict:
63+
args = argparser.parse_args()
64+
else:
65+
args, unknown = argparser.parse_known_args()
66+
args._unknown = unknown
67+
68+
if 'version' in args and args.version and argparser.prog != "zenlib_test":
6469
package = argparser.prog
6570
from importlib.metadata import version
6671

@@ -96,6 +101,9 @@ def process_args(argparser, logger=None):
96101
handler.setFormatter(formatter)
97102
logger.addHandler(handler)
98103

104+
if '__unknown' in args and args.__unknown:
105+
logger.warning(f"Unknown args: {args.__unknown}")
106+
99107
return args
100108

101109

@@ -113,37 +121,42 @@ def dump_args_for_autocomplete(args, test=False):
113121
exit(0)
114122

115123

116-
def get_args_n_logger(package, description: str, arguments=[], drop_default=False):
124+
def get_args_n_logger(package, description: str, arguments=[], drop_default=False, strict=False):
117125
"""Takes a package name and description
118126
If arguments are passed, they are added to argparser.
119127
Returns the parsed args and logger.
120128
"""
121-
arguments = get_base_args() + arguments
129+
all_arguments = get_base_args() + arguments
122130
from sys import argv
123-
124131
if "--dump_args" in argv:
125-
dump_args_for_autocomplete(arguments)
126-
127-
from argparse import Namespace
132+
dump_args_for_autocomplete(all_arguments)
128133

134+
from argparse import Namespace, ArgumentError
129135
argparser = init_argparser(prog=package, description=description)
130136
logger = init_logger(package)
131137

132-
for arg in arguments:
133-
dest = arg.pop("flags")
134-
if drop_default:
135-
arg["default"] = None
136-
argparser.add_argument(*dest, **arg)
138+
def add_args(args, argparser):
139+
for arg in args:
140+
dest = arg.pop("flags")
141+
if drop_default:
142+
arg["default"] = None
143+
try:
144+
argparser.add_argument(*dest, **arg)
145+
except ArgumentError:
146+
pass
147+
148+
add_args(arguments, argparser) # Add custom args first, then base args
149+
add_args(get_base_args(), argparser)
137150

138-
args = process_args(argparser, logger=logger)
151+
args = process_args(argparser, logger=logger, strict=strict)
139152

140-
if drop_default:
153+
if drop_default: # Remove defaults from args
141154
args = Namespace(**{name: value for name, value in vars(args).items() if value != argparser.get_default(name)})
142155

143156
return args, logger
144157

145158

146-
def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True):
159+
def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False):
147160
"""Like get_args_n_logger, but only returns kwargs"""
148-
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default)
161+
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default, strict=strict)
149162
return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base)

tests/test_main_funcs.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010

1111
def get_test_args():
12-
return [{"flags": ["arg1"], "action": "store", "nargs": "?"}, {"flags": ["arg2"], "action": "store", "nargs": "?"}]
12+
return [
13+
{"flags": ["args"], "action": "store", "nargs": "*"},
14+
]
1315

1416

1517
class TestMainFuncs(TestCase):
@@ -26,32 +28,42 @@ def test_named_init_argparser(self):
2628

2729
def _check_for_test_args(self, args):
2830
self.assertIsInstance(args, Namespace)
29-
if arg1 := getattr(args, "arg1", None):
30-
self.assertEqual(arg1, "discover")
31-
if arg2 := getattr(args, "arg2", None):
32-
self.assertEqual(arg2, "tests")
31+
32+
if 'discover' not in args.args:
33+
self.fail("discover not found in args")
34+
35+
search_strs = ['tests', './tests']
36+
search_locs = [args.args, args._unknown]
37+
38+
has_tests = False
39+
for search_str in search_strs:
40+
for search_loc in search_locs:
41+
if search_str in search_loc:
42+
has_tests = True
43+
if not has_tests:
44+
self.fail("tests not found in args")
3345

3446
def test_get_args_n_logger(self):
35-
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
47+
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
3648
self.assertIsInstance(logger, Logger)
3749
self._check_for_test_args(args)
3850

3951
def test_get_args_n_logger_no_default(self):
40-
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), drop_default=True)
52+
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False, drop_default=True)
4153
self.assertIsInstance(logger, Logger)
4254

4355
for arg in DEFAULT_ARGS:
4456
self.assertFalse(hasattr(args, arg))
4557

4658
def test_get_kwargs_from_args(self):
47-
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
59+
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
4860
self._check_for_test_args(args)
4961
kwargs = get_kwargs_from_args(args, logger)
5062
self.assertIsInstance(kwargs, dict)
5163
self.assertEqual(kwargs["logger"], logger)
5264

5365
def test_not_drop_base(self):
54-
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args())
66+
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False)
5567
self._check_for_test_args(args)
5668
kwargs = get_kwargs_from_args(args, logger, drop_base=False)
5769
self.assertIsInstance(kwargs, dict)
@@ -60,7 +72,7 @@ def test_not_drop_base(self):
6072
self.assertTrue(arg in kwargs)
6173

6274
def test_get_kwargs(self):
63-
kwargs = get_kwargs("zenlib_test", "test description", get_test_args())
75+
kwargs = get_kwargs("zenlib_test", "test description", get_test_args(), strict=False)
6476
self.assertIsInstance(kwargs, dict)
6577
self.assertTrue("logger" in kwargs)
6678

0 commit comments

Comments
 (0)