Skip to content

Commit 0967d04

Browse files
Support multi-level nested create/update with model full_clean() (#659)
* Add support for nested creation/update in mutations. This also has the benefit of consistently calling `full_clean()` before creating related instances. This does remove the `get_or_create()` calls and instead uses `create` only. The expectation here is that `key_attr` could and should be used to indicate what field should be used as the unique identifier, and not something hard coded that could have unintended side effects when creating related instances that don't have unique constraints and expect new instances to always be created. * Formatting * First test (heavily based on one from an existing PR) * Update new test with m2m creation/use * Add test for nested creation when creating a new resource * Add test for full_clean being called when performing nested creation or resources * Remove unecessary `@transaction.atomic()` call * Add support for nested creation of ForeignKeys
1 parent 9e6d2bb commit 0967d04

File tree

5 files changed

+673
-42
lines changed

5 files changed

+673
-42
lines changed

strawberry_django/mutations/resolvers.py

+114-25
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import strawberry
1616
from django.db import models, transaction
1717
from django.db.models.base import Model
18+
from django.db.models.fields import Field
1819
from django.db.models.fields.related import ManyToManyField
1920
from django.db.models.fields.reverse_related import (
2021
ForeignObjectRel,
@@ -44,7 +45,11 @@
4445
)
4546

4647
if TYPE_CHECKING:
47-
from django.db.models.manager import ManyToManyRelatedManager, RelatedManager
48+
from django.db.models.manager import (
49+
BaseManager,
50+
ManyToManyRelatedManager,
51+
RelatedManager,
52+
)
4853
from strawberry.types.info import Info
4954

5055

@@ -88,6 +93,7 @@ def _parse_data(
8893
value: Any,
8994
*,
9095
key_attr: str | None = None,
96+
full_clean: bool | FullCleanOptions = True,
9197
):
9298
obj, data = _parse_pk(value, model, key_attr=key_attr)
9399
parsed_data = {}
@@ -97,10 +103,21 @@ def _parse_data(
97103
continue
98104

99105
if isinstance(v, ParsedObject):
100-
if v.pk is None:
101-
v = create(info, model, v.data or {}) # noqa: PLW2901
106+
if v.pk in {None, UNSET}:
107+
related_field = cast("Field", get_model_fields(model).get(k))
108+
related_model = related_field.related_model
109+
v = create( # noqa: PLW2901
110+
info,
111+
cast("type[Model]", related_model),
112+
v.data or {},
113+
key_attr=key_attr,
114+
full_clean=full_clean,
115+
exclude_m2m=[related_field.name],
116+
)
102117
elif isinstance(v.pk, models.Model) and v.data:
103-
v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901
118+
v = update( # noqa: PLW2901
119+
info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean
120+
)
104121
else:
105122
v = v.pk # noqa: PLW2901
106123

@@ -222,6 +239,7 @@ def prepare_create_update(
222239
data: dict[str, Any],
223240
key_attr: str | None = None,
224241
full_clean: bool | FullCleanOptions = True,
242+
exclude_m2m: list[str] | None = None,
225243
) -> tuple[
226244
Model,
227245
dict[str, object],
@@ -237,6 +255,7 @@ def prepare_create_update(
237255
fields = get_model_fields(model)
238256
m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = []
239257
direct_field_values: dict[str, object] = {}
258+
exclude_m2m = exclude_m2m or []
240259

241260
if dataclasses.is_dataclass(data):
242261
data = vars(data)
@@ -256,6 +275,8 @@ def prepare_create_update(
256275
# (but only if the instance is already saved and we are updating it)
257276
value = False # noqa: PLW2901
258277
elif isinstance(field, (ManyToManyField, ForeignObjectRel)):
278+
if name in exclude_m2m:
279+
continue
259280
# m2m will be processed later
260281
m2m.append((field, value))
261282
direct_field_value = False
@@ -269,14 +290,19 @@ def prepare_create_update(
269290
cast("type[Model]", field.related_model),
270291
value,
271292
key_attr=key_attr,
293+
full_clean=full_clean,
272294
)
273295
if value is None and not value_data:
274296
value = None # noqa: PLW2901
275297

276298
# If foreign object is not found, then create it
277-
elif value is None:
278-
value = field.related_model._default_manager.create( # noqa: PLW2901
279-
**value_data,
299+
elif value in {None, UNSET}:
300+
value = create( # noqa: PLW2901
301+
info,
302+
field.related_model,
303+
value_data,
304+
key_attr=key_attr,
305+
full_clean=full_clean,
280306
)
281307

282308
# If foreign object does not need updating, then skip it
@@ -309,6 +335,7 @@ def create(
309335
key_attr: str | None = None,
310336
full_clean: bool | FullCleanOptions = True,
311337
pre_save_hook: Callable[[_M], None] | None = None,
338+
exclude_m2m: list[str] | None = None,
312339
) -> _M: ...
313340

314341

@@ -321,10 +348,10 @@ def create(
321348
key_attr: str | None = None,
322349
full_clean: bool | FullCleanOptions = True,
323350
pre_save_hook: Callable[[_M], None] | None = None,
351+
exclude_m2m: list[str] | None = None,
324352
) -> list[_M]: ...
325353

326354

327-
@transaction.atomic
328355
def create(
329356
info: Info,
330357
model: type[_M],
@@ -333,12 +360,43 @@ def create(
333360
key_attr: str | None = None,
334361
full_clean: bool | FullCleanOptions = True,
335362
pre_save_hook: Callable[[_M], None] | None = None,
363+
exclude_m2m: list[str] | None = None,
364+
) -> list[_M] | _M:
365+
return _create(
366+
info,
367+
model._default_manager,
368+
data,
369+
key_attr=key_attr,
370+
full_clean=full_clean,
371+
pre_save_hook=pre_save_hook,
372+
exclude_m2m=exclude_m2m,
373+
)
374+
375+
376+
@transaction.atomic
377+
def _create(
378+
info: Info,
379+
manager: BaseManager,
380+
data: dict[str, Any] | list[dict[str, Any]],
381+
*,
382+
key_attr: str | None = None,
383+
full_clean: bool | FullCleanOptions = True,
384+
pre_save_hook: Callable[[_M], None] | None = None,
385+
exclude_m2m: list[str] | None = None,
336386
) -> list[_M] | _M:
387+
model = manager.model
337388
# Before creating your instance, verify this is not a bulk create
338389
# if so, add them one by one. Otherwise, get to work.
339390
if isinstance(data, list):
340391
return [
341-
create(info, model, d, key_attr=key_attr, full_clean=full_clean)
392+
create(
393+
info,
394+
model,
395+
d,
396+
key_attr=key_attr,
397+
full_clean=full_clean,
398+
exclude_m2m=exclude_m2m,
399+
)
342400
for d in data
343401
]
344402

@@ -365,6 +423,7 @@ def create(
365423
data=data,
366424
full_clean=full_clean,
367425
key_attr=key_attr,
426+
exclude_m2m=exclude_m2m,
368427
)
369428

370429
# Creating the instance directly via create() without full-clean will
@@ -376,7 +435,7 @@ def create(
376435

377436
# Create the instance using the manager create method to respect
378437
# manager create overrides. This also ensures support for proxy-models.
379-
instance = model._default_manager.create(**create_kwargs)
438+
instance = manager.create(**create_kwargs)
380439

381440
for field, value in m2m:
382441
update_m2m(info, instance, field, value, key_attr)
@@ -393,6 +452,7 @@ def update(
393452
key_attr: str | None = None,
394453
full_clean: bool | FullCleanOptions = True,
395454
pre_save_hook: Callable[[_M], None] | None = None,
455+
exclude_m2m: list[str] | None = None,
396456
) -> _M: ...
397457

398458

@@ -405,6 +465,7 @@ def update(
405465
key_attr: str | None = None,
406466
full_clean: bool | FullCleanOptions = True,
407467
pre_save_hook: Callable[[_M], None] | None = None,
468+
exclude_m2m: list[str] | None = None,
408469
) -> list[_M]: ...
409470

410471

@@ -417,6 +478,7 @@ def update(
417478
key_attr: str | None = None,
418479
full_clean: bool | FullCleanOptions = True,
419480
pre_save_hook: Callable[[_M], None] | None = None,
481+
exclude_m2m: list[str] | None = None,
420482
) -> _M | list[_M]:
421483
# Unwrap lazy objects since they have a proxy __iter__ method that will make
422484
# them iterables even if the wrapped object isn't
@@ -433,6 +495,7 @@ def update(
433495
key_attr=key_attr,
434496
full_clean=full_clean,
435497
pre_save_hook=pre_save_hook,
498+
exclude_m2m=exclude_m2m,
436499
)
437500
for instance in instances
438501
]
@@ -443,6 +506,7 @@ def update(
443506
data=data,
444507
key_attr=key_attr,
445508
full_clean=full_clean,
509+
exclude_m2m=exclude_m2m,
446510
)
447511

448512
if pre_save_hook is not None:
@@ -554,15 +618,22 @@ def update_m2m(
554618
use_remove = True
555619
if isinstance(field, ManyToManyField):
556620
manager = cast("RelatedManager", getattr(instance, field.attname))
621+
reverse_field_name = field.remote_field.related_name # type: ignore
557622
else:
558623
assert isinstance(field, (ManyToManyRel, ManyToOneRel))
559624
accessor_name = field.get_accessor_name()
625+
reverse_field_name = field.field.name
560626
assert accessor_name
561627
manager = cast("RelatedManager", getattr(instance, accessor_name))
562628
if field.one_to_many:
563629
# remove if field is nullable, otherwise delete
564630
use_remove = field.remote_field.null is True
565631

632+
# Create a data dict containing the reference to the instance and exclude it from
633+
# nested m2m creation (to break circular references)
634+
ref_instance_data = {reverse_field_name: instance}
635+
exclude_m2m = [reverse_field_name]
636+
566637
to_add = []
567638
to_remove = []
568639
to_delete = []
@@ -581,7 +652,11 @@ def update_m2m(
581652
need_remove_cache = need_remove_cache or bool(values)
582653
for v in values:
583654
obj, data = _parse_data(
584-
info, cast("type[Model]", manager.model), v, key_attr=key_attr
655+
info,
656+
cast("type[Model]", manager.model),
657+
v,
658+
key_attr=key_attr,
659+
full_clean=full_clean,
585660
)
586661
if obj:
587662
data.pop(key_attr, None)
@@ -621,14 +696,17 @@ def update_m2m(
621696

622697
existing.discard(obj)
623698
else:
624-
if key_attr not in data: # we have a Input Type
625-
obj, _ = manager.get_or_create(**data)
626-
else:
627-
data.pop(key_attr)
628-
obj = manager.create(**data)
629-
630-
if full_clean:
631-
obj.full_clean(**full_clean_options)
699+
# If we've reached here, the key_attr should be UNSET or missing. So
700+
# let's remove it if it is there.
701+
data.pop(key_attr, None)
702+
obj = _create(
703+
info,
704+
manager,
705+
data | ref_instance_data,
706+
key_attr=key_attr,
707+
full_clean=full_clean,
708+
exclude_m2m=exclude_m2m,
709+
)
632710
existing.discard(obj)
633711

634712
for remaining in existing:
@@ -645,6 +723,7 @@ def update_m2m(
645723
cast("type[Model]", manager.model),
646724
v,
647725
key_attr=key_attr,
726+
full_clean=full_clean,
648727
)
649728
if obj and data:
650729
data.pop(key_attr, None)
@@ -656,18 +735,28 @@ def update_m2m(
656735
data.pop(key_attr, None)
657736
to_add.append(obj)
658737
elif data:
659-
if key_attr not in data:
660-
manager.get_or_create(**data)
661-
else:
662-
data.pop(key_attr)
663-
manager.create(**data)
738+
# If we've reached here, the key_attr should be UNSET or missing. So
739+
# let's remove it if it is there.
740+
data.pop(key_attr, None)
741+
_create(
742+
info,
743+
manager,
744+
data | ref_instance_data,
745+
key_attr=key_attr,
746+
full_clean=full_clean,
747+
exclude_m2m=exclude_m2m,
748+
)
664749
else:
665750
raise AssertionError
666751

667752
need_remove_cache = need_remove_cache or bool(value.remove)
668753
for v in value.remove or []:
669754
obj, data = _parse_data(
670-
info, cast("type[Model]", manager.model), v, key_attr=key_attr
755+
info,
756+
cast("type[Model]", manager.model),
757+
v,
758+
key_attr=key_attr,
759+
full_clean=full_clean,
671760
)
672761
data.pop(key_attr, None)
673762
assert not data

tests/projects/schema.py

+13
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,12 @@ class MilestoneIssueInput:
337337
name: strawberry.auto
338338

339339

340+
@strawberry_django.partial(Issue)
341+
class MilestoneIssueInputPartial:
342+
name: strawberry.auto
343+
tags: Optional[list[TagInputPartial]]
344+
345+
340346
@strawberry_django.partial(Project)
341347
class ProjectInputPartial(NodeInputPartial):
342348
name: strawberry.auto
@@ -353,6 +359,8 @@ class MilestoneInput:
353359
@strawberry_django.partial(Milestone)
354360
class MilestoneInputPartial(NodeInputPartial):
355361
name: strawberry.auto
362+
issues: Optional[list[MilestoneIssueInputPartial]]
363+
project: Optional[ProjectInputPartial]
356364

357365

358366
@strawberry.type
@@ -521,6 +529,11 @@ class Mutation:
521529
argument_name="input",
522530
key_attr="name",
523531
)
532+
create_project_with_milestones: ProjectType = mutations.create(
533+
ProjectInputPartial,
534+
handle_django_errors=True,
535+
argument_name="input",
536+
)
524537
update_project: ProjectType = mutations.update(
525538
ProjectInputPartial,
526539
handle_django_errors=True,

tests/projects/snapshots/schema.gql

+10
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ input CreateProjectInput {
9595

9696
union CreateProjectPayload = ProjectType | OperationInfo
9797

98+
union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo
99+
98100
input CreateQuizInput {
99101
title: String!
100102
fullCleanOptions: Boolean! = false
@@ -365,12 +367,19 @@ input MilestoneInput {
365367
input MilestoneInputPartial {
366368
id: GlobalID
367369
name: String
370+
issues: [MilestoneIssueInputPartial!]
371+
project: ProjectInputPartial
368372
}
369373

370374
input MilestoneIssueInput {
371375
name: String!
372376
}
373377

378+
input MilestoneIssueInputPartial {
379+
name: String
380+
tags: [TagInputPartial!]
381+
}
382+
374383
input MilestoneOrder {
375384
name: Ordering
376385
project: ProjectOrder
@@ -433,6 +442,7 @@ type Mutation {
433442
updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload!
434443
deleteIssue(input: NodeInput!): DeleteIssuePayload!
435444
deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload!
445+
createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload!
436446
updateProject(input: ProjectInputPartial!): UpdateProjectPayload!
437447
createMilestone(input: MilestoneInput!): CreateMilestonePayload!
438448
createProject(

0 commit comments

Comments
 (0)