Skip to content

Commit ec682d4

Browse files
committed
Experimental support for autofix
1 parent a3c1968 commit ec682d4

File tree

11 files changed

+304
-24
lines changed

11 files changed

+304
-24
lines changed

setup.cfg

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ dev =
5252
pdbpp
5353
tox>=3,<4
5454
black==20.8b1
55+
libcst>=0.3,<0.4
56+
codemod =
57+
libcst>=0.3,<0.4
5558

5659
[flake8]
5760
max-line-length = 110
@@ -64,4 +67,4 @@ force_grid_wrap = 0
6467
use_parentheses = True
6568
line_length = 88
6669
known_first_party = extra_checks
67-
known_third_party = django,pytest,rest_framework,setuptools
70+
known_third_party = django,libcst,pytest,rest_framework,setuptools

shell.nix

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,13 @@ devshell.mkShell {
4949
name = "app.lint";
5050
command = "pre-commit run -a";
5151
}
52+
{
53+
help = "run main tox env";
54+
name = "app.tox";
55+
command = ''
56+
unset PYTHONPATH;
57+
tox -e 'py{38}-django{22,30,31,32,32-drf,-latest},flake8,black,isort,manifest,mypy,check'
58+
'';
59+
}
5260
];
5361
}

src/extra_checks/checks/base_checks.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@
1515
}
1616

1717

18+
class ExtraCheckMessage(django.core.checks.CheckMessage):
19+
def __init__(
20+
self,
21+
level: int,
22+
msg: str,
23+
*,
24+
id: str,
25+
hint: Optional[str] = None,
26+
obj: Any = None,
27+
file: Optional[str] = None,
28+
fix: Any = None,
29+
) -> None:
30+
super().__init__(level, msg, hint=hint, obj=obj, id=id)
31+
self._file = file
32+
self._fix = fix
33+
34+
1835
class BaseCheck(ABC):
1936
Id: CheckId
2037
settings_form_class: ClassVar[Type[forms.BaseCheckForm]] = forms.BaseCheckForm
@@ -46,10 +63,21 @@ def is_ignored(self, obj: Any) -> bool:
4663
return obj in self.ignore_objects or type(obj) in self.ignore_types
4764

4865
def message(
49-
self, message: str, hint: Optional[str] = None, obj: Any = None
66+
self,
67+
message: str,
68+
hint: Optional[str] = None,
69+
obj: Any = None,
70+
file: Optional[str] = None,
71+
fix: Any = None,
5072
) -> django.core.checks.CheckMessage:
51-
return MESSAGE_MAP[self.level](
52-
message + f" [{self.Id.value}]", hint=hint, obj=obj, id=self.Id.name
73+
return ExtraCheckMessage(
74+
self.level,
75+
message + f" [{self.Id.value}]",
76+
hint=hint,
77+
obj=obj,
78+
id=self.Id.name,
79+
file=file,
80+
fix=fix,
5381
)
5482

5583

src/extra_checks/checks/model_field_checks.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -285,23 +285,50 @@ def apply(
285285
field: models.fields.Field,
286286
ast: FieldASTProtocol,
287287
model: Type[models.Model],
288+
**kwargs: Any,
288289
) -> Iterator[django.core.checks.CheckMessage]:
289290
choices = field.flatchoices # type: ignore
290-
if choices:
291-
field_choices = [c[0] for c in choices]
292-
if field.blank and "" not in field_choices:
293-
field_choices.append("")
294-
in_name = f"{field.name}__in"
295-
for constraint in model._meta.constraints:
296-
if isinstance(constraint, models.CheckConstraint):
297-
conditions = dict(constraint.check.children)
298-
if in_name in conditions and set(field_choices) == set(
299-
conditions[in_name]
300-
):
301-
return
302-
check = f'models.Q({in_name}=[{", ".join([self._repr_choice(c) for c in field_choices])}])'
303-
yield self.message(
304-
"Field with choices must have companion CheckConstraint to enforce choices on database level.",
305-
hint=f'Add to Meta.constraints: `models.CheckConstraint(name="%(app_label)s_%(class)s_{field.name}_valid", check={check})`',
306-
obj=field,
291+
if not choices:
292+
return
293+
field_choices = [c[0] for c in choices]
294+
if field.blank and "" not in field_choices:
295+
field_choices.append("")
296+
in_name = f"{field.name}__in"
297+
name = f"%(app_label)s_%(class)s_{field.name}_choices_valid"
298+
replace = False
299+
for constraint in model._meta.constraints:
300+
if isinstance(constraint, models.CheckConstraint):
301+
if name == constraint.name:
302+
replace = True
303+
conditions = dict(constraint.check.children)
304+
if in_name in conditions and set(field_choices) == set(
305+
conditions[in_name]
306+
):
307+
return
308+
check = f'models.Q({in_name}=[{", ".join([self._repr_choice(c) for c in field_choices])}])'
309+
kwargs = {}
310+
try:
311+
import importlib
312+
313+
from extra_checks.fixes.fix_choices_constraint import (
314+
gen_fix_for_choices_constraint,
307315
)
316+
except ImportError:
317+
pass
318+
else:
319+
kwargs = {
320+
"file": importlib.import_module(model.__module__).__file__,
321+
"fix": gen_fix_for_choices_constraint(
322+
model.__name__,
323+
f"%(app_label)s_%(class)s_{field.name}_choices_valid",
324+
check=check,
325+
replace=replace,
326+
),
327+
}
328+
329+
yield self.message(
330+
"Field with choices must have companion CheckConstraint to enforce choices on database level.",
331+
hint=f'Add to Meta.constraints: `models.CheckConstraint(name="{name}", check={check})`',
332+
obj=field,
333+
**kwargs,
334+
)

src/extra_checks/fixes/__init__.py

Whitespace-only changes.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import libcst as cst
2+
from libcst import matchers as m
3+
4+
5+
def gen_fix_for_choices_constraint(
6+
class_name: str, name: str, check: str, replace: bool = False
7+
) -> m.MatcherDecoratableTransformer:
8+
class Fixes(m.MatcherDecoratableTransformer):
9+
def __init__(self) -> None:
10+
self.is_constraint_updated = False
11+
super().__init__()
12+
13+
@m.call_if_inside(m.ClassDef(m.Name(class_name)))
14+
@m.leave(m.ClassDef(m.Name("Meta")))
15+
def leave_meta(
16+
self, node: cst.ClassDef, updated_node: cst.ClassDef
17+
) -> cst.ClassDef:
18+
if not self.is_constraint_updated and not replace:
19+
exp = cst.parse_statement(
20+
f'constraints = [models.CheckConstraint(name="{name}", check={check})]'
21+
)
22+
lines = updated_node.body.body
23+
return updated_node.with_deep_changes(
24+
updated_node.body, body=[*lines, exp]
25+
)
26+
self.is_constraint_updated = False
27+
return updated_node
28+
29+
if replace:
30+
31+
@m.call_if_inside(m.ClassDef(m.Name(class_name)))
32+
@m.call_if_inside(m.ClassDef(m.Name("Meta")))
33+
@m.call_if_inside(
34+
m.Assign(targets=[m.AssignTarget(target=m.Name("constraints"))])
35+
)
36+
@m.leave(
37+
m.Call(
38+
func=m.Attribute(attr=m.Name("CheckConstraint"))
39+
| m.Name("CheckConstraint")
40+
)
41+
)
42+
def fix_existing(self, node: cst.Call, updated_node: cst.Call) -> cst.Call:
43+
# TODO select either models.CheckConstraint or CheckConstraint
44+
if node.args[0].value.raw_value == name:
45+
return cst.parse_expression(
46+
f'models.CheckConstraint(name="{name}", check={check})'
47+
)
48+
return updated_node
49+
50+
else:
51+
52+
@m.call_if_inside(m.ClassDef(m.Name(class_name)))
53+
@m.call_if_inside(m.ClassDef(m.Name("Meta")))
54+
@m.leave(m.Assign(targets=[m.AssignTarget(target=m.Name("constraints"))]))
55+
def add_new(self, node: cst.Assign, updated_node: cst.Assign) -> cst.Assign:
56+
self.is_constraint_updated = True
57+
exp = cst.parse_expression(
58+
f'[models.CheckConstraint(name="{name}", check={check})]'
59+
)
60+
constraints = updated_node.value.elements
61+
return updated_node.with_deep_changes(
62+
updated_node.value, elements=[*constraints, exp.elements[0]]
63+
)
64+
65+
return Fixes()

src/extra_checks/management/__init__.py

Whitespace-only changes.

src/extra_checks/management/commands/__init__.py

Whitespace-only changes.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List
2+
3+
from django.core import checks
4+
from django.core.management.base import SystemCheckError
5+
from django.core.management.commands.check import Command as BaseCommand
6+
7+
if TYPE_CHECKING:
8+
from extra_checks.checks.base_checks import ExtraCheckMessage
9+
10+
11+
class Command(BaseCommand):
12+
def __init__(self, *args: Any, **kwargs: Any) -> None:
13+
super().__init__(*args, **kwargs)
14+
self._errors: "List[ExtraCheckMessage]" = []
15+
16+
def add_arguments(self, parser: Any) -> None:
17+
super().add_arguments(parser)
18+
parser.add_argument(
19+
"--fix",
20+
action="store_true",
21+
help="Apply autofix if available.",
22+
)
23+
parser.add_argument(
24+
"--fix-black",
25+
action="store_true",
26+
help="Apply black on autofix result.",
27+
)
28+
29+
def _run_checks(self, **kwargs: Any) -> dict:
30+
errors = checks.run_checks(**kwargs)
31+
self._errors = errors
32+
return errors
33+
34+
def handle(self, *app_labels: Any, **options: Any) -> None:
35+
on_exit = None
36+
try:
37+
super().handle(*app_labels, **options)
38+
except SystemCheckError as exc:
39+
on_exit = exc
40+
if not options["fix"]:
41+
return
42+
import libcst as cst
43+
from libcst import matchers as m
44+
45+
files: Dict[str, List[m.MatcherDecoratableTransformer]] = {}
46+
for error in self._errors:
47+
fix = getattr(error, "_fix", None)
48+
file = getattr(error, "_file", None)
49+
if fix and file:
50+
files.setdefault(file, []).append(fix)
51+
for file, fixes in files.items():
52+
with open(file, "r") as f:
53+
source_text = f.read()
54+
tree = cst.parse_module(source_text)
55+
for fix in fixes:
56+
tree = tree.visit(fix)
57+
result_text = tree.code
58+
if source_text != result_text:
59+
if options["fix_black"]:
60+
import black
61+
62+
mode = black.FileMode()
63+
fast = False
64+
result_text = black.format_file_contents(
65+
src_contents=result_text, fast=fast, mode=mode
66+
)
67+
with open(file, "w") as f:
68+
f.write(result_text)
69+
if on_exit:
70+
raise on_exit

tests/test_autofix.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import libcst as cst
2+
3+
from extra_checks.fixes.fix_choices_constraint import gen_fix_for_choices_constraint
4+
5+
SOURCE = """
6+
class TestClass(Token):
7+
class Meta:
8+
constraints = [
9+
models.CheckConstraint(name="value_valid", check=models.Q(value__in=[1, 2]))
10+
]
11+
"""
12+
13+
14+
def test_fix_add():
15+
result = """
16+
class TestClass(Token):
17+
class Meta:
18+
constraints = [
19+
models.CheckConstraint(name="value_valid", check=models.Q(value__in=[1, 2])), models.CheckConstraint(name="another_valid", check=models.Q(value__in=[1, 2, 3]))
20+
]
21+
"""
22+
23+
source_tree = cst.parse_module(SOURCE)
24+
modefied_tree = source_tree.visit(
25+
gen_fix_for_choices_constraint(
26+
"TestClass",
27+
name="another_valid",
28+
check="models.Q(value__in=[1, 2, 3])",
29+
replace=False,
30+
)
31+
)
32+
assert modefied_tree.code == result
33+
34+
35+
def test_fix_replace():
36+
result = """
37+
class TestClass(Token):
38+
class Meta:
39+
constraints = [
40+
models.CheckConstraint(name="value_valid", check=models.Q(value__in=[1, 2, 3]))
41+
]
42+
"""
43+
44+
source_tree = cst.parse_module(SOURCE)
45+
modefied_tree = source_tree.visit(
46+
gen_fix_for_choices_constraint(
47+
"TestClass",
48+
name="value_valid",
49+
check="models.Q(value__in=[1, 2, 3])",
50+
replace=True,
51+
)
52+
)
53+
assert modefied_tree.code == result
54+
55+
56+
def test_fix_meta_add_constraints():
57+
source = """
58+
class TestClass(Token):
59+
class Meta:
60+
db_table = 'test_class'
61+
"""
62+
result = """
63+
class TestClass(Token):
64+
class Meta:
65+
db_table = 'test_class'
66+
constraints = [models.CheckConstraint(name="value_valid", check=models.Q(value__in=[1, 2, 3]))]
67+
"""
68+
69+
source_tree = cst.parse_module(source)
70+
modefied_tree = source_tree.visit(
71+
gen_fix_for_choices_constraint(
72+
"TestClass",
73+
name="value_valid",
74+
check="models.Q(value__in=[1, 2, 3])",
75+
replace=False,
76+
)
77+
)
78+
assert modefied_tree.code == result

0 commit comments

Comments
 (0)