|
2 | 2 | import site |
3 | 3 | from abc import abstractmethod |
4 | 4 | from typing import ( |
| 5 | + TYPE_CHECKING, |
5 | 6 | Any, |
6 | 7 | Iterable, |
7 | 8 | Iterator, |
|
17 | 18 | import django.core.checks |
18 | 19 | from rest_framework.serializers import ModelSerializer, Serializer |
19 | 20 |
|
20 | | -from .. import CheckId |
| 21 | +from ..ast.protocols import DisableCommentProtocol |
| 22 | +from ..ast.source_provider import SourceProvider |
| 23 | +from ..check_id import DRF_META_CHECKS_NAMES, CheckId |
21 | 24 | from ..forms import AttrsForm |
22 | 25 | from ..registry import ChecksConfig, registry |
23 | 26 | from .base_checks import BaseCheck |
24 | 27 |
|
| 28 | +if TYPE_CHECKING: |
| 29 | + cached_property = property |
| 30 | +else: |
| 31 | + from django.utils.functional import cached_property |
| 32 | + |
| 33 | + |
| 34 | +class DisableCommentProvider(DisableCommentProtocol): |
| 35 | + def __init__(self, serializer_class: Type[Serializer]): |
| 36 | + self.serializer_class = serializer_class |
| 37 | + |
| 38 | + @cached_property |
| 39 | + def _source_provider(self) -> SourceProvider: |
| 40 | + return SourceProvider(self.serializer_class) |
| 41 | + |
| 42 | + def is_disabled_by_comment(self, check_id: str) -> bool: |
| 43 | + check = CheckId.find_check(check_id) |
| 44 | + if check in DRF_META_CHECKS_NAMES: |
| 45 | + lines = (self._source_provider.source or "").splitlines() |
| 46 | + # find line starting with `class Meta` and lowest indent |
| 47 | + try: |
| 48 | + lineno, _ = sorted( |
| 49 | + [ |
| 50 | + (i, line) |
| 51 | + for i, line in enumerate(lines, 1) |
| 52 | + if line.strip().startswith(("class Meta(", "class Meta:")) |
| 53 | + ], |
| 54 | + key=lambda a: a[1].find("class Meta"), |
| 55 | + )[0] |
| 56 | + except StopIteration: |
| 57 | + return False |
| 58 | + return check in self._source_provider.get_disabled_checks_for_line(lineno) |
| 59 | + return check in self._source_provider.get_disabled_checks_for_line(1) |
| 60 | + |
25 | 61 |
|
26 | 62 | def _collect_serializers( |
27 | 63 | serializers: Iterable[Type[Serializer]], |
@@ -90,10 +126,10 @@ def check_drf_serializers( |
90 | 126 | s_classes, m_classes = _get_serializers_to_check(config.include_apps) |
91 | 127 | for s in s_classes: |
92 | 128 | for check in serializer_checks: |
93 | | - yield from check(s, None) |
| 129 | + yield from check(s, DisableCommentProvider(s)) |
94 | 130 | for s in m_classes: |
95 | 131 | for check in model_serializer_checks: |
96 | | - yield from check(s, None) |
| 132 | + yield from check(s, DisableCommentProvider(s)) |
97 | 133 |
|
98 | 134 |
|
99 | 135 | class CheckDRFSerializer(BaseCheck): |
|
0 commit comments