Skip to content

Commit 51b1038

Browse files
committed
add testing for main_funcs
Signed-off-by: Zen <[email protected]>
1 parent 8ba6552 commit 51b1038

File tree

2 files changed

+139
-43
lines changed

2 files changed

+139
-43
lines changed

src/zenlib/util/main_funcs.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,47 @@
22
Functions to help with the main()
33
"""
44

5-
__version__ = '1.1.1'
6-
__author__ = 'desultory'
5+
__version__ = "1.2.0"
6+
__author__ = "desultory"
77

88

9-
BASE_ARGS = [{'flags': ['-d', '--debug'], 'action': 'store_true', 'help': 'enable debug mode (level 10)'},
10-
{'flags': ['-dd', '--trace'], 'action': 'store_true', 'help': 'enable trace debug mode (level 5)'},
11-
{'flags': ['-v', '--version'], 'action': 'store_true', 'help': 'print the version and exit'},
12-
{'flags': ['--log-file'], 'type': str, 'help': 'set the path to the log file'},
13-
{'flags': ['--log-level'], 'type': str, 'help': 'set the log level'},
14-
{'flags': ['--log-time'], 'action': 'store_true', 'help': 'enable log timestamps'},
15-
{'flags': ['--no-log-color'], 'action': 'store_true', 'help': 'disable log color'}]
9+
def get_base_args():
10+
return [
11+
{"flags": ["-d", "--debug"], "action": "store_true", "help": "enable debug mode (level 10)"},
12+
{"flags": ["-dd", "--trace"], "action": "store_true", "help": "enable trace debug mode (level 5)"},
13+
{"flags": ["-v", "--version"], "action": "store_true", "help": "print the version and exit"},
14+
{"flags": ["--log-file"], "type": str, "help": "set the path to the log file"},
15+
{"flags": ["--log-level"], "type": str, "help": "set the log level"},
16+
{"flags": ["--log-time"], "action": "store_true", "help": "enable log timestamps"},
17+
{"flags": ["--no-log-color"], "action": "store_true", "help": "disable log color"},
18+
]
19+
20+
21+
def init_logger(name=None):
22+
"""Initialize the logger with a name"""
23+
from logging import getLogger
24+
25+
name = name or __name__
26+
return getLogger(name)
27+
28+
29+
def init_argparser(prog=None, description=None):
30+
"""Initialize an argparser with common options."""
31+
from argparse import ArgumentParser
32+
33+
argparser = ArgumentParser(prog=prog, description=description)
34+
return argparser
1635

1736

1837
def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):
19-
""" Get kwargs from argparser args.
20-
Drop base doesn't add init_argparser args. """
38+
"""Get kwargs from argparser args.
39+
Drop base doesn't add init_argparser args."""
2140
kwargs = base_kwargs.copy()
2241
if logger is not None:
23-
kwargs['logger'] = logger
42+
kwargs["logger"] = logger
2443

2544
for arg in vars(args):
26-
if drop_base and arg in ['debug', 'trace', 'version', 'log_file', 'log_level', 'log_time', 'no_log_color']:
45+
if drop_base and arg in ["debug", "trace", "version", "log_file", "log_level", "log_time", "no_log_color"]:
2746
continue
2847
value = getattr(args, arg)
2948

@@ -34,28 +53,17 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):
3453
return kwargs
3554

3655

37-
def init_logger(name=None):
38-
""" Initialize the logger with a name"""
39-
from logging import getLogger
40-
name = name or __name__
41-
return getLogger(name)
42-
43-
44-
def init_argparser(prog=None, description=None):
45-
""" Initialize an argparser with common options. """
46-
from argparse import ArgumentParser
47-
argparser = ArgumentParser(prog=prog, description=description)
48-
return argparser
49-
50-
5156
def process_args(argparser, logger=None):
52-
""" Process argparser args, optionally configuring a logger. """
57+
"""Process argparser args, optionally configuring a logger."""
5358
from logging import Formatter
59+
5460
from zenlib.logging import ColorLognameFormatter
61+
5562
args = argparser.parse_args()
5663
if args.version:
5764
package = argparser.prog
5865
from importlib.metadata import version
66+
5967
print(f"{package} {version(package)}")
6068
exit(0)
6169

@@ -70,54 +78,61 @@ def process_args(argparser, logger=None):
7078
log_level = 20
7179
logger.setLevel(log_level)
7280

73-
format_str = '%(asctime)s | ' if args.log_time else ''
81+
format_str = "%(asctime)s | " if args.log_time else ""
7482
if log_level < 20:
75-
format_str += '%(levelname)s | %(name)-42s | %(message)s'
83+
format_str += "%(levelname)s | %(name)-42s | %(message)s"
7684
else:
77-
format_str += '%(levelname)s | %(message)s'
85+
format_str += "%(levelname)s | %(message)s"
7886
formatter = ColorLognameFormatter(format_str) if not args.no_log_color else Formatter(format_str)
7987

8088
# Add the formatter to the first handler, or add a new handler
8189
for handler in logger.handlers:
8290
handler.setFormatter(formatter)
8391
break
8492
else:
85-
from logging import StreamHandler, FileHandler
93+
from logging import FileHandler, StreamHandler
94+
8695
handler = StreamHandler() if args.log_file is None else FileHandler(args.log_file)
8796
handler.setFormatter(formatter)
8897
logger.addHandler(handler)
8998

9099
return args
91100

92101

93-
def dump_args_for_autocomplete(args):
94-
""" Dump args for autocomplete """
102+
def dump_args_for_autocomplete(args, test=False):
103+
"""Dump args for autocomplete"""
104+
out_str = ""
95105
for arg in args:
96-
if arg.get('action') not in ['store_true', 'store_false']:
106+
if arg.get("action") not in ["store_true", "store_false"]:
97107
continue
98-
for flag in arg['flags']:
99-
print(f"{flag} {arg.get('help')}")
108+
for flag in arg["flags"]:
109+
out_str += f"{flag} {arg.get('help')}\n"
110+
if test:
111+
return out_str
112+
print(out_str)
100113
exit(0)
101114

102115

103116
def get_args_n_logger(package, description: str, arguments=[], drop_default=False):
104-
""" Takes a package name and description
117+
"""Takes a package name and description
105118
If arguments are passed, they are added to argparser.
106119
Returns the parsed args and logger.
107120
"""
108-
arguments = BASE_ARGS + arguments
121+
arguments = get_base_args() + arguments
109122
from sys import argv
110-
if '--dump_args' in argv:
123+
124+
if "--dump_args" in argv:
111125
dump_args_for_autocomplete(arguments)
112126

113127
from argparse import Namespace
128+
114129
argparser = init_argparser(prog=package, description=description)
115130
logger = init_logger(package)
116131

117132
for arg in arguments:
118-
dest = arg.pop('flags')
133+
dest = arg.pop("flags")
119134
if drop_default:
120-
arg['default'] = None
135+
arg["default"] = None
121136
argparser.add_argument(*dest, **arg)
122137

123138
args = process_args(argparser, logger=logger)
@@ -129,6 +144,6 @@ def get_args_n_logger(package, description: str, arguments=[], drop_default=Fals
129144

130145

131146
def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True):
132-
""" Like get_args_n_logger, but only returns kwargs """
147+
"""Like get_args_n_logger, but only returns kwargs"""
133148
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default)
134149
return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base)

tests/test_main_funcs.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from argparse import ArgumentParser, Namespace
2+
from logging import Logger
3+
from unittest import TestCase, expectedFailure, main
4+
5+
from zenlib.util import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger
6+
from zenlib.util.main_funcs import dump_args_for_autocomplete, get_base_args
7+
8+
DEFAULT_ARGS = ["debug", "trace", "version", "log_time", "no_log_color"]
9+
10+
11+
def get_test_args():
12+
return [{"flags": ["arg1"], "action": "store"}, {"flags": ["arg2"], "action": "store"}]
13+
14+
15+
class TestMainFuncs(TestCase):
16+
def test_init_logger(self):
17+
self.assertIsInstance(init_logger(), Logger)
18+
19+
def test_init_argparser(self):
20+
self.assertIsInstance(init_argparser(), ArgumentParser)
21+
22+
def test_named_init_argparser(self):
23+
parser = init_argparser("test", "test description")
24+
self.assertEqual(parser.prog, "test")
25+
self.assertEqual(parser.description, "test description")
26+
27+
def test_get_args_n_logger(self):
28+
args, logger = get_args_n_logger("test", "test description", get_test_args())
29+
self.assertIsInstance(args, Namespace)
30+
self.assertIsInstance(logger, Logger)
31+
32+
self.assertEqual(args.arg1, "discover")
33+
self.assertEqual(args.arg2, "tests")
34+
35+
def test_get_args_n_logger_no_default(self):
36+
args, logger = get_args_n_logger("test", "test description", get_test_args(), drop_default=True)
37+
self.assertIsInstance(args, Namespace)
38+
self.assertIsInstance(logger, Logger)
39+
40+
for arg in DEFAULT_ARGS:
41+
self.assertFalse(hasattr(args, arg))
42+
43+
def test_get_kwargs_from_args(self):
44+
args, logger = get_args_n_logger("test", "test description", get_test_args())
45+
kwargs = get_kwargs_from_args(args, logger)
46+
self.assertIsInstance(kwargs, dict)
47+
self.assertEqual(kwargs["arg1"], "discover")
48+
self.assertEqual(kwargs["arg2"], "tests")
49+
self.assertEqual(kwargs["logger"], logger)
50+
51+
def test_not_drop_base(self):
52+
args, logger = get_args_n_logger("test", "test description", get_test_args())
53+
kwargs = get_kwargs_from_args(args, logger, drop_base=False)
54+
self.assertIsInstance(kwargs, dict)
55+
self.assertEqual(kwargs["arg1"], "discover")
56+
self.assertEqual(kwargs["arg2"], "tests")
57+
self.assertEqual(kwargs["logger"], logger)
58+
for arg in DEFAULT_ARGS:
59+
self.assertTrue(arg in kwargs)
60+
61+
def test_get_kwargs(self):
62+
kwargs = get_kwargs("test", "test description", get_test_args())
63+
self.assertIsInstance(kwargs, dict)
64+
self.assertEqual(kwargs["arg1"], "discover")
65+
self.assertEqual(kwargs["arg2"], "tests")
66+
self.assertTrue("logger" in kwargs)
67+
68+
@expectedFailure # This exits so should fail
69+
def test_dump_args_for_autocomplete(self):
70+
dump_args_for_autocomplete(get_test_args())
71+
72+
def test_dump_args_for_autocomplete_no_exit(self):
73+
self.assertEqual(dump_args_for_autocomplete(get_test_args(), test=True), "")
74+
self.assertEqual(
75+
dump_args_for_autocomplete(get_base_args(), test=True),
76+
"-d enable debug mode (level 10)\n--debug enable debug mode (level 10)\n-dd enable trace debug mode (level 5)\n--trace enable trace debug mode (level 5)\n-v print the version and exit\n--version print the version and exit\n--log-time enable log timestamps\n--no-log-color disable log color\n",
77+
)
78+
79+
80+
if __name__ == "__main__":
81+
main()

0 commit comments

Comments
 (0)