Skip to content

Commit a1a7717

Browse files
feat(wait): add wait_exception strategy
Add `wait_exception` which calls a user-defined predicate with the caught exception to determine the wait time. Raises `RuntimeError` if the outcome has no exception.
1 parent d6e57dd commit a1a7717

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

tenacity/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
# Import all built-in wait strategies for easier usage.
6060
from .wait import wait_chain # noqa
6161
from .wait import wait_combine # noqa
62+
from .wait import wait_exception # noqa
6263
from .wait import wait_exponential # noqa
6364
from .wait import wait_fixed # noqa
6465
from .wait import wait_incrementing # noqa
@@ -686,6 +687,7 @@ def wrap(f: WrappedFn) -> WrappedFn:
686687
"stop_when_event_set",
687688
"wait_chain",
688689
"wait_combine",
690+
"wait_exception",
689691
"wait_exponential",
690692
"wait_fixed",
691693
"wait_incrementing",

tenacity/wait.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,42 @@ def __call__(self, retry_state: "RetryCallState") -> float:
113113
return wait_func(retry_state=retry_state)
114114

115115

116+
class wait_exception(wait_base):
117+
"""Wait strategy that waits the amount of time returned by the predicate.
118+
119+
The predicate is passed the exception object. Based on the exception, the
120+
user can decide how much time to wait before retrying.
121+
122+
For example::
123+
124+
def http_error(exception: BaseException) -> float:
125+
if (
126+
isinstance(exception, requests.HTTPError)
127+
and exception.response.status_code == requests.codes.too_many_requests
128+
):
129+
return float(exception.response.headers.get("Retry-After", "1"))
130+
return 60.0
131+
132+
133+
@retry(
134+
stop=stop_after_attempt(3),
135+
wait=wait_exception(http_error),
136+
)
137+
def http_get_request(url: str) -> None:
138+
response = requests.get(url)
139+
response.raise_for_status()
140+
"""
141+
142+
def __init__(self, predicate: typing.Callable[[BaseException], float]) -> None:
143+
self.predicate = predicate
144+
145+
def __call__(self, retry_state: "RetryCallState") -> float:
146+
exception = retry_state.outcome.exception()
147+
if exception is None:
148+
raise RuntimeError("outcome failed but the exception is None")
149+
return self.predicate(exception)
150+
151+
116152
class wait_incrementing(wait_base):
117153
"""Wait an incremental amount of time after each attempt.
118154

tests/test_tenacity.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,24 @@ def test_wait_combine(self):
369369
self.assertLess(w, 8)
370370
self.assertGreaterEqual(w, 5)
371371

372+
def test_wait_exception(self):
373+
def predicate(exc):
374+
if isinstance(exc, ValueError):
375+
return 3.5
376+
return 10.0
377+
378+
r = Retrying(wait=tenacity.wait_exception(predicate))
379+
380+
fut1 = tenacity.Future.construct(1, ValueError(), True)
381+
self.assertEqual(r.wait(make_retry_state(1, 0, last_result=fut1)), 3.5)
382+
383+
fut2 = tenacity.Future.construct(1, KeyError(), True)
384+
self.assertEqual(r.wait(make_retry_state(1, 0, last_result=fut2)), 10.0)
385+
386+
fut3 = tenacity.Future.construct(1, None, False)
387+
with self.assertRaises(RuntimeError):
388+
r.wait(make_retry_state(1, 0, last_result=fut3))
389+
372390
def test_wait_double_sum(self):
373391
r = Retrying(wait=tenacity.wait_random(0, 3) + tenacity.wait_fixed(5))
374392
# Test it a few time since it's random

0 commit comments

Comments
 (0)