Skip to content

Commit 3015eff

Browse files
committed
Make clear_task_instances kwarg only and dag required
Making dag required helps me in apache#50040. It's always passed in all production code so, we can make the change. Should be thought of as private anyway. And making it kwarg only, well why not.
1 parent 59d592f commit 3015eff

File tree

9 files changed

+37
-30
lines changed

9 files changed

+37
-30
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -719,10 +719,10 @@ def post_clear_task_instances(
719719

720720
if not dry_run:
721721
clear_task_instances(
722-
task_instances,
723-
session,
724-
dag,
725-
DagRunState.QUEUED if reset_dag_runs else False,
722+
tis=task_instances,
723+
session=session,
724+
dag=dag,
725+
dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
726726
)
727727

728728
return TaskInstanceCollectionResponse(

airflow-core/src/airflow/models/baseoperator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def clear(
381381
# definition code
382382
assert isinstance(self.dag, SchedulerDAG)
383383

384-
clear_task_instances(results, session, dag=self.dag)
384+
clear_task_instances(tis=results, session=session, dag=self.dag)
385385
session.commit()
386386
return count
387387

airflow-core/src/airflow/models/dag.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1524,8 +1524,8 @@ def clear(
15241524

15251525
if do_it:
15261526
clear_task_instances(
1527-
list(tis),
1528-
session,
1527+
tis=list(tis),
1528+
session=session,
15291529
dag=self,
15301530
dag_run_state=dag_run_state,
15311531
)

airflow-core/src/airflow/models/taskinstance.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,10 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None
253253

254254

255255
def clear_task_instances(
256+
*,
256257
tis: list[TaskInstance],
257258
session: Session,
258-
dag: DAG | None = None,
259+
dag: DAG,
259260
dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
260261
) -> None:
261262
"""

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1744,7 +1744,8 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu
17441744
def test_should_respond_200_with_mapped_task_at_different_try_numbers(
17451745
self, test_client, try_number, session
17461746
):
1747-
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
1747+
dag_id = "example_python_operator"
1748+
tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}])
17481749
old_ti = tis[0]
17491750
for idx in (1, 2):
17501751
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
@@ -1758,7 +1759,8 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers(
17581759
# Record the task instance history
17591760
from airflow.models.taskinstance import clear_task_instances
17601761

1761-
clear_task_instances(tis, session)
1762+
dag = self.dagbag.get_dag(dag_id)
1763+
clear_task_instances(tis=tis, dag=dag, session=session)
17621764
# Simulate the try_number increasing to new values in TI
17631765
for ti in tis:
17641766
if ti.map_index > 0:
@@ -2890,7 +2892,9 @@ def test_ti_in_retry_state_not_returned(self, test_client, session):
28902892
}
28912893

28922894
def test_mapped_task_should_respond_200(self, test_client, session):
2893-
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
2895+
dag_id = "example_python_operator"
2896+
dag = self.dagbag.get_dag(dag_id)
2897+
tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}])
28942898
old_ti = tis[0]
28952899
for idx in (1, 2):
28962900
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
@@ -2904,7 +2908,7 @@ def test_mapped_task_should_respond_200(self, test_client, session):
29042908
# Record the task instance history
29052909
from airflow.models.taskinstance import clear_task_instances
29062910

2907-
clear_task_instances(tis, session)
2911+
clear_task_instances(tis=tis, dag=dag, session=session)
29082912
# Simulate the try_number increasing to new values in TI
29092913
for ti in tis:
29102914
if ti.map_index > 0:

airflow-core/tests/unit/models/test_cleartasks.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_clear_task_instances(self, dag_maker):
8585
# but it works for our case because we specifically constructed test DAGS
8686
# in the way that those two sort methods are equivalent
8787
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
88-
clear_task_instances(qry, session, dag=dag)
88+
clear_task_instances(tis=qry, session=session, dag=dag)
8989

9090
ti0.refresh_from_db(session)
9191
ti1.refresh_from_db(session)
@@ -119,7 +119,7 @@ def test_clear_task_instances_external_executor_id(self, dag_maker):
119119
# but it works for our case because we specifically constructed test DAGS
120120
# in the way that those two sort methods are equivalent
121121
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
122-
clear_task_instances(qry, session, dag=dag)
122+
clear_task_instances(tis=qry, session=session, dag=dag)
123123

124124
ti0.refresh_from_db()
125125

@@ -142,7 +142,7 @@ def test_clear_task_instances_next_method(self, dag_maker, session):
142142
session.add(ti0)
143143
session.commit()
144144

145-
clear_task_instances([ti0], session, dag=dag)
145+
clear_task_instances(tis=[ti0], session=session, dag=dag)
146146

147147
ti0.refresh_from_db()
148148

@@ -184,7 +184,7 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
184184
# in the way that those two sort methods are equivalent
185185
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
186186
assert session.query(TaskInstanceHistory).count() == 0
187-
clear_task_instances(qry, session, dag_run_state=state, dag=dag)
187+
clear_task_instances(tis=qry, session=session, dag_run_state=state, dag=dag)
188188
session.flush()
189189
# 2 TIs were cleared so 2 history records should be created
190190
assert session.query(TaskInstanceHistory).count() == 2
@@ -226,7 +226,7 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker):
226226
# but it works for our case because we specifically constructed test DAGS
227227
# in the way that those two sort methods are equivalent
228228
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
229-
clear_task_instances(qry, session, dag=dag)
229+
clear_task_instances(tis=qry, session=session, dag=dag)
230230
session.flush()
231231

232232
session.refresh(dr)
@@ -277,7 +277,7 @@ def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_m
277277
# but it works for our case because we specifically constructed test DAGS
278278
# in the way that those two sort methods are equivalent
279279
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
280-
clear_task_instances(qry, session, dag=dag)
280+
clear_task_instances(tis=qry, session=session, dag=dag)
281281
session.flush()
282282

283283
session.refresh(dr)
@@ -328,7 +328,7 @@ def test_clear_task_instances_without_task(self, dag_maker):
328328
# but it works for our case because we specifically constructed test DAGS
329329
# in the way that those two sort methods are equivalent
330330
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
331-
clear_task_instances(qry, session, dag=dag)
331+
clear_task_instances(tis=qry, session=session, dag=dag)
332332

333333
# When no task is found, max_tries will be maximum of original max_tries or try_number.
334334
ti0.refresh_from_db()
@@ -376,7 +376,7 @@ def test_clear_task_instances_without_dag(self, dag_maker):
376376
# but it works for our case because we specifically constructed test DAGS
377377
# in the way that those two sort methods are equivalent
378378
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
379-
clear_task_instances(qry, session)
379+
clear_task_instances(tis=qry, session=session, dag=dag)
380380

381381
# When no DAG is found, max_tries will be maximum of original max_tries or try_number.
382382
ti0.refresh_from_db()
@@ -426,7 +426,7 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session):
426426
# but it works for our case because we specifically constructed test DAGS
427427
# in the way that those two sort methods are equivalent
428428
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
429-
clear_task_instances(qry, session)
429+
clear_task_instances(tis=qry, session=session, dag=dag)
430430

431431
ti0.refresh_from_db(session=session)
432432
ti1.refresh_from_db(session=session)
@@ -486,7 +486,7 @@ def test_clear_task_instances_in_multiple_dags(self, dag_maker, session):
486486
ti1.run(session=session)
487487

488488
qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, dag1.dag_id))).all()
489-
clear_task_instances(qry, session, dag=dag0)
489+
clear_task_instances(tis=qry, session=session, dag=dag0)
490490

491491
ti0.refresh_from_db(session=session)
492492
ti1.refresh_from_db(session=session)
@@ -545,7 +545,7 @@ def count_task_reschedule(ti):
545545
.order_by(TI.task_id)
546546
.all()
547547
)
548-
clear_task_instances(qry, session, dag=dag)
548+
clear_task_instances(tis=qry, session=session, dag=dag)
549549
assert count_task_reschedule(ti0) == 0
550550
assert count_task_reschedule(ti1) == 1
551551

@@ -586,7 +586,7 @@ def test_task_instance_history_record(self, state, state_recorded, dag_maker):
586586
session = dag_maker.session
587587
session.flush()
588588
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
589-
clear_task_instances(qry, session, dag=dag)
589+
clear_task_instances(tis=qry, session=session, dag=dag)
590590
session.flush()
591591

592592
session.refresh(dr)

airflow-core/tests/unit/models/test_dagrun.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi
145145
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)
146146

147147
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
148-
clear_task_instances(qry, session)
148+
clear_task_instances(tis=qry, session=session, dag=dag)
149149
session.flush()
150150
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
151151
assert dr0.state == state
@@ -159,8 +159,8 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat
159159
EmptyOperator(task_id="backfill_task_0")
160160
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)
161161

162-
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
163-
clear_task_instances(qry, session)
162+
tis = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
163+
clear_task_instances(tis=tis, session=session, dag=dag)
164164
session.flush()
165165
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
166166
assert dr0.state == DagRunState.QUEUED

providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _clear_task_instances(
118118
log.debug("task_ids %s to clear", str(task_ids))
119119
dr: DagRun = _get_dagrun(dag, run_id, session=session)
120120
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
121-
clear_task_instances(tis_to_clear, session)
121+
clear_task_instances(tis=tis_to_clear, session=session, dag=dag)
122122

123123

124124
def _repair_task(

providers/standard/tests/unit/standard/operators/test_python.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,9 @@ def test_clear_skipped_downstream_task(self):
787787
tis = dr.get_task_instances()
788788
with create_session() as session:
789789
clear_task_instances(
790-
[ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag
790+
tis=[ti for ti in tis if ti.task_id == "op1"],
791+
session=session,
792+
dag=short_circuit.dag,
791793
)
792794
self.op1.run(start_date=self.default_date, end_date=self.default_date)
793795
self.assert_expected_task_states(dr, expected_states)
@@ -1742,7 +1744,7 @@ def f():
17421744
tis = dr.get_task_instances()
17431745
children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
17441746
with create_session() as session:
1745-
clear_task_instances(children_tis, session=session, dag=branch_op.dag)
1747+
clear_task_instances(tis=children_tis, session=session, dag=branch_op.dag)
17461748

17471749
# Run the cleared tasks again.
17481750
for task in branches:

0 commit comments

Comments
 (0)