Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow-core/newsfragments/60268.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The ``PythonOperator`` parameter ``python_callable`` now also supports async callables.

``PythonOperator`` now supports async callables in Airflow 3.2, allowing users to run async def functions without manually managing an event loop.
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,6 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]):
factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
)

_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)

def _read_frame(self):
from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -816,7 +814,7 @@ async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
bytes = frame.as_bytes()

async with self._lock:
async with self._async_lock:
self._async_writer.write(bytes)

return await self._aget_response(frame.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -616,7 +616,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -1408,7 +1408,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_python_operator"},
9, # Based on test failure - example_python_operator creates 9 task instances
14, # Based on test failure - example_python_operator creates 14 task instances
3,
id="test dag_id_pattern exact match",
),
Expand All @@ -1417,7 +1417,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
False,
"/dags/~/dagRuns/~/taskInstances",
{"dag_id_pattern": "example_%"},
17, # Based on test failure - both DAGs together create 17 task instances
22, # Based on test failure - both DAGs together create 22 task instances
3,
id="test dag_id_pattern wildcard prefix",
),
Expand Down Expand Up @@ -1931,8 +1931,8 @@ def test_should_respond_200_when_task_instance_properties_are_none(
[
pytest.param(
{"dag_ids": ["example_python_operator", "example_skip_dag"]},
17,
17,
22,
22,
id="with dag filter",
),
],
Expand Down Expand Up @@ -2041,7 +2041,7 @@ def test_should_respond_200_for_pagination(self, test_client, session):
assert len(response_batch2.json()["task_instances"]) > 0

# Match
ti_count = 9
ti_count = 10
assert response_batch1.json()["total_entries"] == response_batch2.json()["total_entries"] == ti_count
assert (num_entries_batch1 + num_entries_batch2) == ti_count
assert response_batch1 != response_batch2
Expand Down Expand Up @@ -2080,7 +2080,7 @@ def test_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -2127,7 +2127,7 @@ def test_should_respond_200_with_different_try_numbers(self, test_client, try_nu
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -2205,7 +2205,7 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers(
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -2278,7 +2278,7 @@ def test_should_respond_200_with_task_state_in_deferred(self, test_client, sessi
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -2326,7 +2326,7 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3243,7 +3243,7 @@ def test_should_respond_200_with_dag_run_id(
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3615,7 +3615,7 @@ def test_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3653,7 +3653,7 @@ def test_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3801,7 +3801,7 @@ def test_ti_in_retry_state_not_returned(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3885,7 +3885,7 @@ def test_mapped_task_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -3923,7 +3923,7 @@ def test_mapped_task_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4165,7 +4165,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4439,7 +4439,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4573,7 +4573,7 @@ def test_update_mask_set_note_should_respond_200(
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4634,7 +4634,7 @@ def test_set_note_should_respond_200(self, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4713,7 +4713,7 @@ def test_set_note_should_respond_200_mapped_task_with_rtif(self, test_client, se
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4794,7 +4794,7 @@ def test_set_note_should_respond_200_mapped_task_summary_with_rtif(self, test_cl
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -4912,7 +4912,7 @@ def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session):
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down Expand Up @@ -5198,7 +5198,7 @@ def test_should_raise_422_for_invalid_task_instance_state(self, payload, expecte
"pid": 100,
"pool": "default_pool",
"pool_slots": 1,
"priority_weight": 9,
"priority_weight": 14,
"queue": "default_queue",
"queued_when": None,
"scheduled_when": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,74 @@

from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.common.compat._compat_utils import create_module_getattr
from airflow.providers.common.compat.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS,
)

_IMPORT_MAP: dict[str, str | tuple[str, ...]] = {
# Re-export from sdk (which handles Airflow 2.x/3.x fallbacks)
"BaseOperator": "airflow.providers.common.compat.sdk",
"BaseAsyncOperator": "airflow.providers.common.compat.sdk",
"get_current_context": "airflow.providers.common.compat.sdk",
"is_async_callable": "airflow.providers.common.compat.sdk",
# Standard provider items with direct fallbacks
"PythonOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
"ShortCircuitOperator": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
"_SERIALIZERS": ("airflow.providers.standard.operators.python", "airflow.operators.python"),
}

if TYPE_CHECKING:
from airflow.sdk.bases.decorator import is_async_callable
from airflow.sdk.bases.operator import BaseAsyncOperator
elif AIRFLOW_V_3_2_PLUS:
from airflow.sdk.bases.decorator import is_async_callable
from airflow.sdk.bases.operator import BaseAsyncOperator
else:
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
else:
from airflow.models import BaseOperator

def is_async_callable(func) -> bool:
"""Detect if a callable is an async function."""
import inspect
from functools import partial

while isinstance(func, partial):
func = func.func
return inspect.iscoroutinefunction(func)

class BaseAsyncOperator(BaseOperator):
"""Stub for Airflow < 3.2 that raises a clear error."""

@property
def is_async(self) -> bool:
return True

if not AIRFLOW_V_3_1_PLUS:

@property
def xcom_push(self) -> bool:
return self.do_xcom_push

@xcom_push.setter
def xcom_push(self, value: bool):
self.do_xcom_push = value

async def aexecute(self, context):
raise NotImplementedError()

def execute(self, context):
raise RuntimeError(
"Async operators require Airflow 3.2+. Upgrade Airflow or use a synchronous callable."
)


__getattr__ = create_module_getattr(import_map=_IMPORT_MAP)

__all__ = sorted(_IMPORT_MAP.keys())
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:

AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)

# BaseOperator removed from version_compat to avoid circular imports
# Import it directly in files that need it instead

__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
]
Loading
Loading