diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index f46facc3450e7..30a973757ad83 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -719,10 +719,10 @@ def post_clear_task_instances( if not dry_run: clear_task_instances( - task_instances, - session, - dag, - DagRunState.QUEUED if reset_dag_runs else False, + tis=task_instances, + session=session, + dag=dag, + dag_run_state=DagRunState.QUEUED if reset_dag_runs else False, ) return TaskInstanceCollectionResponse( diff --git a/airflow-core/src/airflow/models/baseoperator.py b/airflow-core/src/airflow/models/baseoperator.py index 9b57d3d5251e1..9433fad7da078 100644 --- a/airflow-core/src/airflow/models/baseoperator.py +++ b/airflow-core/src/airflow/models/baseoperator.py @@ -381,7 +381,7 @@ def clear( # definition code assert isinstance(self.dag, SchedulerDAG) - clear_task_instances(results, session, dag=self.dag) + clear_task_instances(tis=results, session=session, dag=self.dag) session.commit() return count diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index aa9ad78fbcce7..00a020f3315f4 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -1524,8 +1524,8 @@ def clear( if do_it: clear_task_instances( - list(tis), - session, + tis=list(tis), + session=session, dag=self, dag_run_state=dag_run_state, ) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 52c1d7d5e8ea9..3ba47e2c67213 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -253,9 +253,10 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None def clear_task_instances( + *, tis: list[TaskInstance], session: Session, - dag: DAG | None = None, + dag: DAG, dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED, ) -> None: """ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index af2d23014061b..a3208b05e02ff 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -1744,7 +1744,8 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu def test_should_respond_200_with_mapped_task_at_different_try_numbers( self, test_client, try_number, session ): - tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}]) + dag_id = "example_python_operator" + tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}]) old_ti = tis[0] for idx in (1, 2): ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) @@ -1758,7 +1759,8 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers( # Record the task instance history from airflow.models.taskinstance import clear_task_instances - clear_task_instances(tis, session) + dag = self.dagbag.get_dag(dag_id) + clear_task_instances(tis=tis, dag=dag, session=session) # Simulate the try_number increasing to new values in TI for ti in tis: if ti.map_index > 0: @@ -2890,7 +2892,9 @@ def test_ti_in_retry_state_not_returned(self, test_client, session): } def test_mapped_task_should_respond_200(self, test_client, session): - tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}]) + dag_id = "example_python_operator" + dag = self.dagbag.get_dag(dag_id) + tis = self.create_task_instances(session, dag_id=dag_id, task_instances=[{"state": State.FAILED}]) old_ti = tis[0] for idx in (1, 2): ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) @@ -2904,7 +2908,7 @@ def test_mapped_task_should_respond_200(self, test_client, session): # Record the task instance history from airflow.models.taskinstance import clear_task_instances - clear_task_instances(tis, session) + clear_task_instances(tis=tis, dag=dag, session=session) # Simulate the try_number increasing to new values in TI for ti in tis: if ti.map_index > 0: diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 9bcae8cceba51..76a8ba2bcdef8 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -85,7 +85,7 @@ def test_clear_task_instances(self, dag_maker): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) ti0.refresh_from_db(session) ti1.refresh_from_db(session) @@ -119,7 +119,7 @@ def test_clear_task_instances_external_executor_id(self, dag_maker): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) ti0.refresh_from_db() @@ -142,7 +142,7 @@ def test_clear_task_instances_next_method(self, dag_maker, session): session.add(ti0) session.commit() - clear_task_instances([ti0], session, dag=dag) + clear_task_instances(tis=[ti0], session=session, dag=dag) ti0.refresh_from_db() @@ -184,7 +184,7 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker): # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() assert session.query(TaskInstanceHistory).count() == 0 - clear_task_instances(qry, session, dag_run_state=state, dag=dag) + clear_task_instances(tis=qry, session=session, dag_run_state=state, dag=dag) session.flush() # 2 TIs were cleared so 2 history records should be created assert session.query(TaskInstanceHistory).count() == 2 @@ -226,7 +226,7 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) session.flush() session.refresh(dr) @@ -277,7 +277,7 @@ def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_m # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) session.flush() session.refresh(dr) @@ -328,7 +328,7 @@ def test_clear_task_instances_without_task(self, dag_maker): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) # When no task is found, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() @@ -376,7 +376,7 @@ def test_clear_task_instances_without_dag(self, dag_maker): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session) + clear_task_instances(tis=qry, session=session, dag=dag) # When no DAG is found, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() @@ -426,7 +426,7 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session): # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session) + clear_task_instances(tis=qry, session=session, dag=dag) ti0.refresh_from_db(session=session) ti1.refresh_from_db(session=session) @@ -486,7 +486,7 @@ def test_clear_task_instances_in_multiple_dags(self, dag_maker, session): ti1.run(session=session) qry = session.query(TI).filter(TI.dag_id.in_((dag0.dag_id, dag1.dag_id))).all() - clear_task_instances(qry, session, dag=dag0) + clear_task_instances(tis=qry, session=session, dag=dag0) ti0.refresh_from_db(session=session) ti1.refresh_from_db(session=session) @@ -545,7 +545,7 @@ def count_task_reschedule(ti): .order_by(TI.task_id) .all() ) - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) assert count_task_reschedule(ti0) == 0 assert count_task_reschedule(ti1) == 1 @@ -586,7 +586,7 @@ def test_task_instance_history_record(self, state, state_recorded, dag_maker): session = dag_maker.session session.flush() qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - clear_task_instances(qry, session, dag=dag) + clear_task_instances(tis=qry, session=session, dag=dag) session.flush() session.refresh(dr) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 9d76ffabc1cea..56425ae14d7be 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -145,7 +145,7 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() - clear_task_instances(qry, session) + clear_task_instances(tis=qry, session=session, dag=dag) session.flush() dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() assert dr0.state == state @@ -159,8 +159,8 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() - clear_task_instances(qry, session) + tis = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + clear_task_instances(tis=tis, session=session, dag=dag) session.flush() dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() assert dr0.state == DagRunState.QUEUED diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 16ea7a0b6113e..c8ffba9ba7788 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -118,7 +118,7 @@ def _clear_task_instances( log.debug("task_ids %s to clear", str(task_ids)) dr: DagRun = _get_dagrun(dag, run_id, session=session) tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids] - clear_task_instances(tis_to_clear, session) + clear_task_instances(tis=tis_to_clear, session=session, dag=dag) def _repair_task( diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 45a1e395e1a16..516ad05d565a7 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -787,7 +787,9 @@ def test_clear_skipped_downstream_task(self): tis = dr.get_task_instances() with create_session() as session: clear_task_instances( - [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag + tis=[ti for ti in tis if ti.task_id == "op1"], + session=session, + dag=short_circuit.dag, ) self.op1.run(start_date=self.default_date, end_date=self.default_date) self.assert_expected_task_states(dr, expected_states) @@ -1742,7 +1744,7 @@ def f(): tis = dr.get_task_instances() children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=branch_op.dag) + clear_task_instances(tis=children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: