Skip to content

Commit f65ca57

Browse files
committed
Allow asynchronous callbacks for AsyncRetrying parameters
1 parent 014b8e6 commit f65ca57

File tree

3 files changed

+181
-5
lines changed

3 files changed

+181
-5
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
features:
3+
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.

tenacity/_asyncio.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
import sys
2020
import typing
2121
from asyncio import sleep
22+
from inspect import iscoroutinefunction
2223

2324
from tenacity import AttemptManager
2425
from tenacity import BaseRetrying
2526
from tenacity import DoAttempt
2627
from tenacity import DoSleep
28+
from tenacity import RetryAction
2729
from tenacity import RetryCallState
30+
from tenacity import TryAgain
2831

2932
WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
3033
_RetValT = typing.TypeVar("_RetValT")
@@ -45,7 +48,7 @@ async def __call__( # type: ignore # Change signature from supertype
4548

4649
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
4750
while True:
48-
do = self.iter(retry_state=retry_state)
51+
do = await self.iter(retry_state=retry_state)
4952
if isinstance(do, DoAttempt):
5053
try:
5154
result = await fn(*args, **kwargs)
@@ -66,7 +69,7 @@ def __aiter__(self) -> "AsyncRetrying":
6669

6770
async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
6871
while True:
69-
do = self.iter(retry_state=self._retry_state)
72+
do = await self.iter(retry_state=self._retry_state)
7073
if do is None:
7174
raise StopAsyncIteration
7275
elif isinstance(do, DoAttempt):
@@ -90,3 +93,48 @@ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
9093
async_wrapped.retry_with = fn.retry_with
9194

9295
return async_wrapped
96+
97+
@staticmethod
98+
async def handle_custom_function(
99+
func: typing.Union[typing.Callable, typing.Awaitable], retry_state: RetryCallState
100+
) -> typing.Any:
101+
if iscoroutinefunction(func):
102+
return await func(retry_state)
103+
return func(retry_state)
104+
105+
async def iter(self, retry_state: "RetryCallState") -> typing.Union[DoAttempt, DoSleep, typing.Any]: # noqa
106+
fut = retry_state.outcome
107+
if fut is None:
108+
if self.before is not None:
109+
await self.handle_custom_function(self.before, retry_state)
110+
return DoAttempt()
111+
112+
is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
113+
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
114+
return fut.result()
115+
116+
if self.after is not None:
117+
await self.handle_custom_function(self.after, retry_state)
118+
119+
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
120+
if self.stop(retry_state=retry_state):
121+
if self.retry_error_callback:
122+
return await self.handle_custom_function(self.retry_error_callback, retry_state)
123+
retry_exc = self.retry_error_cls(fut)
124+
if self.reraise:
125+
raise retry_exc.reraise()
126+
raise retry_exc from fut.exception()
127+
128+
if self.wait:
129+
_sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
130+
else:
131+
_sleep = 0.0
132+
retry_state.next_action = RetryAction(_sleep)
133+
retry_state.idle_for += _sleep
134+
self.statistics["idle_for"] += _sleep
135+
self.statistics["attempt_number"] += 1
136+
137+
if self.before_sleep is not None:
138+
await self.handle_custom_function(self.before_sleep, retry_state)
139+
140+
return DoSleep(_sleep)

tests/test_asyncio.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515

1616
import asyncio
1717
import inspect
18+
import logging
1819
import unittest
1920
from functools import wraps
2021

21-
from tenacity import AsyncRetrying, RetryError
22+
from tenacity import AsyncRetrying, Future, RetryCallState, RetryError
2223
from tenacity import _asyncio as tasyncio
23-
from tenacity import retry, stop_after_attempt
24+
from tenacity import before_sleep_log, retry, retry_if_result, stop_after_attempt
2425
from tenacity.wait import wait_fixed
2526

26-
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
27+
from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms
2728

2829

2930
def asynctest(callable_):
@@ -86,6 +87,31 @@ def test_retry_attributes(self):
8687
assert hasattr(_retryable_coroutine, "retry")
8788
assert hasattr(_retryable_coroutine, "retry_with")
8889

90+
@asynctest
91+
async def test_async_retry_error_callback_handler(self):
92+
num_attempts = 3
93+
self.attempt_counter = 0
94+
95+
async def _retry_error_callback_handler(retry_state: RetryCallState):
96+
_retry_error_callback_handler.called_times += 1
97+
return retry_state.outcome
98+
99+
_retry_error_callback_handler.called_times = 0
100+
101+
@retry(
102+
stop=stop_after_attempt(num_attempts),
103+
retry_error_callback=_retry_error_callback_handler,
104+
)
105+
async def _foobar():
106+
self.attempt_counter += 1
107+
raise Exception("This exception should not be raised")
108+
109+
result = await _foobar()
110+
111+
self.assertEqual(_retry_error_callback_handler.called_times, 1)
112+
self.assertEqual(num_attempts, self.attempt_counter)
113+
self.assertIsInstance(result, Future)
114+
89115
@asynctest
90116
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):
91117

@@ -157,5 +183,104 @@ async def test_sleeps(self):
157183
self.assertLess(t, 1.1)
158184

159185

186+
class TestAsyncBeforeAfterAttempts(unittest.TestCase):
187+
_attempt_number = 0
188+
189+
@asynctest
190+
async def test_before_attempts(self):
191+
TestAsyncBeforeAfterAttempts._attempt_number = 0
192+
193+
async def _before(retry_state):
194+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
195+
196+
@retry(
197+
wait=wait_fixed(1),
198+
stop=stop_after_attempt(1),
199+
before=_before,
200+
)
201+
async def _test_before():
202+
pass
203+
204+
await _test_before()
205+
206+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)
207+
208+
@asynctest
209+
async def test_after_attempts(self):
210+
TestAsyncBeforeAfterAttempts._attempt_number = 0
211+
212+
async def _after(retry_state):
213+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
214+
215+
@retry(
216+
wait=wait_fixed(0.1),
217+
stop=stop_after_attempt(3),
218+
after=_after,
219+
)
220+
async def _test_after():
221+
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
222+
raise Exception("testing after_attempts handler")
223+
else:
224+
pass
225+
226+
await _test_after()
227+
228+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)
229+
230+
@asynctest
231+
async def test_before_sleep(self):
232+
async def _before_sleep(retry_state):
233+
self.assertGreater(retry_state.next_action.sleep, 0)
234+
_before_sleep.attempt_number = retry_state.attempt_number
235+
236+
_before_sleep.attempt_number = 0
237+
238+
@retry(
239+
wait=wait_fixed(0.01),
240+
stop=stop_after_attempt(3),
241+
before_sleep=_before_sleep,
242+
)
243+
async def _test_before_sleep():
244+
if _before_sleep.attempt_number < 2:
245+
raise Exception("testing before_sleep_attempts handler")
246+
247+
await _test_before_sleep()
248+
self.assertEqual(_before_sleep.attempt_number, 2)
249+
250+
async def _test_before_sleep_log_returns(self, exc_info):
251+
thing = NoneReturnUntilAfterCount(2)
252+
logger = logging.getLogger(self.id())
253+
logger.propagate = False
254+
logger.setLevel(logging.INFO)
255+
handler = CapturingHandler()
256+
logger.addHandler(handler)
257+
try:
258+
_before_sleep = before_sleep_log(logger, logging.INFO, exc_info=exc_info)
259+
_retry = retry_if_result(lambda result: result is None)
260+
retrying = AsyncRetrying(
261+
wait=wait_fixed(0.01),
262+
stop=stop_after_attempt(3),
263+
retry=_retry,
264+
before_sleep=_before_sleep,
265+
)
266+
await retrying(_async_function, thing)
267+
finally:
268+
logger.removeHandler(handler)
269+
270+
etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
271+
self.assertEqual(len(handler.records), 2)
272+
fmt = logging.Formatter().format
273+
self.assertRegex(fmt(handler.records[0]), etalon_re)
274+
self.assertRegex(fmt(handler.records[1]), etalon_re)
275+
276+
@asynctest
277+
async def test_before_sleep_log_returns_without_exc_info(self):
278+
await self._test_before_sleep_log_returns(exc_info=False)
279+
280+
@asynctest
281+
async def test_before_sleep_log_returns_with_exc_info(self):
282+
await self._test_before_sleep_log_returns(exc_info=True)
283+
284+
160285
if __name__ == "__main__":
161286
unittest.main()

0 commit comments

Comments
 (0)