|
15 | 15 |
|
16 | 16 | import asyncio |
17 | 17 | import inspect |
| 18 | +import logging |
18 | 19 | import unittest |
19 | 20 | from functools import wraps |
20 | 21 |
|
21 | | -from tenacity import AsyncRetrying, RetryError |
| 22 | +from tenacity import AsyncRetrying, Future, RetryCallState, RetryError |
22 | 23 | from tenacity import _asyncio as tasyncio |
23 | | -from tenacity import retry, stop_after_attempt |
| 24 | +from tenacity import before_sleep_log, retry, retry_if_result, stop_after_attempt |
24 | 25 | from tenacity.wait import wait_fixed |
25 | 26 |
|
26 | | -from .test_tenacity import NoIOErrorAfterCount, current_time_ms |
| 27 | +from .test_tenacity import CapturingHandler, NoIOErrorAfterCount, NoneReturnUntilAfterCount, current_time_ms |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def asynctest(callable_): |
@@ -86,6 +87,31 @@ def test_retry_attributes(self): |
86 | 87 | assert hasattr(_retryable_coroutine, "retry") |
87 | 88 | assert hasattr(_retryable_coroutine, "retry_with") |
88 | 89 |
|
| 90 | + @asynctest |
| 91 | + async def test_async_retry_error_callback_handler(self): |
| 92 | + num_attempts = 3 |
| 93 | + self.attempt_counter = 0 |
| 94 | + |
| 95 | + async def _retry_error_callback_handler(retry_state: RetryCallState): |
| 96 | + _retry_error_callback_handler.called_times += 1 |
| 97 | + return retry_state.outcome |
| 98 | + |
| 99 | + _retry_error_callback_handler.called_times = 0 |
| 100 | + |
| 101 | + @retry( |
| 102 | + stop=stop_after_attempt(num_attempts), |
| 103 | + retry_error_callback=_retry_error_callback_handler, |
| 104 | + ) |
| 105 | + async def _foobar(): |
| 106 | + self.attempt_counter += 1 |
| 107 | + raise Exception("This exception should not be raised") |
| 108 | + |
| 109 | + result = await _foobar() |
| 110 | + |
| 111 | + self.assertEqual(_retry_error_callback_handler.called_times, 1) |
| 112 | + self.assertEqual(num_attempts, self.attempt_counter) |
| 113 | + self.assertIsInstance(result, Future) |
| 114 | + |
89 | 115 | @asynctest |
90 | 116 | async def test_attempt_number_is_correct_for_interleaved_coroutines(self): |
91 | 117 |
|
@@ -157,5 +183,104 @@ async def test_sleeps(self): |
157 | 183 | self.assertLess(t, 1.1) |
158 | 184 |
|
159 | 185 |
|
| 186 | +class TestAsyncBeforeAfterAttempts(unittest.TestCase): |
| 187 | + _attempt_number = 0 |
| 188 | + |
| 189 | + @asynctest |
| 190 | + async def test_before_attempts(self): |
| 191 | + TestAsyncBeforeAfterAttempts._attempt_number = 0 |
| 192 | + |
| 193 | + async def _before(retry_state): |
| 194 | + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number |
| 195 | + |
| 196 | + @retry( |
| 197 | + wait=wait_fixed(1), |
| 198 | + stop=stop_after_attempt(1), |
| 199 | + before=_before, |
| 200 | + ) |
| 201 | + async def _test_before(): |
| 202 | + pass |
| 203 | + |
| 204 | + await _test_before() |
| 205 | + |
| 206 | + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1) |
| 207 | + |
| 208 | + @asynctest |
| 209 | + async def test_after_attempts(self): |
| 210 | + TestAsyncBeforeAfterAttempts._attempt_number = 0 |
| 211 | + |
| 212 | + async def _after(retry_state): |
| 213 | + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number |
| 214 | + |
| 215 | + @retry( |
| 216 | + wait=wait_fixed(0.1), |
| 217 | + stop=stop_after_attempt(3), |
| 218 | + after=_after, |
| 219 | + ) |
| 220 | + async def _test_after(): |
| 221 | + if TestAsyncBeforeAfterAttempts._attempt_number < 2: |
| 222 | + raise Exception("testing after_attempts handler") |
| 223 | + else: |
| 224 | + pass |
| 225 | + |
| 226 | + await _test_after() |
| 227 | + |
| 228 | + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2) |
| 229 | + |
| 230 | + @asynctest |
| 231 | + async def test_before_sleep(self): |
| 232 | + async def _before_sleep(retry_state): |
| 233 | + self.assertGreater(retry_state.next_action.sleep, 0) |
| 234 | + _before_sleep.attempt_number = retry_state.attempt_number |
| 235 | + |
| 236 | + _before_sleep.attempt_number = 0 |
| 237 | + |
| 238 | + @retry( |
| 239 | + wait=wait_fixed(0.01), |
| 240 | + stop=stop_after_attempt(3), |
| 241 | + before_sleep=_before_sleep, |
| 242 | + ) |
| 243 | + async def _test_before_sleep(): |
| 244 | + if _before_sleep.attempt_number < 2: |
| 245 | + raise Exception("testing before_sleep_attempts handler") |
| 246 | + |
| 247 | + await _test_before_sleep() |
| 248 | + self.assertEqual(_before_sleep.attempt_number, 2) |
| 249 | + |
| 250 | + async def _test_before_sleep_log_returns(self, exc_info): |
| 251 | + thing = NoneReturnUntilAfterCount(2) |
| 252 | + logger = logging.getLogger(self.id()) |
| 253 | + logger.propagate = False |
| 254 | + logger.setLevel(logging.INFO) |
| 255 | + handler = CapturingHandler() |
| 256 | + logger.addHandler(handler) |
| 257 | + try: |
| 258 | + _before_sleep = before_sleep_log(logger, logging.INFO, exc_info=exc_info) |
| 259 | + _retry = retry_if_result(lambda result: result is None) |
| 260 | + retrying = AsyncRetrying( |
| 261 | + wait=wait_fixed(0.01), |
| 262 | + stop=stop_after_attempt(3), |
| 263 | + retry=_retry, |
| 264 | + before_sleep=_before_sleep, |
| 265 | + ) |
| 266 | + await retrying(_async_function, thing) |
| 267 | + finally: |
| 268 | + logger.removeHandler(handler) |
| 269 | + |
| 270 | + etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$" |
| 271 | + self.assertEqual(len(handler.records), 2) |
| 272 | + fmt = logging.Formatter().format |
| 273 | + self.assertRegex(fmt(handler.records[0]), etalon_re) |
| 274 | + self.assertRegex(fmt(handler.records[1]), etalon_re) |
| 275 | + |
| 276 | + @asynctest |
| 277 | + async def test_before_sleep_log_returns_without_exc_info(self): |
| 278 | + await self._test_before_sleep_log_returns(exc_info=False) |
| 279 | + |
| 280 | + @asynctest |
| 281 | + async def test_before_sleep_log_returns_with_exc_info(self): |
| 282 | + await self._test_before_sleep_log_returns(exc_info=True) |
| 283 | + |
| 284 | + |
160 | 285 | if __name__ == "__main__": |
161 | 286 | unittest.main() |
0 commit comments