Skip to content

Commit be9640d

Browse files
committed
fix tests
1 parent 638bd31 commit be9640d

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

airflow-core/src/airflow/models/taskinstance.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def clear_task_instances(
305305
session.merge(ti)
306306

307307
if dag_run_state is not False and tis:
308+
from airflow.models.dag import DagModel
308309
from airflow.models.dagrun import DagRun # Avoid circular import
309310

310311
run_ids_by_dag_id = defaultdict(set)
@@ -326,8 +327,19 @@ def clear_task_instances(
326327
if dr.state in State.finished_dr_states:
327328
dr.state = dag_run_state
328329
dr.start_date = timezone.utcnow()
330+
if TYPE_CHECKING:
331+
assert dag # todo: change signature so this is required
329332
if not dag.disable_bundle_versioning:
330-
dr.bundle_version = dr.dag_model.bundle_version
333+
if dr.dag_model:
334+
bundle_version = dr.dag_model.bundle_version
335+
else:
336+
bundle_version = session.scalar(
337+
select(DagModel.bundle_version).where(
338+
DagModel.dag_id == dag.dag_id,
339+
)
340+
)
341+
if bundle_version is not None:
342+
dr.bundle_version = bundle_version
331343
if dag_run_state == DagRunState.QUEUED:
332344
dr.last_scheduling_decision = None
333345
dr.start_date = None

airflow-core/tests/unit/models/test_dag.py

-1
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,6 @@ def test_clear_set_dagrun_state(self, dag_run_state):
14261426
dag_run_state=dag_run_state,
14271427
session=session,
14281428
)
1429-
14301429
dagruns = session.query(DagRun).filter(DagRun.dag_id == dag_id).all()
14311430

14321431
assert len(dagruns) == 1

airflow-core/tests/unit/models/test_dagrun.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat
160160
self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session)
161161

162162
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
163-
clear_task_instances(qry, session)
163+
clear_task_instances(qry, session, dag=dag)
164164
session.flush()
165165
dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first()
166166
assert dr0.state == DagRunState.QUEUED

0 commit comments

Comments
 (0)