Skip to content

Commit b552bca

Browse files
authored
Mark DagRun as success when no teardown tasks are running (#49752)
* Mark DagRun as success when no teardown tasks are running * Change to unfinished teardown tasks, modified unit tests * Fix in unit tests * Fix pytest.xdist issues * Only select task_ids and convert to set, update unit tests
1 parent 7b389c3 commit b552bca

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

airflow-core/src/airflow/api/common/mark_tasks.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,31 @@ def set_dag_run_state_to_success(
215215
if not run_id:
216216
raise ValueError(f"Invalid dag_run_id: {run_id}")
217217

218-
# Mark all task instances of the dag run to success - except for teardown as they need to complete work.
218+
# Mark all task instances of the dag run to success - except for unfinished teardown as they need to complete work.
219219
normal_tasks = [task for task in dag.tasks if not task.is_teardown]
220+
teardown_tasks = [task for task in dag.tasks if task.is_teardown]
221+
unfinished_teardown_task_ids = set(
222+
session.scalars(
223+
select(TaskInstance.task_id).where(
224+
TaskInstance.dag_id == dag.dag_id,
225+
TaskInstance.run_id == run_id,
226+
TaskInstance.task_id.in_([task.task_id for task in teardown_tasks]),
227+
or_(TaskInstance.state.is_(None), TaskInstance.state.in_(State.unfinished)),
228+
)
229+
).all()
230+
)
220231

221-
# Mark the dag run to success.
222-
if commit and len(normal_tasks) == len(dag.tasks):
232+
# Mark the dag run to success if there are no unfinished teardown tasks.
233+
if commit and len(unfinished_teardown_task_ids) == 0:
223234
_set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
224235

225-
for task in normal_tasks:
236+
tasks_to_mark_success = normal_tasks + [
237+
task for task in teardown_tasks if task.task_id not in unfinished_teardown_task_ids
238+
]
239+
for task in tasks_to_mark_success:
226240
task.dag = dag
227241
return set_state(
228-
tasks=normal_tasks,
242+
tasks=tasks_to_mark_success,
229243
run_id=run_id,
230244
state=TaskInstanceState.SUCCESS,
231245
commit=commit,

airflow-core/tests/unit/api/common/test_mark_tasks.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from typing import TYPE_CHECKING
2020

2121
import pytest
22+
from sqlalchemy import select
2223

2324
from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, set_dag_run_state_to_success
25+
from airflow.models.dagrun import DagRun
2426
from airflow.providers.standard.operators.empty import EmptyOperator
25-
from airflow.utils.state import TaskInstanceState
27+
from airflow.utils.state import DagRunState, State, TaskInstanceState
2628

2729
if TYPE_CHECKING:
2830
from airflow.models.taskinstance import TaskInstance
@@ -54,23 +56,63 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker):
5456
assert "teardown" not in task_dict
5557

5658

57-
def test_set_dag_run_state_to_success(dag_maker: DagMaker):
59+
@pytest.mark.parametrize(
60+
"unfinished_state", sorted([state for state in State.unfinished if state is not None])
61+
)
62+
def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, unfinished_state):
5863
with dag_maker("TEST_DAG_1"):
5964
with EmptyOperator(task_id="teardown").as_teardown():
6065
EmptyOperator(task_id="running")
6166
EmptyOperator(task_id="pending")
67+
6268
dr = dag_maker.create_dagrun()
6369
for ti in dr.get_task_instances():
6470
if ti.task_id == "running":
6571
ti.set_state(TaskInstanceState.RUNNING)
72+
if ti.task_id == "teardown":
73+
ti.set_state(unfinished_state)
74+
6675
dag_maker.session.flush()
6776
assert dr.dag
77+
assert dr.state == DagRunState.RUNNING
6878

6979
updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
7080
dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
7181
)
82+
run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id))
83+
assert run.state != DagRunState.SUCCESS
7284
assert len(updated_tis) == 2
7385
task_dict = {ti.task_id: ti for ti in updated_tis}
7486
assert task_dict["running"].state == TaskInstanceState.SUCCESS
7587
assert task_dict["pending"].state == TaskInstanceState.SUCCESS
7688
assert "teardown" not in task_dict
89+
90+
91+
@pytest.mark.parametrize("finished_state", sorted(list(State.finished)))
92+
def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker, finished_state):
93+
with dag_maker("TEST_DAG_1"):
94+
with EmptyOperator(task_id="teardown").as_teardown():
95+
EmptyOperator(task_id="failed")
96+
dr = dag_maker.create_dagrun()
97+
for ti in dr.get_task_instances():
98+
if ti.task_id == "failed":
99+
ti.set_state(TaskInstanceState.FAILED)
100+
if ti.task_id == "teardown":
101+
ti.set_state(finished_state)
102+
dag_maker.session.flush()
103+
dr.set_state(DagRunState.FAILED)
104+
assert dr.dag
105+
106+
updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
107+
dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
108+
)
109+
run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id))
110+
assert run.state == DagRunState.SUCCESS
111+
if finished_state == TaskInstanceState.SUCCESS:
112+
assert len(updated_tis) == 1
113+
else:
114+
assert len(updated_tis) == 2
115+
task_dict = {ti.task_id: ti for ti in updated_tis}
116+
assert task_dict["failed"].state == TaskInstanceState.SUCCESS
117+
if finished_state != TaskInstanceState.SUCCESS:
118+
assert task_dict["teardown"].state == TaskInstanceState.SUCCESS

0 commit comments

Comments
 (0)