Skip to content

Commit 39172c3

Browse files
committed
change listener API, add basic support for task instance listeners in TaskSDK, make OpenLineage provider support Airflow 3's listener interface
Signed-off-by: Maciej Obuchowski <[email protected]>
1 parent dafd166 commit 39172c3

File tree

44 files changed

+1931
-799
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1931
-799
lines changed

airflow/api_fastapi/execution_api/datamodels/taskinstance.py

+4
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ class DagRun(StrictBaseModel):
222222
run_after: UtcDateTime
223223
start_date: UtcDateTime
224224
end_date: UtcDateTime | None
225+
clear_number: int
225226
run_type: DagRunType
226227
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
227228
external_trigger: bool = False
@@ -233,6 +234,9 @@ class TIRunContext(BaseModel):
233234
dag_run: DagRun
234235
"""DAG run information for the task instance."""
235236

237+
task_reschedule_count: Annotated[int, Field(default=0)]
238+
"""How many times the task has been rescheduled."""
239+
236240
max_tries: int
237241
"""Maximum number of tries for the task instance (from DB)."""
238242

airflow/api_fastapi/execution_api/routes/task_instances.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from fastapi import Body, HTTPException, status
2525
from pydantic import JsonValue
26-
from sqlalchemy import update
26+
from sqlalchemy import func, update
2727
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
2828
from sqlalchemy.sql import select
2929

@@ -79,14 +79,23 @@ def ti_run(
7979
ti_id_str = str(task_instance_id)
8080

8181
old = (
82-
select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.max_tries)
82+
select(
83+
TI.state,
84+
TI.dag_id,
85+
TI.run_id,
86+
TI.task_id,
87+
TI.map_index,
88+
TI.next_method,
89+
TI.try_number,
90+
TI.max_tries,
91+
)
8392
.where(TI.id == ti_id_str)
8493
.with_for_update()
8594
)
8695
try:
87-
(previous_state, dag_id, run_id, task_id, map_index, next_method, max_tries) = session.execute(
88-
old
89-
).one()
96+
(previous_state, dag_id, run_id, task_id, map_index, next_method, try_number, max_tries) = (
97+
session.execute(old).one()
98+
)
9099
except NoResultFound:
91100
log.error("Task Instance %s not found", ti_id_str)
92101
raise HTTPException(
@@ -147,6 +156,7 @@ def ti_run(
147156
DR.run_after,
148157
DR.start_date,
149158
DR.end_date,
159+
DR.clear_number,
150160
DR.run_type,
151161
DR.conf,
152162
DR.logical_date,
@@ -170,8 +180,24 @@ def ti_run(
170180
session=session,
171181
)
172182

183+
task_reschedule_count = (
184+
session.query(
185+
func.count(TaskReschedule.id) # or any other primary key column
186+
)
187+
.filter(
188+
TaskReschedule.dag_id == dag_id,
189+
TaskReschedule.task_id == ti_id_str,
190+
TaskReschedule.run_id == run_id,
191+
# TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks
192+
TaskReschedule.try_number == try_number,
193+
)
194+
.scalar()
195+
or 0
196+
)
197+
173198
return TIRunContext(
174199
dag_run=DagRun.model_validate(dr, from_attributes=True),
200+
task_reschedule_count=task_reschedule_count,
175201
max_tries=max_tries,
176202
# TODO: Add variables and connections that are needed (and has perms) for the task
177203
variables=[],

airflow/example_dags/plugins/event_listener.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323

2424
if TYPE_CHECKING:
2525
from airflow.models.dagrun import DagRun
26-
from airflow.models.taskinstance import TaskInstance
26+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
2727
from airflow.utils.state import TaskInstanceState
2828

2929

3030
# [START howto_listen_ti_running_task]
3131
@hookimpl
32-
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
32+
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
3333
"""
3434
This method is called when task state changes to RUNNING.
3535
Through callback, parameters like previous_task_state, task_instance object can be accessed.
@@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
3939
print("Task instance is in running state")
4040
print(" Previous state of the Task instance:", previous_state)
4141

42-
state: TaskInstanceState = task_instance.state
4342
name: str = task_instance.task_id
44-
start_date = task_instance.start_date
4543

46-
dagrun = task_instance.dag_run
47-
dagrun_status = dagrun.state
44+
context = task_instance.get_template_context()
4845

49-
task = task_instance.task
46+
task = context["task"]
5047

5148
if TYPE_CHECKING:
5249
assert task
@@ -55,16 +52,16 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
5552
dag_name = None
5653
if dag:
5754
dag_name = dag.dag_id
58-
print(f"Current task name:{name} state:{state} start_date:{start_date}")
59-
print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}")
55+
print(f"Current task name:{name}")
56+
print(f"Dag name:{dag_name}")
6057

6158

6259
# [END howto_listen_ti_running_task]
6360

6461

6562
# [START howto_listen_ti_success_task]
6663
@hookimpl
67-
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
64+
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
6865
"""
6966
This method is called when task state changes to SUCCESS.
7067
Through callback, parameters like previous_task_state, task_instance object can be accessed.
@@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
7471
print("Task instance in success state")
7572
print(" Previous state of the Task instance:", previous_state)
7673

77-
dag_id = task_instance.dag_id
78-
hostname = task_instance.hostname
79-
operator = task_instance.operator
74+
context = task_instance.get_template_context()
75+
operator = context["task"]
8076

81-
dagrun = task_instance.dag_run
82-
queued_at = dagrun.queued_at
83-
print(f"Dag name:{dag_id} queued_at:{queued_at}")
84-
print(f"Task hostname:{hostname} operator:{operator}")
77+
print(f"Task operator:{operator}")
8578

8679

8780
# [END howto_listen_ti_success_task]
@@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
9083
# [START howto_listen_ti_failure_task]
9184
@hookimpl
9285
def on_task_instance_failed(
93-
previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session
86+
previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, error: None | str | BaseException
9487
):
9588
"""
9689
This method is called when task state changes to FAILED.
@@ -100,21 +93,14 @@ def on_task_instance_failed(
10093
"""
10194
print("Task instance in failure state")
10295

103-
start_date = task_instance.start_date
104-
end_date = task_instance.end_date
105-
duration = task_instance.duration
106-
107-
dagrun = task_instance.dag_run
108-
109-
task = task_instance.task
96+
context = task_instance.get_template_context()
97+
task = context["task"]
11098

11199
if TYPE_CHECKING:
112100
assert task
113101

114-
dag = task.dag
115-
116-
print(f"Task start:{start_date} end:{end_date} duration:{duration}")
117-
print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
102+
print("Task start")
103+
print(f"Task:{task}")
118104
if error:
119105
print(f"Failure caused by {error}")
120106

airflow/listeners/listener.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ class ListenerManager:
4646
"""Manage listener registration and provides hook property for calling them."""
4747

4848
def __init__(self):
49-
from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance
49+
from airflow.listeners.spec import (
50+
asset,
51+
dagrun,
52+
importerrors,
53+
lifecycle,
54+
taskinstance,
55+
)
5056

5157
self.pm = pluggy.PluginManager("airflow")
5258
self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)

airflow/listeners/spec/taskinstance.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,26 @@
2222
from pluggy import HookspecMarker
2323

2424
if TYPE_CHECKING:
25-
from sqlalchemy.orm.session import Session
26-
27-
from airflow.models.taskinstance import TaskInstance
25+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
2826
from airflow.utils.state import TaskInstanceState
2927

3028
hookspec = HookspecMarker("airflow")
3129

3230

3331
@hookspec
34-
def on_task_instance_running(
35-
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
36-
):
32+
def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
3733
"""Execute when task state changes to RUNNING. previous_state can be None."""
3834

3935

4036
@hookspec
41-
def on_task_instance_success(
42-
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
43-
):
37+
def on_task_instance_success(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
4438
"""Execute when task state changes to SUCCESS. previous_state can be None."""
4539

4640

4741
@hookspec
4842
def on_task_instance_failed(
4943
previous_state: TaskInstanceState | None,
50-
task_instance: TaskInstance,
44+
task_instance: RuntimeTaskInstance,
5145
error: None | str | BaseException,
52-
session: Session | None,
5346
):
5447
"""Execute when task state changes to FAIL. previous_state can be None."""

airflow/models/taskinstance.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def _run_raw_task(
383383
TaskInstance.save_to_db(ti=ti, session=session)
384384
if ti.state == TaskInstanceState.SUCCESS:
385385
get_listener_manager().hook.on_task_instance_success(
386-
previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
386+
previous_state=TaskInstanceState.RUNNING, task_instance=ti
387387
)
388388

389389
return None
@@ -1903,6 +1903,7 @@ def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol:
19031903
max_tries=self.max_tries,
19041904
hostname=self.hostname,
19051905
_ti_context_from_server=context_from_server,
1906+
start_date=self.start_date,
19061907
)
19071908

19081909
return runtime_ti
@@ -2925,7 +2926,7 @@ def signal_handler(signum, frame):
29252926

29262927
# Run on_task_instance_running event
29272928
get_listener_manager().hook.on_task_instance_running(
2928-
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
2929+
previous_state=TaskInstanceState.QUEUED, task_instance=self
29292930
)
29302931

29312932
def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
@@ -3167,7 +3168,7 @@ def fetch_handle_failure_context(
31673168
callbacks = task.on_retry_callback if task else None
31683169

31693170
get_listener_manager().hook.on_task_instance_failed(
3170-
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
3171+
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
31713172
)
31723173

31733174
return {

airflow/utils/context.py

+2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@
9696
"prev_end_date_success",
9797
"reason",
9898
"run_id",
99+
"start_date",
99100
"task",
101+
"task_reschedule_count",
100102
"task_instance",
101103
"task_instance_key_str",
102104
"test_mode",

docs/apache-airflow/administration-and-deployment/listeners.rst

+11-9
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ For example if you want to implement a listener that uses the ``error`` field in
165165
...
166166
167167
@hookimpl
168-
def on_task_instance_failed(
169-
self, previous_state, task_instance, error: None | str | BaseException, session
170-
):
168+
def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException):
171169
# Handle error case here
172170
pass
173171
@@ -177,15 +175,19 @@ For example if you want to implement a listener that uses the ``error`` field in
177175
...
178176
179177
@hookimpl
180-
def on_task_instance_failed(self, previous_state, task_instance, session):
178+
def on_task_instance_failed(self, previous_state, task_instance):
181179
# Handle no error case here
182180
pass
183181
184182
List of changes in the listener interfaces since 2.8.0 when they were introduced:
185183

186184

187-
+-----------------+-----------------------------+---------------------------------------+
188-
| Airflow Version | Affected method | Change |
189-
+=================+=============================+=======================================+
190-
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
191-
+-----------------+-----------------------------+---------------------------------------+
185+
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
186+
| Airflow Version | Affected method | Change |
187+
+=================+============================================+=========================================================================+
188+
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
189+
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
190+
| 3.0.0 | ``on_task_instance_running``, | ``session`` argument removed from task instance listeners, |
191+
| | ``on_task_instance_success``, | ``task_instance`` object is now an instance of ``RuntimeTaskInstance`` |
192+
| | ``on_task_instance_failed`` | |
193+
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+

docs/apache-airflow/templates-ref.rst

+2
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ Variable Type Description
6161
| ``None``
6262
``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available).
6363
| ``None``
64+
``{{ start_date }}`` `pendulum.DateTime`_ Datetime of when current task has been started.
6465
``{{ inlets }}`` list List of inlets declared on the task.
6566
``{{ inlet_events }}`` dict[str, ...] Access past events of inlet assets. See :doc:`Assets <authoring-and-scheduling/datasets>`. Added in version 2.10.
6667
``{{ outlets }}`` list List of outlets declared on the task.
6768
``{{ outlet_events }}`` dict[str, ...] | Accessors to attach information to asset events that will be emitted by the current task.
6869
| See :doc:`Assets <authoring-and-scheduling/datasets>`. Added in version 2.10.
6970
``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs <core-concepts/dags>`.
7071
``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators`
72+
``{{ task_reschedule_count }}`` int How many times current task has been rescheduled. Relevant to ``mode="reschedule"`` sensors.
7173
``{{ macros }}`` | A reference to the macros package. See Macros_ below.
7274
``{{ task_instance }}`` TaskInstance The currently running :class:`~airflow.models.taskinstance.TaskInstance`.
7375
``{{ ti }}`` TaskInstance Same as ``{{ task_instance }}``.

providers/edge/tests/provider_tests/edge/cli/test_edge_command.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
"pool_slots": 1,
5252
"queue": "default",
5353
"priority_weight": 1,
54+
"start_date": "2023-01-01T00:00:00+00:00",
55+
"map_index": -1,
5456
},
5557
"dag_rel_path": "mock.py",
5658
"log_path": "mock.log",

providers/edge/tests/provider_tests/edge/executors/test_edge_executor.py

+1
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def test_queue_workload(self):
300300
pool_slots=1,
301301
queue="default",
302302
priority_weight=1,
303+
start_date=timezone.utcnow(),
303304
),
304305
dag_rel_path="mock.py",
305306
log_path="mock.log",

0 commit comments

Comments
 (0)