diff --git a/src/zenlib/util/handle_plural.py b/src/zenlib/util/handle_plural.py index 1971ccd..03b76f7 100644 --- a/src/zenlib/util/handle_plural.py +++ b/src/zenlib/util/handle_plural.py @@ -1,5 +1,5 @@ __author__ = "desultory" -__version__ = "2.1.0" +__version__ = "2.2.0" from collections.abc import KeysView, ValuesView @@ -11,7 +11,7 @@ def handle_plural(function, log_level=10): Logs using the logger attribute if it exists. """ - def wrapper(self, *args): + def wrapper(self, *args, **kwargs): def log(msg, level=log_level): if hasattr(self, "logger"): self.logger.log(level, msg) @@ -26,15 +26,15 @@ def log(msg, level=log_level): if isinstance(focus_arg, list) and not isinstance(focus_arg, str): log("Expanding list: %s" % focus_arg) for item in focus_arg: - function(self, *(other_args + (item,))) + function(self, *(other_args + (item,)), **kwargs) elif isinstance(focus_arg, ValuesView): log("Expanding dict values: %s" % focus_arg) for value in focus_arg: - function(self, *(other_args + (value,))) + function(self, *(other_args + (value,)), **kwargs) elif isinstance(focus_arg, KeysView): log("Expanding dict keys: %s" % focus_arg) for key in focus_arg: - function(self, *(other_args + (key,))) + function(self, *(other_args + (key,)), **kwargs) elif isinstance(focus_arg, dict): log("Expanding dict: %s" % focus_arg) for key, value in focus_arg.items(): @@ -47,9 +47,10 @@ def log(msg, level=log_level): value, ) ), + **kwargs, ) else: log(f"Arguments were not expanded: {args}", log_level - 5) - return function(self, *args) + return function(self, *args, **kwargs) return wrapper diff --git a/tests/test_handle_plural.py b/tests/test_handle_plural.py index ee33f7b..7dc5b5b 100644 --- a/tests/test_handle_plural.py +++ b/tests/test_handle_plural.py @@ -10,12 +10,36 @@ def _test_plural_ints(self, arg_a, iterated): self._test_data += iterated return iterated + @handle_plural + def _test_plural_ints_with_kwarg(self, arg_a, iterated, test_kwarg='asdf'): + self.assertEqual(test_kwarg, 'asdf') + if isinstance(iterated, int): + self._test_data += iterated + return iterated + + @handle_plural + def _test_plural_setting_kwarg(self, arg_a, iterated, test_kwarg='asdf'): + self.assertEqual(test_kwarg, 'test') + if isinstance(iterated, int): + self._test_data += iterated + return iterated + def test_list(self): self._test_data = 0 test_list = [1, 2, 3, 4] extra_arg = 'a' self.assertEqual(self._test_plural_ints(extra_arg, test_list), None) self.assertEqual(self._test_data, sum(test_list)) + self._test_data = 0 + self.assertEqual(self._test_plural_ints_with_kwarg(extra_arg, test_list), None) + self.assertEqual(self._test_data, sum(test_list)) + + def test_setting_kwarg(self): + self._test_data = 0 + test_list = [1, 2, 3, 4] + extra_arg = 'a' + self.assertEqual(self._test_plural_setting_kwarg(extra_arg, test_list, test_kwarg='test'), None) + self.assertEqual(self._test_data, sum(test_list)) def test_single(self): # non-iterables should allow returns self._test_data = 0