Skip to content

Commit 014b8e6

Browse files
authored
feat: Add retry_if_exception_cause_type (#362)
Add a new retry_base class called `retry_if_exception_cause_type` that checks that the cause of the raised exception is of a certain type. Co-authored-by: Guillaume RISBOURG <[email protected]>
1 parent 18d05a6 commit 014b8e6

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
Add a new `retry_base` class called `retry_if_exception_cause_type` that
5+
checks, recursively, if any of the causes of the raised exception is of a certain type.

tenacity/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .retry import retry_any # noqa
3434
from .retry import retry_if_exception # noqa
3535
from .retry import retry_if_exception_type # noqa
36+
from .retry import retry_if_exception_cause_type # noqa
3637
from .retry import retry_if_not_exception_type # noqa
3738
from .retry import retry_if_not_result # noqa
3839
from .retry import retry_if_result # noqa

tenacity/retry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,33 @@ def __call__(self, retry_state: "RetryCallState") -> bool:
117117
return self.predicate(retry_state.outcome.exception())
118118

119119

120+
class retry_if_exception_cause_type(retry_base):
121+
"""Retries if any of the causes of the raised exception is of one or more types.
122+
123+
The check on the type of the cause of the exception is done recursively (until finding
124+
an exception in the chain that has no `__cause__`)
125+
"""
126+
127+
def __init__(
128+
self,
129+
exception_types: typing.Union[
130+
typing.Type[BaseException],
131+
typing.Tuple[typing.Type[BaseException], ...],
132+
] = Exception,
133+
) -> None:
134+
self.exception_cause_types = exception_types
135+
136+
def __call__(self, retry_state: "RetryCallState") -> bool:
137+
if retry_state.outcome.failed:
138+
exc = retry_state.outcome.exception()
139+
while exc is not None:
140+
if isinstance(exc.__cause__, self.exception_cause_types):
141+
return True
142+
exc = exc.__cause__
143+
144+
return False
145+
146+
120147
class retry_if_result(retry_base):
121148
"""Retries if the result verifies a predicate."""
122149

tests/test_tenacity.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,56 @@ def go(self):
676676
return True
677677

678678

679+
class NoNameErrorCauseAfterCount:
680+
"""Holds counter state for invoking a method several times in a row."""
681+
682+
def __init__(self, count):
683+
self.counter = 0
684+
self.count = count
685+
686+
def go2(self):
687+
raise NameError("Hi there, I'm a NameError")
688+
689+
def go(self):
690+
"""Raise an IOError with a NameError as cause until after count threshold has been crossed.
691+
692+
Then return True.
693+
"""
694+
if self.counter < self.count:
695+
self.counter += 1
696+
try:
697+
self.go2()
698+
except NameError as e:
699+
raise IOError() from e
700+
701+
return True
702+
703+
704+
class NoIOErrorCauseAfterCount:
705+
"""Holds counter state for invoking a method several times in a row."""
706+
707+
def __init__(self, count):
708+
self.counter = 0
709+
self.count = count
710+
711+
def go2(self):
712+
raise IOError("Hi there, I'm an IOError")
713+
714+
def go(self):
715+
"""Raise a NameError with an IOError as cause until after count threshold has been crossed.
716+
717+
Then return True.
718+
"""
719+
if self.counter < self.count:
720+
self.counter += 1
721+
try:
722+
self.go2()
723+
except IOError as e:
724+
raise NameError() from e
725+
726+
return True
727+
728+
679729
class NameErrorUntilCount:
680730
"""Holds counter state for invoking a method several times in a row."""
681731

@@ -783,6 +833,11 @@ def _retryable_test_with_stop(thing):
783833
return thing.go()
784834

785835

836+
@retry(retry=tenacity.retry_if_exception_cause_type(NameError))
837+
def _retryable_test_with_exception_cause_type(thing):
838+
return thing.go()
839+
840+
786841
@retry(retry=tenacity.retry_if_exception_type(IOError))
787842
def _retryable_test_with_exception_type_io(thing):
788843
return thing.go()
@@ -987,6 +1042,15 @@ def test_retry_if_not_exception_message_match(self):
9871042
s = _retryable_test_if_not_exception_message_message.retry.statistics
9881043
self.assertTrue(s["attempt_number"] == 1)
9891044

1045+
def test_retry_if_exception_cause_type(self):
1046+
self.assertTrue(_retryable_test_with_exception_cause_type(NoNameErrorCauseAfterCount(5)))
1047+
1048+
try:
1049+
_retryable_test_with_exception_cause_type(NoIOErrorCauseAfterCount(5))
1050+
self.fail("Expected exception without NameError as cause")
1051+
except NameError:
1052+
pass
1053+
9901054
def test_defaults(self):
9911055
self.assertTrue(_retryable_default(NoNameErrorAfterCount(5)))
9921056
self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))

0 commit comments

Comments
 (0)