Skip to content

Commit c42b215

Browse files
authored
Merge pull request #4 from desultory/dev
handle kwargs with handle_plural
2 parents 5cdc8e9 + 49bf7fa commit c42b215

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

src/zenlib/util/handle_plural.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__author__ = "desultory"
2-
__version__ = "2.1.0"
2+
__version__ = "2.2.0"
33

44
from collections.abc import KeysView, ValuesView
55

@@ -11,7 +11,7 @@ def handle_plural(function, log_level=10):
1111
Logs using the logger attribute if it exists.
1212
"""
1313

14-
def wrapper(self, *args):
14+
def wrapper(self, *args, **kwargs):
1515
def log(msg, level=log_level):
1616
if hasattr(self, "logger"):
1717
self.logger.log(level, msg)
@@ -26,15 +26,15 @@ def log(msg, level=log_level):
2626
if isinstance(focus_arg, list) and not isinstance(focus_arg, str):
2727
log("Expanding list: %s" % focus_arg)
2828
for item in focus_arg:
29-
function(self, *(other_args + (item,)))
29+
function(self, *(other_args + (item,)), **kwargs)
3030
elif isinstance(focus_arg, ValuesView):
3131
log("Expanding dict values: %s" % focus_arg)
3232
for value in focus_arg:
33-
function(self, *(other_args + (value,)))
33+
function(self, *(other_args + (value,)), **kwargs)
3434
elif isinstance(focus_arg, KeysView):
3535
log("Expanding dict keys: %s" % focus_arg)
3636
for key in focus_arg:
37-
function(self, *(other_args + (key,)))
37+
function(self, *(other_args + (key,)), **kwargs)
3838
elif isinstance(focus_arg, dict):
3939
log("Expanding dict: %s" % focus_arg)
4040
for key, value in focus_arg.items():
@@ -47,9 +47,10 @@ def log(msg, level=log_level):
4747
value,
4848
)
4949
),
50+
**kwargs,
5051
)
5152
else:
5253
log(f"Arguments were not expanded: {args}", log_level - 5)
53-
return function(self, *args)
54+
return function(self, *args, **kwargs)
5455

5556
return wrapper

tests/test_handle_plural.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,36 @@ def _test_plural_ints(self, arg_a, iterated):
1010
self._test_data += iterated
1111
return iterated
1212

13+
@handle_plural
14+
def _test_plural_ints_with_kwarg(self, arg_a, iterated, test_kwarg='asdf'):
15+
self.assertEqual(test_kwarg, 'asdf')
16+
if isinstance(iterated, int):
17+
self._test_data += iterated
18+
return iterated
19+
20+
@handle_plural
21+
def _test_plural_setting_kwarg(self, arg_a, iterated, test_kwarg='asdf'):
22+
self.assertEqual(test_kwarg, 'test')
23+
if isinstance(iterated, int):
24+
self._test_data += iterated
25+
return iterated
26+
1327
def test_list(self):
1428
self._test_data = 0
1529
test_list = [1, 2, 3, 4]
1630
extra_arg = 'a'
1731
self.assertEqual(self._test_plural_ints(extra_arg, test_list), None)
1832
self.assertEqual(self._test_data, sum(test_list))
33+
self._test_data = 0
34+
self.assertEqual(self._test_plural_ints_with_kwarg(extra_arg, test_list), None)
35+
self.assertEqual(self._test_data, sum(test_list))
36+
37+
def test_setting_kwarg(self):
38+
self._test_data = 0
39+
test_list = [1, 2, 3, 4]
40+
extra_arg = 'a'
41+
self.assertEqual(self._test_plural_setting_kwarg(extra_arg, test_list, test_kwarg='test'), None)
42+
self.assertEqual(self._test_data, sum(test_list))
1943

2044
def test_single(self): # non-iterables should allow returns
2145
self._test_data = 0

0 commit comments

Comments
 (0)