Skip to content

Commit bfa2c80

Browse files
authored
Support async actions (#437)
* Support async actions * Fixes after main rebase * Test is_coroutine_callable
1 parent b7e4883 commit bfa2c80

File tree

3 files changed

+97
-3
lines changed

3 files changed

+97
-3
lines changed

tenacity/_asyncio.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tenacity import DoAttempt
2626
from tenacity import DoSleep
2727
from tenacity import RetryCallState
28+
from tenacity import _utils
2829

2930
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
3031
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])
@@ -46,7 +47,7 @@ async def __call__( # type: ignore[override]
4647

4748
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
4849
while True:
49-
do = self.iter(retry_state=retry_state)
50+
do = await self.iter(retry_state=retry_state)
5051
if isinstance(do, DoAttempt):
5152
try:
5253
result = await fn(*args, **kwargs)
@@ -60,6 +61,47 @@ async def __call__( # type: ignore[override]
6061
else:
6162
return do # type: ignore[no-any-return]
6263

64+
@classmethod
65+
def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
66+
if _utils.is_coroutine_callable(fn):
67+
return fn
68+
69+
async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any:
70+
return fn(*args, **kwargs)
71+
72+
return inner
73+
74+
def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None:
75+
self.iter_state.actions.append(self._wrap_action_func(fn))
76+
77+
async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
78+
self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)(
79+
retry_state
80+
)
81+
82+
async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
83+
if self.wait:
84+
sleep = await self._wrap_action_func(self.wait)(retry_state)
85+
else:
86+
sleep = 0.0
87+
88+
retry_state.upcoming_sleep = sleep
89+
90+
async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override]
91+
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
92+
self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)(
93+
retry_state
94+
)
95+
96+
async def iter(
97+
self, retry_state: "RetryCallState"
98+
) -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa: A003
99+
self._begin_iter(retry_state)
100+
result = None
101+
for action in self.iter_state.actions:
102+
result = await action(retry_state)
103+
return result
104+
63105
def __iter__(self) -> t.Generator[AttemptManager, None, None]:
64106
raise TypeError("AsyncRetrying object is not iterable")
65107

@@ -70,7 +112,7 @@ def __aiter__(self) -> "AsyncRetrying":
70112

71113
async def __anext__(self) -> AttemptManager:
72114
while True:
73-
do = self.iter(retry_state=self._retry_state)
115+
do = await self.iter(retry_state=self._retry_state)
74116
if do is None:
75117
raise StopAsyncIteration
76118
elif isinstance(do, DoAttempt):

tenacity/_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
16+
import functools
17+
import inspect
1718
import sys
1819
import typing
1920
from datetime import timedelta
@@ -76,3 +77,13 @@ def to_seconds(time_unit: time_unit_type) -> float:
7677
return float(
7778
time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit
7879
)
80+
81+
82+
def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool:
83+
if inspect.isclass(call):
84+
return False
85+
if inspect.iscoroutinefunction(call):
86+
return True
87+
partial_call = isinstance(call, functools.partial) and call.func
88+
dunder_call = partial_call or getattr(call, "__call__", None)
89+
return inspect.iscoroutinefunction(dunder_call)

tests/test_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import functools
2+
3+
from tenacity import _utils
4+
5+
6+
def test_is_coroutine_callable() -> None:
7+
async def async_func() -> None:
8+
pass
9+
10+
def sync_func() -> None:
11+
pass
12+
13+
class AsyncClass:
14+
async def __call__(self) -> None:
15+
pass
16+
17+
class SyncClass:
18+
def __call__(self) -> None:
19+
pass
20+
21+
lambda_fn = lambda: None # noqa: E731
22+
23+
partial_async_func = functools.partial(async_func)
24+
partial_sync_func = functools.partial(sync_func)
25+
partial_async_class = functools.partial(AsyncClass().__call__)
26+
partial_sync_class = functools.partial(SyncClass().__call__)
27+
partial_lambda_fn = functools.partial(lambda_fn)
28+
29+
assert _utils.is_coroutine_callable(async_func) is True
30+
assert _utils.is_coroutine_callable(sync_func) is False
31+
assert _utils.is_coroutine_callable(AsyncClass) is False
32+
assert _utils.is_coroutine_callable(AsyncClass()) is True
33+
assert _utils.is_coroutine_callable(SyncClass) is False
34+
assert _utils.is_coroutine_callable(SyncClass()) is False
35+
assert _utils.is_coroutine_callable(lambda_fn) is False
36+
37+
assert _utils.is_coroutine_callable(partial_async_func) is True
38+
assert _utils.is_coroutine_callable(partial_sync_func) is False
39+
assert _utils.is_coroutine_callable(partial_async_class) is True
40+
assert _utils.is_coroutine_callable(partial_sync_class) is False
41+
assert _utils.is_coroutine_callable(partial_lambda_fn) is False

0 commit comments

Comments
 (0)