diff --git a/src/griffe_warnings_deprecated/extension.py b/src/griffe_warnings_deprecated/extension.py index 201f1ce..0295956 100644 --- a/src/griffe_warnings_deprecated/extension.py +++ b/src/griffe_warnings_deprecated/extension.py @@ -2,9 +2,10 @@ from __future__ import annotations +import ast from typing import Any -from griffe import Class, Docstring, DocstringSectionAdmonition, Extension, Function, get_logger +from griffe import Class, Docstring, DocstringSectionAdmonition, ExprCall, Extension, Function, get_logger logger = get_logger(__name__) self_namespace = "griffe_warnings_deprecated" @@ -15,8 +16,13 @@ def _deprecated(obj: Class | Function) -> str | None: for decorator in obj.decorators: - if decorator.callable_path in _decorators: - return str(decorator.value).split("(", 1)[1].rstrip(")").rsplit(",", 1)[0].lstrip("f")[1:-1] + if decorator.callable_path in _decorators and isinstance(decorator.value, ExprCall): + first_arg = decorator.value.arguments[0] + try: + return ast.literal_eval(first_arg) # type: ignore[arg-type] + except ValueError: + logger.debug("%s is not a static string", str(first_arg)) + return None return None diff --git a/tests/test_extension.py b/tests/test_extension.py index 39d38ea..c28ee6c 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from textwrap import dedent import pytest @@ -37,10 +38,6 @@ def hello(): ... def hello(): ... """, """ - @warnings.deprecated(f"message", category=DeprecationWarning) - def hello(): ... - """, - """ @warnings.deprecated("message", category=DeprecationWarning) def hello(): '''Summary.''' @@ -85,3 +82,23 @@ def test_extension(code: str) -> None: assert adm.title == "Deprecated" assert adm.value.kind == "danger" assert adm.value.contents == "message" + + +def test_extension_fstring(caplog: pytest.LogCaptureFixture) -> None: + """Test the extension with an f-string as the deprecation message.""" + code = dedent( + """ + import warnings + @warnings.deprecated(f"message") + def hello(): ... + """, + ) + with ( + caplog.at_level(logging.DEBUG), + temporary_visited_module(code, extensions=load_extensions(WarningsDeprecatedExtension)) as module, + ): + adm = module["hello"].docstring + + # Expect no deprecation message in the docstring. + assert adm is None + assert "f'message' is not a static string" in caplog.records[0].message