Skip to content

Commit 09b7a48

Browse files
authored
Always pass failed state to retry handlers and support async (#14746)
1 parent e430486 commit 09b7a48

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

src/prefect/task_engine.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def state(self) -> State:
126126
raise ValueError("Task run is not set")
127127
return self.task_run.state
128128

129-
@property
130-
def can_retry(self) -> bool:
129+
def can_retry(self, exc: Exception) -> bool:
131130
retry_condition: Optional[
132131
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
133132
] = self.task.retry_condition_fn
@@ -138,9 +137,19 @@ def can_retry(self) -> bool:
138137
f"Running `retry_condition_fn` check {retry_condition!r} for task"
139138
f" {self.task.name!r}"
140139
)
141-
return not retry_condition or retry_condition(
142-
self.task, self.task_run, self.state
140+
state = Failed(
141+
data=exc,
142+
message=f"Task run encountered unexpected exception: {repr(exc)}",
143143
)
144+
if inspect.iscoroutinefunction(retry_condition):
145+
should_retry = run_coro_as_sync(
146+
retry_condition(self.task, self.task_run, state)
147+
)
148+
elif inspect.isfunction(retry_condition):
149+
should_retry = retry_condition(self.task, self.task_run, state)
150+
else:
151+
should_retry = not retry_condition
152+
return should_retry
144153
except Exception:
145154
self.logger.error(
146155
(
@@ -418,7 +427,7 @@ def handle_retry(self, exc: Exception) -> bool:
418427
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
419428
- If the task has no retries left, or the retry condition is not met, return False.
420429
"""
421-
if self.retries < self.task.retries and self.can_retry:
430+
if self.retries < self.task.retries and self.can_retry(exc):
422431
if self.task.retry_delay_seconds:
423432
delay = (
424433
self.task.retry_delay_seconds[

tests/test_task_engine.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,92 @@ async def test_flow():
935935
"Completed",
936936
]
937937

938+
async def test_task_passes_failed_state_to_retry_fn(self):
939+
mock = MagicMock()
940+
exc = SyntaxError("oops")
941+
handler_mock = MagicMock()
942+
943+
async def handler(task, task_run, state):
944+
handler_mock()
945+
assert state.is_failed()
946+
try:
947+
await state.result()
948+
except SyntaxError:
949+
return True
950+
return False
951+
952+
@task(retries=3, retry_condition_fn=handler)
953+
async def flaky_function():
954+
mock()
955+
if mock.call_count == 2:
956+
return True
957+
raise exc
958+
959+
@flow
960+
async def test_flow():
961+
return await flaky_function(return_state=True)
962+
963+
task_run_state = await test_flow()
964+
task_run_id = task_run_state.state_details.task_run_id
965+
966+
assert task_run_state.is_completed()
967+
assert await task_run_state.result() is True
968+
assert mock.call_count == 2
969+
assert handler_mock.call_count == 1
970+
971+
states = await get_task_run_states(task_run_id)
972+
973+
state_names = [state.name for state in states]
974+
assert state_names == [
975+
"Pending",
976+
"Running",
977+
"Retrying",
978+
"Completed",
979+
]
980+
981+
async def test_task_passes_failed_state_to_retry_fn_sync(self):
982+
mock = MagicMock()
983+
exc = SyntaxError("oops")
984+
handler_mock = MagicMock()
985+
986+
def handler(task, task_run, state):
987+
handler_mock()
988+
assert state.is_failed()
989+
try:
990+
state.result()
991+
except SyntaxError:
992+
return True
993+
return False
994+
995+
@task(retries=3, retry_condition_fn=handler)
996+
def flaky_function():
997+
mock()
998+
if mock.call_count == 2:
999+
return True
1000+
raise exc
1001+
1002+
@flow
1003+
def test_flow():
1004+
return flaky_function(return_state=True)
1005+
1006+
task_run_state = test_flow()
1007+
task_run_id = task_run_state.state_details.task_run_id
1008+
1009+
assert task_run_state.is_completed()
1010+
assert await task_run_state.result() is True
1011+
assert mock.call_count == 2
1012+
assert handler_mock.call_count == 1
1013+
1014+
states = await get_task_run_states(task_run_id)
1015+
1016+
state_names = [state.name for state in states]
1017+
assert state_names == [
1018+
"Pending",
1019+
"Running",
1020+
"Retrying",
1021+
"Completed",
1022+
]
1023+
9381024
async def test_task_retries_receive_latest_task_run_in_context(self):
9391025
state_names: List[str] = []
9401026
run_counts = []

0 commit comments

Comments
 (0)