Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ jobs:
- "Django>=4.2,<5.0"
- "Django>=5.0,<5.1"
- "Django>=5.1,<5.2"
- "Django==5.2a1"
- "Django>=5.2,<6.0"
# - "https://github.com/django/django/archive/main.tar.gz"
include:
- drf: djangorestframework
python-version: "3.12"
django-version: "Django<5.2,>=5.0" # must be different from django-version
django-version: "Django<6.0,>=5.2" # must be different from django-version
exclude:
- django-version: "Django>=5.0,<5.1"
python-version: 3.9
- django-version: "Django>=5.1,<5.2"
python-version: 3.9
- django-version: "Django==5.2a1"
- django-version: "Django>=5.2,<6.0"
python-version: 3.9
# - django-version: "https://github.com/django/django/archive/main.tar.gz"
# python-version: 3.8
Expand Down
12 changes: 9 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
- repo: local
hooks:
- id: ruff
name: ruff
entry: ruff check --force-exclude --fix --exit-non-zero-on-fix
language: system
args: [--fix, --exit-non-zero-on-fix]
types_or: [python]
require_serial: true
- id: ruff-format
name: ruff-format
entry: ruff format --force-exclude
language: system
types: [python]
require_serial: true

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ dev = [
"pre-commit",
"ruff",
{include-group = "test"},
"pyright>=1.1.402",
"ty>=0.0.1a13",
]
test = [
"pytest",
Expand Down
19 changes: 10 additions & 9 deletions src/extra_checks/ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
)

from django.db import models
from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField
from django.utils.functional import SimpleLazyObject

from extra_checks.check_id import CheckId
Expand Down Expand Up @@ -50,7 +52,7 @@ def _parse(self, predicate: Optional[Callable[[ast.AST], bool]] = None) -> None:
try:
for node in self._nodes:
if predicate and predicate(node):
self._meta = cast(ast.ClassDef, node)
self._meta = cast("ast.ClassDef", node)
break
if isinstance(node, ast.Assign):
self._assignment_nodes.append(node)
Expand Down Expand Up @@ -86,13 +88,14 @@ def _assignments(self) -> dict[str, ast.Assign]:
return result

@cached_property
def field_nodes(self) -> Iterable[tuple[models.fields.Field, "FieldAST"]]:
def field_nodes(self) -> Iterable[tuple[Field, "FieldAST"]]:
for field in self.model_cls._meta.get_fields(include_parents=False):
if isinstance(field, models.Field):
if isinstance(field, Field):
yield (
field,
cast(
FieldAST, SimpleLazyObject(partial(get_field_ast, self, field))
"FieldAST",
SimpleLazyObject(partial(get_field_ast, self, field)),
),
)

Expand All @@ -111,7 +114,7 @@ def is_disabled_by_comment(self, check_id: str) -> bool:
return check in self._source_provider.get_disabled_checks_for_line(1)


def get_field_ast(model_ast: ModelAST, field: models.Field) -> "FieldAST":
def get_field_ast(model_ast: ModelAST, field: Field) -> "FieldAST":
try:
return FieldAST(
model_ast._assignments[field.name], field, model_ast._source_provider
Expand Down Expand Up @@ -141,9 +144,7 @@ def get_call_first_args(self) -> str:


class FieldAST(DisableCommentProtocol, FieldASTProtocol):
def __init__(
self, node: ast.Assign, field: models.Field, source_provider: SourceProvider
):
def __init__(self, node: ast.Assign, field: Field, source_provider: SourceProvider):
self._node = node
self._field = field
self._source_provider = source_provider
Expand All @@ -166,7 +167,7 @@ def _verbose_name(self) -> Union[None, ast.Constant, ast.Call]:
result = getattr(self._kwargs.get("verbose_name"), "value", None)
if result:
return result
if isinstance(self._field, models.fields.related.RelatedField):
if isinstance(self._field, RelatedField):
return None
if self._args:
node = self._args[0]
Expand Down
4 changes: 2 additions & 2 deletions src/extra_checks/ast/protocols.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable
from typing import Any, Optional, Protocol

from django.db import models
from django.db.models.fields import Field


class ArgASTProtocol(Protocol):
Expand All @@ -22,7 +22,7 @@ class ModelASTProtocol(Protocol):
@property
def field_nodes(
self,
) -> Iterable[tuple[models.fields.Field, "FieldASTDisableCommentProtocol"]]: ...
) -> Iterable[tuple[Field, "FieldASTDisableCommentProtocol"]]: ...

def has_meta_var(self, name: str) -> bool: ...

Expand Down
2 changes: 1 addition & 1 deletion src/extra_checks/check_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_check(cls, value: str) -> Optional["CheckId"]:
except ValueError:
pass
try:
return cast(CheckId, cls._member_map_[value])
return cast("CheckId", cls._member_map_[value])
except KeyError:
pass
return None
Expand Down
7 changes: 3 additions & 4 deletions src/extra_checks/checks/drf_serializer_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

import django.core.checks
from django.apps import apps
from rest_framework.serializers import ModelSerializer, Serializer

from ..ast.protocols import DisableCommentProtocol
Expand Down Expand Up @@ -70,9 +71,7 @@ def _filter_app_serializers(
) -> Iterator[type[Serializer]]:
site_prefixes = set(site.PREFIXES)
if include_apps is not None:
app_paths = {
a.path for a in django.apps.apps.get_app_configs() if a.name in include_apps
}
app_paths = {a.path for a in apps.get_app_configs() if a.name in include_apps}
for s in serializers:
module = importlib.import_module(s.__module__)
if any(
Expand Down Expand Up @@ -104,7 +103,7 @@ def _get_serializers_to_check(
)
return (
serializer_classes,
cast(Iterator[type[ModelSerializer]], model_serializer_classes),
cast("Iterator[type[ModelSerializer]]", model_serializer_classes), # ty: ignore[redundant-cast]
)


Expand Down
6 changes: 3 additions & 3 deletions src/extra_checks/checks/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def _get_models_to_check(
app_configs: Optional[list[Any]] = None,
include_apps: Optional[Iterable[str]] = None,
) -> Iterator[type[models.Model]]:
apps = django.apps.apps.get_app_configs() if app_configs is None else app_configs
apps_ = apps.get_app_configs() if app_configs is None else app_configs
if include_apps is not None:
for app in apps:
for app in apps_:
if app.name in include_apps:
yield from app.get_models()
return
for app in apps:
for app in apps_:
if not any(app.path.startswith(path) for path in set(site.PREFIXES)):
yield from app.get_models()

Expand Down
30 changes: 16 additions & 14 deletions src/extra_checks/checks/model_field_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import django.core.checks
from django import forms
from django.db import models
from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField

from .. import CheckId
from ..ast import FieldASTProtocol, MissingASTError
Expand All @@ -19,7 +21,7 @@ class CheckModelField(BaseCheck):
@abstractmethod
def apply(
self,
field: models.fields.Field,
field: Field,
*,
ast: FieldASTProtocol,
model: type[models.Model],
Expand Down Expand Up @@ -56,7 +58,7 @@ class CheckFieldVerboseName(CheckModelField):
Id = CheckId.X050

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
if not ast.get_arg("verbose_name"):
yield self.message(
Expand All @@ -71,7 +73,7 @@ class CheckFieldVerboseNameGettext(GetTextMixin, CheckModelField):
Id = CheckId.X051

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
verbose_name = ast.get_arg("verbose_name")
if verbose_name and not (
Expand All @@ -98,7 +100,7 @@ def is_invalid(cls, value: object) -> bool:
)

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
verbose_name = ast.get_arg("verbose_name")
if verbose_name and (
Expand All @@ -119,7 +121,7 @@ class CheckFieldHelpTextGettext(GetTextMixin, CheckModelField):
Id = CheckId.X053

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
help_text = ast.get_arg("help_text")
if help_text and not (
Expand All @@ -137,7 +139,7 @@ class CheckFieldFileUploadTo(CheckModelField):
Id = CheckId.X054

def apply(
self, field: models.fields.Field, **kwargs: Any
self, field: Field, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
if isinstance(field, models.FileField):
if not field.upload_to:
Expand All @@ -153,7 +155,7 @@ class CheckFieldTextNull(CheckModelField):
Id = CheckId.X055

def apply(
self, field: models.fields.Field, **kwargs: Any
self, field: Field, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
if isinstance(field, (models.CharField, models.TextField)):
if field.null:
Expand All @@ -170,7 +172,7 @@ class CheckFieldNullFalse(CheckModelField):
Id = CheckId.X057

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
if field.null is False and ast.get_arg("null"):
yield self.message(
Expand Down Expand Up @@ -218,11 +220,11 @@ def get_fields_with_indexes_in_meta(

def apply(
self,
field: models.fields.Field,
field: Field,
ast: FieldASTProtocol,
model: type[models.Model],
) -> Iterator[django.core.checks.CheckMessage]:
if isinstance(field, models.fields.related.RelatedField):
if isinstance(field, RelatedField):
if field.many_to_one and not ast.get_arg("db_index"):
if self.when == "indexes":
if field.name in self.get_fields_with_indexes_in_meta(model):
Expand All @@ -245,11 +247,11 @@ class CheckFieldRelatedName(CheckModelField):

def apply(
self,
field: models.fields.Field,
field: Field,
ast: FieldASTProtocol,
model: type[models.Model],
) -> Iterator[django.core.checks.CheckMessage]:
if isinstance(field, models.fields.related.RelatedField):
if isinstance(field, RelatedField):
if not field.remote_field.related_name:
yield self.message(
"Related fields must set `related_name` explicitly.",
Expand All @@ -263,7 +265,7 @@ class CheckFieldDefaultNull(CheckModelField):
Id = CheckId.X059

def apply(
self, field: models.fields.Field, ast: FieldASTProtocol, **kwargs: Any
self, field: Field, ast: FieldASTProtocol, **kwargs: Any
) -> Iterator[django.core.checks.CheckMessage]:
if field.null and field.default is None and ast.get_arg("default"):
yield self.message(
Expand All @@ -285,7 +287,7 @@ def _repr_choice(value: Any) -> str:

def apply(
self,
field: models.fields.Field,
field: Field,
ast: FieldASTProtocol,
model: type[models.Model],
) -> Iterator[django.core.checks.CheckMessage]:
Expand Down
Loading