Skip to content

Commit

Permalink
fixup! refactor(utils/decorators): rewrite remove task decorator to u…
Browse files Browse the repository at this point in the history
…se ast
  • Loading branch information
josix committed Feb 4, 2025
1 parent 272ee7a commit bc09e4b
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 28 deletions.
12 changes: 7 additions & 5 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def __init__(self, task_decorator_name: str) -> None:
def _is_task_decorator(self, decorator: cst.Decorator) -> bool:
if isinstance(decorator.decorator, cst.Name):
return decorator.decorator.value in self.decorators_to_remove
elif isinstance(decorator.decorator, cst.Attribute) and isinstance(decorator.decorator.value, cst.Name):
return (
f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}"
in self.decorators_to_remove
)
elif isinstance(decorator.decorator, cst.Attribute) and isinstance(
decorator.decorator.value, cst.Name
):
return (
f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}"
in self.decorators_to_remove
)
elif isinstance(decorator.decorator, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func))
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from pathlib import Path
from textwrap import dedent
from unittest import mock

import pytest
Expand Down Expand Up @@ -192,25 +193,77 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
)

def test_remove_task_decorator(self):
py_source = '@task.virtualenv(serializer="dill")\ndef f():\n import funcsigs'
py_source = dedent(
"""
@task.virtualenv(serializer="dill")
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n import funcsigs"
assert res == expected_source

def test_remove_decorator_no_parens(self):
py_source = "@task.virtualenv\ndef f():\n import funcsigs"
py_source = dedent(
"""
@task.virtualenv
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n import funcsigs"
assert res == expected_source

def test_remove_decorator_including_comment(self):
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n# @task.virtualenv\n import funcsigs"
py_source = dedent(
"""
@task.virtualenv
def f():
# @task.virtualenv
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
# @task.virtualenv
import funcsigs
"""
)

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
assert res == expected_source

py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\n import funcsigs"
@pytest.mark.parametrize("decorator", ["@task.virtualenv", "@task.virtualenv()"])
def test_remove_decorator_nested(self, decorator):
py_source = dedent(
f"""
@foo
{decorator}
@bar
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
@foo
@bar
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
assert res == expected_source
63 changes: 52 additions & 11 deletions tests/utils/test_preexisting_python_virtualenv_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,66 @@
# under the License.
from __future__ import annotations

from textwrap import dedent

import pytest

from airflow.utils.decorators import remove_task_decorator


class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = '@task.external_python(serializer="dill")\ndef f():\n import funcsigs'
py_source = dedent(
"""
@task.external_python(serializer="dill")
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\n import funcsigs"
assert res == expected_source

def test_remove_decorator_no_parens(self):
py_source = "@task.external_python\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\n import funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.external_python\n@bar\ndef f():\n import funcsigs"
py_source = dedent(
"""
@task.external_python
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
assert res == expected_source

py_source = "@foo\n@task.external_python()\n@bar\ndef f():\n import funcsigs"
@pytest.mark.parametrize("decorator", ["@task.external_python", "@task.external_python()"])
def test_remove_decorator_nested(self, decorator):
py_source = dedent(
f"""
@foo
{decorator}
@bar
def f():
import funcsigs
"""
)
expected_source = dedent(
"""
@foo
@bar
def f():
import funcsigs
"""
)
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
assert res == expected_source

0 comments on commit bc09e4b

Please sign in to comment.