|
19 | 19 | from typing import TYPE_CHECKING |
20 | 20 |
|
21 | 21 | import pytest |
| 22 | +from sqlalchemy import select |
22 | 23 |
|
23 | 24 | from airflow.api.common.mark_tasks import set_dag_run_state_to_failed, set_dag_run_state_to_success |
| 25 | +from airflow.models.dagrun import DagRun |
24 | 26 | from airflow.providers.standard.operators.empty import EmptyOperator |
25 | | -from airflow.utils.state import TaskInstanceState |
| 27 | +from airflow.utils.state import DagRunState, State, TaskInstanceState |
26 | 28 |
|
27 | 29 | if TYPE_CHECKING: |
28 | 30 | from airflow.models.taskinstance import TaskInstance |
@@ -54,23 +56,63 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker): |
54 | 56 | assert "teardown" not in task_dict |
55 | 57 |
|
56 | 58 |
|
57 | | -def test_set_dag_run_state_to_success(dag_maker: DagMaker): |
| 59 | +@pytest.mark.parametrize( |
| 60 | + "unfinished_state", sorted([state for state in State.unfinished if state is not None]) |
| 61 | +) |
| 62 | +def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, unfinished_state): |
58 | 63 | with dag_maker("TEST_DAG_1"): |
59 | 64 | with EmptyOperator(task_id="teardown").as_teardown(): |
60 | 65 | EmptyOperator(task_id="running") |
61 | 66 | EmptyOperator(task_id="pending") |
| 67 | + |
62 | 68 | dr = dag_maker.create_dagrun() |
63 | 69 | for ti in dr.get_task_instances(): |
64 | 70 | if ti.task_id == "running": |
65 | 71 | ti.set_state(TaskInstanceState.RUNNING) |
| 72 | + if ti.task_id == "teardown": |
| 73 | + ti.set_state(unfinished_state) |
| 74 | + |
66 | 75 | dag_maker.session.flush() |
67 | 76 | assert dr.dag |
| 77 | + assert dr.state == DagRunState.RUNNING |
68 | 78 |
|
69 | 79 | updated_tis: list[TaskInstance] = set_dag_run_state_to_success( |
70 | 80 | dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session |
71 | 81 | ) |
| 82 | + run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id)) |
| 83 | + assert run.state != DagRunState.SUCCESS |
72 | 84 | assert len(updated_tis) == 2 |
73 | 85 | task_dict = {ti.task_id: ti for ti in updated_tis} |
74 | 86 | assert task_dict["running"].state == TaskInstanceState.SUCCESS |
75 | 87 | assert task_dict["pending"].state == TaskInstanceState.SUCCESS |
76 | 88 | assert "teardown" not in task_dict |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.parametrize("finished_state", sorted(list(State.finished))) |
| 92 | +def test_set_dag_run_state_to_success_finished_teardown(dag_maker: DagMaker, finished_state): |
| 93 | + with dag_maker("TEST_DAG_1"): |
| 94 | + with EmptyOperator(task_id="teardown").as_teardown(): |
| 95 | + EmptyOperator(task_id="failed") |
| 96 | + dr = dag_maker.create_dagrun() |
| 97 | + for ti in dr.get_task_instances(): |
| 98 | + if ti.task_id == "failed": |
| 99 | + ti.set_state(TaskInstanceState.FAILED) |
| 100 | + if ti.task_id == "teardown": |
| 101 | + ti.set_state(finished_state) |
| 102 | + dag_maker.session.flush() |
| 103 | + dr.set_state(DagRunState.FAILED) |
| 104 | + assert dr.dag |
| 105 | + |
| 106 | + updated_tis: list[TaskInstance] = set_dag_run_state_to_success( |
| 107 | + dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session |
| 108 | + ) |
| 109 | + run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, run_id=dr.run_id)) |
| 110 | + assert run.state == DagRunState.SUCCESS |
| 111 | + if finished_state == TaskInstanceState.SUCCESS: |
| 112 | + assert len(updated_tis) == 1 |
| 113 | + else: |
| 114 | + assert len(updated_tis) == 2 |
| 115 | + task_dict = {ti.task_id: ti for ti in updated_tis} |
| 116 | + assert task_dict["failed"].state == TaskInstanceState.SUCCESS |
| 117 | + if finished_state != TaskInstanceState.SUCCESS: |
| 118 | + assert task_dict["teardown"].state == TaskInstanceState.SUCCESS |
0 commit comments