Skip to content

Commit 13657ee

Browse files
committed
Allow asynchronous callbacks for AsyncRetrying parameters
1 parent 014b8e6 commit 13657ee

File tree

3 files changed

+181
-5
lines changed

3 files changed

+181
-5
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
features:
3+
- Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters.

tenacity/_asyncio.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
import sys
2020
import typing
2121
from asyncio import sleep
22+
from inspect import iscoroutinefunction
2223

2324
from tenacity import AttemptManager
2425
from tenacity import BaseRetrying
2526
from tenacity import DoAttempt
2627
from tenacity import DoSleep
2728
from tenacity import RetryCallState
29+
from tenacity import RetryAction
30+
from tenacity import TryAgain
2831

2932
WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
3033
_RetValT = typing.TypeVar("_RetValT")
@@ -45,7 +48,7 @@ async def __call__( # type: ignore # Change signature from supertype
4548

4649
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
4750
while True:
48-
do = self.iter(retry_state=retry_state)
51+
do = await self.iter(retry_state=retry_state)
4952
if isinstance(do, DoAttempt):
5053
try:
5154
result = await fn(*args, **kwargs)
@@ -66,7 +69,7 @@ def __aiter__(self) -> "AsyncRetrying":
6669

6770
async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
6871
while True:
69-
do = self.iter(retry_state=self._retry_state)
72+
do = await self.iter(retry_state=self._retry_state)
7073
if do is None:
7174
raise StopAsyncIteration
7275
elif isinstance(do, DoAttempt):
@@ -90,3 +93,48 @@ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
9093
async_wrapped.retry_with = fn.retry_with
9194

9295
return async_wrapped
96+
97+
@staticmethod
98+
async def handle_custom_function(
99+
func: typing.Union[typing.Callable, typing.Awaitable], retry_state: RetryCallState
100+
) -> typing.Any:
101+
if iscoroutinefunction(func):
102+
return await func(retry_state)
103+
return func(retry_state)
104+
105+
async def iter(self, retry_state: "RetryCallState") -> typing.Union[DoAttempt, DoSleep, typing.Any]: # noqa
106+
fut = retry_state.outcome
107+
if fut is None:
108+
if self.before is not None:
109+
await self.handle_custom_function(self.before, retry_state)
110+
return DoAttempt()
111+
112+
is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain)
113+
if not (is_explicit_retry or self.retry(retry_state=retry_state)):
114+
return fut.result()
115+
116+
if self.after is not None:
117+
await self.handle_custom_function(self.after, retry_state)
118+
119+
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
120+
if self.stop(retry_state=retry_state):
121+
if self.retry_error_callback:
122+
return await self.handle_custom_function(self.retry_error_callback, retry_state)
123+
retry_exc = self.retry_error_cls(fut)
124+
if self.reraise:
125+
raise retry_exc.reraise()
126+
raise retry_exc from fut.exception()
127+
128+
if self.wait:
129+
_sleep = await self.handle_custom_function(self.wait, retry_state=retry_state)
130+
else:
131+
_sleep = 0.0
132+
retry_state.next_action = RetryAction(_sleep)
133+
retry_state.idle_for += _sleep
134+
self.statistics["idle_for"] += _sleep
135+
self.statistics["attempt_number"] += 1
136+
137+
if self.before_sleep is not None:
138+
await self.handle_custom_function(self.before_sleep, retry_state)
139+
140+
return DoSleep(_sleep)

tests/test_asyncio.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
import logging
1718
import inspect
1819
import unittest
1920
from functools import wraps
2021

21-
from tenacity import AsyncRetrying, RetryError
22+
from tenacity import AsyncRetrying, RetryError, RetryCallState, Future
2223
from tenacity import _asyncio as tasyncio
23-
from tenacity import retry, stop_after_attempt
24+
from tenacity import retry, stop_after_attempt, before_sleep_log, retry_if_result
2425
from tenacity.wait import wait_fixed
2526

26-
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
27+
from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms
2728

2829

2930
def asynctest(callable_):
@@ -86,6 +87,31 @@ def test_retry_attributes(self):
8687
assert hasattr(_retryable_coroutine, "retry")
8788
assert hasattr(_retryable_coroutine, "retry_with")
8889

90+
@asynctest
91+
async def test_async_retry_error_callback_handler(self):
92+
num_attempts = 3
93+
self.attempt_counter = 0
94+
95+
async def _retry_error_callback_handler(retry_state: RetryCallState):
96+
_retry_error_callback_handler.called_times += 1
97+
return retry_state.outcome
98+
99+
_retry_error_callback_handler.called_times = 0
100+
101+
@retry(
102+
stop=stop_after_attempt(num_attempts),
103+
retry_error_callback=_retry_error_callback_handler,
104+
)
105+
async def _foobar():
106+
self.attempt_counter += 1
107+
raise Exception("This exception should not be raised")
108+
109+
result = await _foobar()
110+
111+
self.assertEqual(_retry_error_callback_handler.called_times, 1)
112+
self.assertEqual(num_attempts, self.attempt_counter)
113+
self.assertIsInstance(result, Future)
114+
89115
@asynctest
90116
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):
91117

@@ -157,5 +183,104 @@ async def test_sleeps(self):
157183
self.assertLess(t, 1.1)
158184

159185

186+
class TestAsyncBeforeAfterAttempts(unittest.TestCase):
187+
_attempt_number = 0
188+
189+
@asynctest
190+
async def test_before_attempts(self):
191+
TestAsyncBeforeAfterAttempts._attempt_number = 0
192+
193+
async def _before(retry_state):
194+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
195+
196+
@retry(
197+
wait=wait_fixed(1),
198+
stop=stop_after_attempt(1),
199+
before=_before,
200+
)
201+
async def _test_before():
202+
pass
203+
204+
await _test_before()
205+
206+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1)
207+
208+
@asynctest
209+
async def test_after_attempts(self):
210+
TestAsyncBeforeAfterAttempts._attempt_number = 0
211+
212+
async def _after(retry_state):
213+
TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number
214+
215+
@retry(
216+
wait=wait_fixed(0.1),
217+
stop=stop_after_attempt(3),
218+
after=_after,
219+
)
220+
async def _test_after():
221+
if TestAsyncBeforeAfterAttempts._attempt_number < 2:
222+
raise Exception("testing after_attempts handler")
223+
else:
224+
pass
225+
226+
await _test_after()
227+
228+
self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2)
229+
230+
@asynctest
231+
async def test_before_sleep(self):
232+
async def _before_sleep(retry_state):
233+
self.assertGreater(retry_state.next_action.sleep, 0)
234+
_before_sleep.attempt_number = retry_state.attempt_number
235+
236+
_before_sleep.attempt_number = 0
237+
238+
@retry(
239+
wait=wait_fixed(0.01),
240+
stop=stop_after_attempt(3),
241+
before_sleep=_before_sleep,
242+
)
243+
async def _test_before_sleep():
244+
if _before_sleep.attempt_number < 2:
245+
raise Exception("testing before_sleep_attempts handler")
246+
247+
await _test_before_sleep()
248+
self.assertEqual(_before_sleep.attempt_number, 2)
249+
250+
async def _test_before_sleep_log_returns(self, exc_info):
251+
thing = NoneReturnUntilAfterCount(2)
252+
logger = logging.getLogger(self.id())
253+
logger.propagate = False
254+
logger.setLevel(logging.INFO)
255+
handler = CapturingHandler()
256+
logger.addHandler(handler)
257+
try:
258+
_before_sleep = before_sleep_log(logger, logging.INFO, exc_info=exc_info)
259+
_retry = retry_if_result(lambda result: result is None)
260+
retrying = AsyncRetrying(
261+
wait=wait_fixed(0.01),
262+
stop=stop_after_attempt(3),
263+
retry=_retry,
264+
before_sleep=_before_sleep,
265+
)
266+
await retrying(_async_function, thing)
267+
finally:
268+
logger.removeHandler(handler)
269+
270+
etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$"
271+
self.assertEqual(len(handler.records), 2)
272+
fmt = logging.Formatter().format
273+
self.assertRegex(fmt(handler.records[0]), etalon_re)
274+
self.assertRegex(fmt(handler.records[1]), etalon_re)
275+
276+
@asynctest
277+
async def test_before_sleep_log_returns_without_exc_info(self):
278+
await self._test_before_sleep_log_returns(exc_info=False)
279+
280+
@asynctest
281+
async def test_before_sleep_log_returns_with_exc_info(self):
282+
await self._test_before_sleep_log_returns(exc_info=True)
283+
284+
160285
if __name__ == "__main__":
161286
unittest.main()

0 commit comments

Comments
 (0)