Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(utils/decorators): rewrite remove task decorator to use cst #43383

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,56 @@
from __future__ import annotations

import sys
from collections import deque
from typing import Callable, TypeVar

import libcst as cst

T = TypeVar("T", bound=Callable)


class _TaskDecoratorRemover(cst.CSTTransformer):
def __init__(self, task_decorator_name: str) -> None:
self.decorators_to_remove: set[str] = {
"setup",
"teardown",
"task.skip_if",
"task.run_if",
task_decorator_name.strip("@"),
}

def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool:
if isinstance(decorator_node.decorator, cst.Name):
return decorator_node.decorator.value in self.decorators_to_remove
elif isinstance(decorator_node.decorator, cst.Attribute) and isinstance(
decorator_node.decorator.value, cst.Name
):
return (
f"{decorator_node.decorator.value.value}.{decorator_node.decorator.attr.value}"
in self.decorators_to_remove
)
elif isinstance(decorator_node.decorator, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator_node.decorator.func))
return False

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)]
if len(new_decorators) == len(updated_node.decorators):
return updated_node
return updated_node.with_changes(decorators=new_decorators)


def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
Remove @task or similar decorators as well as @setup and @teardown.

:param python_source: The python source code
:param task_decorator_name: the decorator name

TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
"""

def _remove_task_decorator(py_source, decorator_name):
# if no line starts with @decorator_name, we can early exit
for line in py_source.split("\n"):
if line.startswith(decorator_name):
break
else:
return python_source
split = python_source.split(decorator_name, 1)
before_decorator, after_decorator = split[0], split[1]
if after_decorator[0] == "(":
after_decorator = _balance_parens(after_decorator)
if after_decorator[0] == "\n":
after_decorator = after_decorator[1:]
return before_decorator + after_decorator

decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
for decorator in decorators:
python_source = _remove_task_decorator(python_source, decorator)
return python_source


def _balance_parens(after_decorator):
num_paren = 1
after_decorator = deque(after_decorator)
after_decorator.popleft()
while num_paren:
current = after_decorator.popleft()
if current == "(":
num_paren = num_paren + 1
elif current == ")":
num_paren = num_paren - 1
return "".join(after_decorator)
source_tree = cst.parse_module(python_source)
modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name))
return modified_tree.code


class _autostacklevel_warn:
Expand Down
1 change: 1 addition & 0 deletions hatch_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@
"jinja2>=3.0.0",
"jsonschema>=4.18.0",
"lazy-object-proxy>=1.2.0",
"libcst >=1.1.0",
"linkify-it-py>=2.0.0",
"lockfile>=0.12.2",
"markdown-it-py>=2.1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from pathlib import Path
from textwrap import dedent
from unittest import mock

import pytest
Expand Down Expand Up @@ -191,26 +192,29 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
["uv", "pip", "install", "--python", "/VENV/bin/python", "apache-beam[gcp]"]
)

def test_remove_task_decorator(self):
py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_no_parens(self):
py_source = "@task.virtualenv\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_including_comment(self):
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n# @task.virtualenv\nimport funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\[email protected]\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.virtualenv"], []),
(["@task.virtualenv()"], []),
(['@task.virtualenv(serializer="dill")'], []),
(["@foo", "@task.virtualenv", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.virtualenv()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
decorator = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
def f():
# @task.virtualenv
import funcsigs
"""
)
py_source = decorator + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

py_source = "@foo\[email protected]()\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == expected_source
43 changes: 27 additions & 16 deletions tests/utils/test_preexisting_python_virtualenv_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,36 @@
# under the License.
from __future__ import annotations

from airflow.utils.decorators import remove_task_decorator
from textwrap import dedent

import pytest

class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"
from airflow.utils.decorators import remove_task_decorator

def test_remove_decorator_no_parens(self):
py_source = "@task.external_python\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\[email protected]_python\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
class TestExternalPythonDecorator:
@pytest.mark.parametrize(
"decorators, expected_decorators",
[
(["@task.external_python"], []),
(["@task.external_python()"], []),
(['@task.external_python(serializer="dill")'], []),
(["@foo", "@task.external_python", "@bar"], ["@foo", "@bar"]),
(["@foo", "@task.external_python()", "@bar"], ["@foo", "@bar"]),
],
ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"],
)
def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]):
decorator = "\n".join(decorators)
expected_decorator = "\n".join(expected_decorators)
SCRIPT = dedent(
"""
def f():
import funcsigs
"""
)
py_source = decorator + SCRIPT
expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip()

py_source = "@foo\[email protected]_python()\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == expected_source