Skip to content

Commit f829c3e

Browse files
dablakaxil
authored andcommitted
Update providers/common/compat/src/airflow/providers/common/compat/standard/operators.py
Co-authored-by: Kaxil Naik <[email protected]>
1 parent cdecb12 commit f829c3e

File tree

2 files changed

+14
-57
lines changed
  • providers

2 files changed

+14
-57
lines changed

providers/common/compat/src/airflow/providers/common/compat/standard/operators.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -45,67 +45,22 @@
4545
from airflow.sdk.bases.decorator import is_async_callable
4646
from airflow.sdk.bases.operator import BaseAsyncOperator
4747
else:
48-
import inspect
49-
from collections.abc import Callable
50-
from contextlib import suppress
51-
from functools import partial
52-
5348
if AIRFLOW_V_3_0_PLUS:
5449
from airflow.sdk import BaseOperator
55-
from airflow.sdk.bases.decorator import _TaskDecorator
56-
from airflow.sdk.definitions.mappedoperator import OperatorPartial
5750
else:
58-
from airflow.decorators.base import _TaskDecorator
5951
from airflow.models import BaseOperator
60-
from airflow.models.mappedoperator import OperatorPartial
61-
62-
def unwrap_partial(fn: Callable) -> Callable:
63-
while isinstance(fn, partial):
64-
fn = fn.func
65-
return fn
66-
67-
def unwrap_callable(func):
68-
# Airflow-specific unwrap
69-
if isinstance(func, (_TaskDecorator, OperatorPartial)):
70-
func = getattr(func, "function", getattr(func, "_func", func))
71-
72-
# Unwrap functools.partial
73-
func = unwrap_partial(func)
74-
75-
# Unwrap @functools.wraps chains
76-
with suppress(Exception):
77-
func = inspect.unwrap(func)
78-
79-
return func
8052

81-
def is_async_callable(func):
82-
"""Detect if a callable (possibly wrapped) is an async function."""
83-
func = unwrap_callable(func)
53+
def is_async_callable(func) -> bool:
54+
"""Detect async callables. """
55+
import inspect
56+
from functools import partial
8457

85-
if not callable(func):
86-
return False
87-
88-
# Direct async function
89-
if inspect.iscoroutinefunction(func):
90-
return True
91-
92-
# Callable object with async __call__
93-
if not inspect.isfunction(func):
94-
call = type(func).__call__ # Bandit-safe
95-
with suppress(Exception):
96-
call = inspect.unwrap(call)
97-
if inspect.iscoroutinefunction(call):
98-
return True
99-
100-
return False
58+
while isinstance(func, partial):
59+
func = func.func
60+
return inspect.iscoroutinefunction(func)
10161

10262
class BaseAsyncOperator(BaseOperator):
103-
"""
104-
Base class for async-capable operators.
105-
106-
As opposed to deferred operators which are executed on the triggerer, async operators are executed
107-
on the worker.
108-
"""
63+
"""Stub for Airflow < 3.2 that raises a clear error."""
10964

11065
@property
11166
def is_async(self) -> bool:
@@ -122,12 +77,13 @@ def xcom_push(self, value: bool):
12277
self.do_xcom_push = value
12378

12479
async def aexecute(self, context):
125-
"""Async version of execute(). Subclasses should implement this."""
12680
raise NotImplementedError()
12781

12882
def execute(self, context):
129-
"""Run `aexecute()` inside an event loop."""
130-
raise NotImplementedError("Airflow 3.2+ is required to allow executing async operators!")
83+
raise RuntimeError(
84+
"Async operators require Airflow 3.2+. "
85+
"Upgrade Airflow or use a synchronous callable."
86+
)
13187

13288

13389
__getattr__ = create_module_getattr(import_map=_IMPORT_MAP)

providers/standard/tests/unit/standard/operators/test_python.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2480,7 +2480,8 @@ async def say_hello(name: str) -> str:
24802480
assert "Done. Returned value was: Hello world!" in caplog.messages
24812481
else:
24822482
with pytest.raises(
2483-
NotImplementedError, match=r"Airflow 3\.2\+ is required to allow executing async operators!"
2483+
NotImplementedError,
2484+
match=r"Async operators require Airflow 3\.2\+\. Upgrade Airflow or use a synchronous callable\.",
24842485
):
24852486
self.run_as_task(say_hello, op_kwargs={"name": "world"}, show_return_value_in_logs=True)
24862487

0 commit comments

Comments
 (0)