Skip to content

Make clear_task_instances kwarg only and dag required #50058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,10 @@ def post_clear_task_instances(

if not dry_run:
clear_task_instances(
task_instances,
session,
dag,
DagRunState.QUEUED if reset_dag_runs else False,
tis=task_instances,
session=session,
dag=dag,
dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
)

return TaskInstanceCollectionResponse(
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def clear(
# definition code
assert isinstance(self.dag, SchedulerDAG)

clear_task_instances(results, session, dag=self.dag)
clear_task_instances(tis=results, session=session, dag=self.dag)
session.commit()
return count

Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,8 +1524,8 @@ def clear(

if do_it:
clear_task_instances(
list(tis),
session,
tis=list(tis),
session=session,
dag=self,
dag_run_state=dag_run_state,
)
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With dag being required, we should also get rid of the DagBag inside the function. This should speed things up quite a bit.

Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,10 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None


def clear_task_instances(
*,
tis: list[TaskInstance],
session: Session,
dag: DAG | None = None,
dag: DAG,
dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,8 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu
def test_should_respond_200_with_mapped_task_at_different_try_numbers(
self, test_client, try_number, session
):
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
dag_id = "example_python_operator"
tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}])
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
Expand All @@ -1758,7 +1759,8 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers(
# Record the task instance history
from airflow.models.taskinstance import clear_task_instances

clear_task_instances(tis, session)
dag = self.dagbag.get_dag(dag_id)
clear_task_instances(tis=tis, dag=dag, session=session)
# Simulate the try_number increasing to new values in TI
for ti in tis:
if ti.map_index > 0:
Expand Down Expand Up @@ -2890,7 +2892,9 @@ def test_ti_in_retry_state_not_returned(self, test_client, session):
}

def test_mapped_task_should_respond_200(self, test_client, session):
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
dag_id = "example_python_operator"
dag = self.dagbag.get_dag(dag_id)
tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}])
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
Expand All @@ -2904,7 +2908,7 @@ def test_mapped_task_should_respond_200(self, test_client, session):
# Record the task instance history
from airflow.models.taskinstance import clear_task_instances

clear_task_instances(tis, session)
clear_task_instances(tis=tis, dag=dag, session=session)
# Simulate the try_number increasing to new values in TI
for ti in tis:
if ti.map_index > 0:
Expand Down
24 changes: 12 additions & 12 deletions airflow-core/tests/unit/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_clear_task_instances(self, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)

ti0.refresh_from_db(session)
ti1.refresh_from_db(session)
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_clear_task_instances_external_executor_id(self, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)

ti0.refresh_from_db()

Expand All @@ -142,7 +142,7 @@ def test_clear_task_instances_next_method(self, dag_maker, session):
session.add(ti0)
session.commit()

clear_task_instances([ti0], session, dag=dag)
clear_task_instances(tis=[ti0], session=session, dag=dag)

ti0.refresh_from_db()

Expand Down Expand Up @@ -184,7 +184,7 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
assert session.query(TaskInstanceHistory).count() == 0
clear_task_instances(qry, session, dag_run_state=state, dag=dag)
clear_task_instances(tis=qry, session=session, dag_run_state=state, dag=dag)
session.flush()
# 2 TIs were cleared so 2 history records should be created
assert session.query(TaskInstanceHistory).count() == 2
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)
session.flush()

session.refresh(dr)
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_m
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)
session.flush()

session.refresh(dr)
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_clear_task_instances_without_task(self, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)

# When no task is found, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_clear_task_instances_without_dag(self, dag_maker):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)
clear_task_instances(tis=qry, session=session, dag=dag)

# When no DAG is found, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
Expand Down Expand Up @@ -426,7 +426,7 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session):
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)
clear_task_instances(tis=qry, session=session, dag=dag)

ti0.refresh_from_db(session=session)
ti1.refresh_from_db(session=session)
Expand Down Expand Up @@ -486,7 +486,7 @@ def test_clear_task_instances_in_multiple_dags(self, dag_maker, session):
ti1.run(session=session)

qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, dag1.dag_id))).all()
clear_task_instances(qry, session, dag=dag0)
clear_task_instances(tis=qry, session=session, dag=dag0)

ti0.refresh_from_db(session=session)
ti1.refresh_from_db(session=session)
Expand Down Expand Up @@ -545,7 +545,7 @@ def count_task_reschedule(ti):
.order_by(TI.task_id)
.all()
)
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)
assert count_task_reschedule(ti0) == 0
assert count_task_reschedule(ti1) == 1

Expand Down Expand Up @@ -586,7 +586,7 @@ def test_task_instance_history_record(self, state, state_recorded, dag_maker):
session = dag_maker.session
session.flush()
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session, dag=dag)
clear_task_instances(tis=qry, session=session, dag=dag)
session.flush()

session.refresh(dr)
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)

qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
clear_task_instances(tis=qry, session=session, dag=dag)
session.flush()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
assert dr0.state == state
Expand All @@ -159,8 +159,8 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat
EmptyOperator(task_id="backfill_task_0")
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)

qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
tis = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
clear_task_instances(tis=tis, session=session, dag=dag)
session.flush()
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
assert dr0.state == DagRunState.QUEUED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _clear_task_instances(
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)
clear_task_instances(tis=tis_to_clear, session=session, dag=dag)


def _repair_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,9 @@ def test_clear_skipped_downstream_task(self):
tis = dr.get_task_instances()
with create_session() as session:
clear_task_instances(
[ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag
tis=[ti for ti in tis if ti.task_id == "op1"],
session=session,
dag=short_circuit.dag,
)
self.op1.run(start_date=self.default_date, end_date=self.default_date)
self.assert_expected_task_states(dr, expected_states)
Expand Down Expand Up @@ -1742,7 +1744,7 @@ def f():
tis = dr.get_task_instances()
children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()]
with create_session() as session:
clear_task_instances(children_tis, session=session, dag=branch_op.dag)
clear_task_instances(tis=children_tis, session=session, dag=branch_op.dag)

# Run the cleared tasks again.
for task in branches:
Expand Down
Loading