Skip to content

Commit 293a383

Browse files
authored
Run Task failure callbacks on DAG Processor when task is externally killed (apache#53058) (apache#53143)
Until apache#44354 is implemented, tasks killed externally or when supervisor process dies unexpectedly, users have no way of knowing this happened. This has been a blocker for Airflow 3.0 adoption for some: - apache#44354 - https://apache-airflow.slack.com/archives/C07813CNKA8/p1751057525231389 apache#44354 is more involved and we might not get to it for Airflow 3.1 -- so this is a good fix until then similar to how we run Dag Run callback. (cherry-picked from a5211f2)
1 parent 0822328 commit 293a383

File tree

8 files changed

+472
-124
lines changed

8 files changed

+472
-124
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class TIRunContext(BaseModel):
302302
dag_run: DagRun
303303
"""DAG run information for the task instance."""
304304

305-
task_reschedule_count: Annotated[int, Field(default=0)]
305+
task_reschedule_count: int = 0
306306
"""How many times the task has been rescheduled."""
307307

308308
max_tries: int
@@ -328,7 +328,7 @@ class TIRunContext(BaseModel):
328328
xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
329329
"""List of Xcom keys that need to be cleared and purged on by the worker."""
330330

331-
should_retry: bool
331+
should_retry: bool = False
332332
"""If the ti encounters an error, whether it should enter retry or failed state."""
333333

334334

airflow-core/src/airflow/callbacks/callback_requests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest):
6161
"""Simplified Task Instance representation"""
6262
task_callback_type: TaskInstanceState | None = None
6363
"""Whether on success, on failure, on retry"""
64+
context_from_server: ti_datamodel.TIRunContext | None = None
65+
"""Task execution context from the Server"""
6466
type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest"
6567

6668
@property

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import contextlib
1920
import importlib
2021
import os
2122
import sys
2223
import traceback
24+
from collections.abc import Callable, Sequence
2325
from pathlib import Path
24-
from typing import TYPE_CHECKING, Annotated, BinaryIO, Callable, ClassVar, Literal, Union
26+
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal, Union
2527

2628
import attrs
2729
from pydantic import BaseModel, Field, TypeAdapter
@@ -44,9 +46,11 @@
4446
VariableResult,
4547
)
4648
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
49+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
4750
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
4851
from airflow.stats import Stats
4952
from airflow.utils.file import iter_airflow_imports
53+
from airflow.utils.state import TaskInstanceState
5054

5155
if TYPE_CHECKING:
5256
from structlog.typing import FilteringBoundLogger
@@ -200,10 +204,7 @@ def _execute_callbacks(
200204
for request in callback_requests:
201205
log.debug("Processing Callback Request", request=request.to_json())
202206
if isinstance(request, TaskCallbackRequest):
203-
raise NotImplementedError(
204-
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
205-
)
206-
# _execute_task_callbacks(dagbag, request)
207+
_execute_task_callbacks(dagbag, request, log)
207208
if isinstance(request, DagCallbackRequest):
208209
_execute_dag_callbacks(dagbag, request, log)
209210

@@ -237,6 +238,67 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
237238
Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id})
238239

239240

241+
def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: FilteringBoundLogger) -> None:
242+
if not request.is_failure_callback:
243+
log.warning(
244+
"Task callback requested but is not a failure callback",
245+
dag_id=request.ti.dag_id,
246+
task_id=request.ti.task_id,
247+
run_id=request.ti.run_id,
248+
)
249+
return
250+
251+
dag = dagbag.dags[request.ti.dag_id]
252+
task = dag.get_task(request.ti.task_id)
253+
254+
if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY:
255+
callbacks = task.on_retry_callback
256+
else:
257+
callbacks = task.on_failure_callback
258+
259+
if not callbacks:
260+
log.warning(
261+
"Callback requested but no callback found",
262+
dag_id=request.ti.dag_id,
263+
task_id=request.ti.task_id,
264+
run_id=request.ti.run_id,
265+
ti_id=request.ti.id,
266+
)
267+
return
268+
269+
callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks]
270+
ctx_from_server = request.context_from_server
271+
272+
if ctx_from_server is not None:
273+
runtime_ti = RuntimeTaskInstance.model_construct(
274+
**request.ti.model_dump(exclude_unset=True),
275+
task=task,
276+
_ti_context_from_server=ctx_from_server,
277+
max_tries=ctx_from_server.max_tries,
278+
)
279+
else:
280+
runtime_ti = RuntimeTaskInstance.model_construct(
281+
**request.ti.model_dump(exclude_unset=True),
282+
task=task,
283+
)
284+
context = runtime_ti.get_template_context()
285+
286+
def get_callback_representation(callback):
287+
with contextlib.suppress(AttributeError):
288+
return callback.__name__
289+
with contextlib.suppress(AttributeError):
290+
return callback.__class__.__name__
291+
return callback
292+
293+
for idx, callback in enumerate(callbacks):
294+
callback_repr = get_callback_representation(callback)
295+
log.info("Executing Task callback at index %d: %s", idx, callback_repr)
296+
try:
297+
callback(context)
298+
except Exception:
299+
log.exception("Error in callback at index %d: %s", idx, callback_repr)
300+
301+
240302
def in_process_api_server() -> InProcessExecutionAPI:
241303
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
242304

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sqlalchemy.sql import expression
3939

4040
from airflow import settings
41+
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
4142
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
4243
from airflow.configuration import conf
4344
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
@@ -945,10 +946,16 @@ def process_executor_events(
945946
bundle_version=ti.dag_version.bundle_version,
946947
ti=ti,
947948
msg=msg,
949+
context_from_server=TIRunContext(
950+
dag_run=ti.dag_run,
951+
max_tries=ti.max_tries,
952+
variables=[],
953+
connections=[],
954+
xcom_keys_to_clear=[],
955+
),
948956
)
949957
executor.send_callback(request)
950-
else:
951-
ti.handle_failure(error=msg, session=session)
958+
ti.handle_failure(error=msg, session=session)
952959

953960
return len(event_buffer)
954961

@@ -2296,6 +2303,13 @@ def _purge_task_instances_without_heartbeats(
22962303
bundle_version=ti.dag_run.bundle_version,
22972304
ti=ti,
22982305
msg=str(task_instance_heartbeat_timeout_message_details),
2306+
context_from_server=TIRunContext(
2307+
dag_run=ti.dag_run,
2308+
max_tries=ti.max_tries,
2309+
variables=[],
2310+
connections=[],
2311+
xcom_keys_to_clear=[],
2312+
),
22992313
)
23002314
session.add(
23012315
Log(

airflow-core/tests/unit/callbacks/test_callback_requests.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from airflow.models.taskinstance import TaskInstance
2929
from airflow.providers.standard.operators.bash import BashOperator
3030
from airflow.utils import timezone
31-
from airflow.utils.state import State
31+
from airflow.utils.state import State, TaskInstanceState
3232

3333
pytestmark = pytest.mark.db_test
3434

@@ -85,3 +85,30 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
8585
json_str = input.to_json()
8686
result = TaskCallbackRequest.from_json(json_str)
8787
assert input == result
88+
89+
@pytest.mark.parametrize(
90+
"task_callback_type,expected_is_failure",
91+
[
92+
(None, True),
93+
(TaskInstanceState.FAILED, True),
94+
(TaskInstanceState.UP_FOR_RETRY, True),
95+
(TaskInstanceState.UPSTREAM_FAILED, True),
96+
(TaskInstanceState.SUCCESS, False),
97+
(TaskInstanceState.RUNNING, False),
98+
],
99+
)
100+
def test_is_failure_callback_property(
101+
self, task_callback_type, expected_is_failure, create_task_instance
102+
):
103+
"""Test is_failure_callback property with different task callback types"""
104+
ti = create_task_instance()
105+
106+
request = TaskCallbackRequest(
107+
filepath="filepath",
108+
ti=ti,
109+
bundle_name="testing",
110+
bundle_version=None,
111+
task_callback_type=task_callback_type,
112+
)
113+
114+
assert request.is_failure_callback == expected_is_failure

0 commit comments

Comments
 (0)