From f65ca57f919b7747c93c52faa9ca1c80528b0d50 Mon Sep 17 00:00:00 2001 From: Emin Mastizada Date: Fri, 23 Sep 2022 13:54:23 +0200 Subject: [PATCH] Allow asynchronous callbacks for AsyncRetrying parameters --- ...llow-async-callbacks-e75b66109e759c3b.yaml | 3 + tenacity/_asyncio.py | 52 ++++++- tests/test_asyncio.py | 131 +++++++++++++++++- 3 files changed, 181 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml diff --git a/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml b/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml new file mode 100644 index 00000000..d906b758 --- /dev/null +++ b/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml @@ -0,0 +1,3 @@ +--- +features: + - Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters. diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index 374ef206..52b23b64 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -19,12 +19,15 @@ import sys import typing from asyncio import sleep +from inspect import iscoroutinefunction from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt from tenacity import DoSleep +from tenacity import RetryAction from tenacity import RetryCallState +from tenacity import TryAgain WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable) _RetValT = typing.TypeVar("_RetValT") @@ -45,7 +48,7 @@ async def __call__( # type: ignore # Change signature from supertype retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -66,7 +69,7 @@ def __aiter__(self) -> "AsyncRetrying": async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]: while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt): @@ -90,3 +93,48 @@ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async_wrapped.retry_with = fn.retry_with return async_wrapped + + @staticmethod + async def handle_custom_function( + func: typing.Union[typing.Callable, typing.Awaitable], retry_state: RetryCallState + ) -> typing.Any: + if iscoroutinefunction(func): + return await func(retry_state) + return func(retry_state) + + async def iter(self, retry_state: "RetryCallState") -> typing.Union[DoAttempt, DoSleep, typing.Any]: # noqa + fut = retry_state.outcome + if fut is None: + if self.before is not None: + await self.handle_custom_function(self.before, retry_state) + return DoAttempt() + + is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain) + if not (is_explicit_retry or self.retry(retry_state=retry_state)): + return fut.result() + + if self.after is not None: + await self.handle_custom_function(self.after, retry_state) + + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + if self.stop(retry_state=retry_state): + if self.retry_error_callback: + return await self.handle_custom_function(self.retry_error_callback, retry_state) + retry_exc = self.retry_error_cls(fut) + if self.reraise: + raise retry_exc.reraise() + raise retry_exc from fut.exception() + + if self.wait: + _sleep = await self.handle_custom_function(self.wait, retry_state=retry_state) + else: + _sleep = 0.0 + retry_state.next_action = RetryAction(_sleep) + retry_state.idle_for += _sleep + self.statistics["idle_for"] += _sleep + self.statistics["attempt_number"] += 1 + + if self.before_sleep is not None: + await self.handle_custom_function(self.before_sleep, retry_state) + + return DoSleep(_sleep) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index b370e29c..d781eb0a 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -15,15 +15,16 @@ import asyncio import inspect +import logging import unittest from functools import wraps -from tenacity import AsyncRetrying, RetryError +from tenacity import AsyncRetrying, Future, RetryCallState, RetryError from tenacity import _asyncio as tasyncio -from tenacity import retry, stop_after_attempt +from tenacity import before_sleep_log, retry, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed -from .test_tenacity import NoIOErrorAfterCount, current_time_ms +from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms def asynctest(callable_): @@ -86,6 +87,31 @@ def test_retry_attributes(self): assert hasattr(_retryable_coroutine, "retry") assert hasattr(_retryable_coroutine, "retry_with") + @asynctest + async def test_async_retry_error_callback_handler(self): + num_attempts = 3 + self.attempt_counter = 0 + + async def _retry_error_callback_handler(retry_state: RetryCallState): + _retry_error_callback_handler.called_times += 1 + return retry_state.outcome + + _retry_error_callback_handler.called_times = 0 + + @retry( + stop=stop_after_attempt(num_attempts), + retry_error_callback=_retry_error_callback_handler, + ) + async def _foobar(): + self.attempt_counter += 1 + raise Exception("This exception should not be raised") + + result = await _foobar() + + self.assertEqual(_retry_error_callback_handler.called_times, 1) + self.assertEqual(num_attempts, self.attempt_counter) + self.assertIsInstance(result, Future) + @asynctest async def test_attempt_number_is_correct_for_interleaved_coroutines(self): @@ -157,5 +183,104 @@ async def test_sleeps(self): self.assertLess(t, 1.1) +class TestAsyncBeforeAfterAttempts(unittest.TestCase): + _attempt_number = 0 + + @asynctest + async def test_before_attempts(self): + TestAsyncBeforeAfterAttempts._attempt_number = 0 + + async def _before(retry_state): + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number + + @retry( + wait=wait_fixed(1), + stop=stop_after_attempt(1), + before=_before, + ) + async def _test_before(): + pass + + await _test_before() + + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1) + + @asynctest + async def test_after_attempts(self): + TestAsyncBeforeAfterAttempts._attempt_number = 0 + + async def _after(retry_state): + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number + + @retry( + wait=wait_fixed(0.1), + stop=stop_after_attempt(3), + after=_after, + ) + async def _test_after(): + if TestAsyncBeforeAfterAttempts._attempt_number < 2: + raise Exception("testing after_attempts handler") + else: + pass + + await _test_after() + + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2) + + @asynctest + async def test_before_sleep(self): + async def _before_sleep(retry_state): + self.assertGreater(retry_state.next_action.sleep, 0) + _before_sleep.attempt_number = retry_state.attempt_number + + _before_sleep.attempt_number = 0 + + @retry( + wait=wait_fixed(0.01), + stop=stop_after_attempt(3), + before_sleep=_before_sleep, + ) + async def _test_before_sleep(): + if _before_sleep.attempt_number < 2: + raise Exception("testing before_sleep_attempts handler") + + await _test_before_sleep() + self.assertEqual(_before_sleep.attempt_number, 2) + + async def _test_before_sleep_log_returns(self, exc_info): + thing = NoneReturnUntilAfterCount(2) + logger = logging.getLogger(self.id()) + logger.propagate = False + logger.setLevel(logging.INFO) + handler = CapturingHandler() + logger.addHandler(handler) + try: + _before_sleep = before_sleep_log(logger, logging.INFO, exc_info=exc_info) + _retry = retry_if_result(lambda result: result is None) + retrying = AsyncRetrying( + wait=wait_fixed(0.01), + stop=stop_after_attempt(3), + retry=_retry, + before_sleep=_before_sleep, + ) + await retrying(_async_function, thing) + finally: + logger.removeHandler(handler) + + etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$" + self.assertEqual(len(handler.records), 2) + fmt = logging.Formatter().format + self.assertRegex(fmt(handler.records[0]), etalon_re) + self.assertRegex(fmt(handler.records[1]), etalon_re) + + @asynctest + async def test_before_sleep_log_returns_without_exc_info(self): + await self._test_before_sleep_log_returns(exc_info=False) + + @asynctest + async def test_before_sleep_log_returns_with_exc_info(self): + await self._test_before_sleep_log_returns(exc_info=True) + + if __name__ == "__main__": unittest.main()