Skip to content

Commit a9c7e35

Browse files
Multiple Updates
1 parent 0aaa4d8 commit a9c7e35

File tree

3 files changed

+167
-119
lines changed

3 files changed

+167
-119
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: '^docs/conf.py'
22

33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.3.0
5+
rev: v4.4.0
66
hooks:
77
- id: trailing-whitespace
88
- id: check-added-large-files
@@ -28,18 +28,18 @@ repos:
2828
# --remove-unused-variables,
2929
# ]
3030
- repo: https://github.com/hadialqattan/pycln
31-
rev: v2.1.1
31+
rev: v2.1.3
3232
hooks:
3333
- id: pycln
3434
args: [--config=setup.cfg]
3535

3636
- repo: https://github.com/pycqa/isort
37-
rev: 5.10.1
37+
rev: 5.12.0
3838
hooks:
3939
- id: isort
4040

4141
- repo: https://github.com/psf/black
42-
rev: 22.8.0
42+
rev: 23.3.0
4343
hooks:
4444
- id: black
4545
language_version: python3
@@ -52,7 +52,7 @@ repos:
5252
# additional_dependencies: [black]
5353

5454
- repo: https://github.com/PyCQA/flake8
55-
rev: 5.0.4
55+
rev: 6.0.0
5656
hooks:
5757
- id: flake8
5858
## You can add flake8 plugins via `additional_dependencies`:

src/strawberry_sqlalchemy_mapper/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ async def load_fn(keys: List[Tuple]) -> List[Any]:
3939
def group_by_remote_key(row: Any) -> Tuple:
4040
return tuple(
4141
[
42-
getattr(row, remote.key)
42+
[
43+
getattr(row, k)
44+
for k, column in row.__mapper__.c.items()
45+
if remote.key == column.key
46+
][0]
4347
for _, remote in relationship.local_remote_pairs
4448
]
4549
)

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 157 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def __init__(
151151
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
152152
Mapping[Type[TypeEngine], Type[Any]]
153153
] = None,
154+
edge_type: Type = None,
155+
connection_type: Type = None,
154156
) -> None:
155157
if model_to_type_name is None:
156158
model_to_type_name = self._default_model_to_type_name
@@ -172,6 +174,9 @@ def __init__(
172174
self._related_type_models = set()
173175
self._related_interface_models = set()
174176

177+
self.edge_type = edge_type
178+
self.connection_type = connection_type
179+
175180
@staticmethod
176181
def _default_model_to_type_name(model: Type[BaseModelType]) -> str:
177182
return model.__name__
@@ -211,6 +216,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
211216
Get or create a corresponding Edge model for the given type
212217
(to support future pagination)
213218
"""
219+
if self.edge_type is not None:
220+
return self.edge_type
214221
edge_name = f"{type_name}Edge"
215222
if edge_name not in self.edge_types:
216223
self.edge_types[edge_name] = edge_type = strawberry.type(
@@ -229,6 +236,8 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
229236
Get or create a corresponding Connection model for the given type
230237
(to support future pagination)
231238
"""
239+
if self.connection_type is not None:
240+
return self.connection_type[ForwardRef(type_name)]
232241
connection_name = f"{type_name}Connection"
233242
if connection_name not in self.connection_types:
234243
self.connection_types[connection_name] = connection_type = strawberry.type(
@@ -259,7 +268,7 @@ def _convert_column_to_strawberry_type(
259268
corresponding strawberry type.
260269
"""
261270
if isinstance(column.type, Enum):
262-
type_annotation = column.type.python_type
271+
type_annotation = strawberry.enum(column.type.python_type)
263272
elif isinstance(column.type, ARRAY):
264273
item_type = self._convert_column_to_strawberry_type(
265274
Column(column.type.item_type, nullable=False)
@@ -404,7 +413,11 @@ async def resolve(self, info: Info):
404413
else:
405414
relationship_key = tuple(
406415
[
407-
getattr(self, local.key)
416+
[
417+
getattr(self, k)
418+
for k, column in self.__mapper__.c.items()
419+
if local.key == column.key
420+
][0]
408421
for local, _ in relationship.local_remote_pairs
409422
]
410423
)
@@ -539,6 +552,7 @@ def type(
539552
model: Type[BaseModelType],
540553
make_interface=False,
541554
use_federation=False,
555+
**kwargs,
542556
) -> Callable[[Type[object]], Any]:
543557
"""
544558
Decorate a type with this to register it as a strawberry type
@@ -560,128 +574,158 @@ class Employee:
560574
```
561575
"""
562576

563-
def convert(type_: Any) -> Any:
564-
old_annotations = getattr(type_, "__annotations__", {})
565-
type_.__annotations__ = {}
566-
mapper: Mapper = inspect(model)
567-
generated_field_keys = []
568-
569-
excluded_keys = getattr(type_, "__exclude__", [])
570-
571-
# if the type inherits from another mapped type, then it may have
572-
# generated resolvers. These will be treated by dataclasses as having
573-
# a default value, which will likely cause issues because of keys
574-
# that don't have default values. To fix this, we wrap them in
575-
# `strawberry.field()` (like when they were originally made), so
576-
# dataclasses will ignore them.
577-
# TODO: Potentially raise/fix this issue upstream
578-
for key in dir(type_):
579-
val = getattr(type_, key)
580-
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
581-
setattr(type_, key, strawberry.field(resolver=val))
582-
generated_field_keys.append(key)
583-
584-
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
585-
for key, relationship in mapper.relationships.items():
586-
relationship: RelationshipProperty
587-
if (
588-
key in excluded_keys
589-
or key in type_.__annotations__
590-
or hasattr(type_, key)
591-
):
592-
continue
593-
strawberry_type = self._convert_relationship_to_strawberry_type(
594-
relationship
595-
)
596-
self._add_annotation(
597-
type_,
598-
key,
577+
def do_conversion(type_):
578+
return self.convert(
579+
type_,
580+
model,
581+
make_interface,
582+
use_federation,
583+
)
584+
585+
return do_conversion
586+
587+
def convert(
588+
self,
589+
type_: Any,
590+
model: Type[BaseModelType],
591+
make_interface=False,
592+
use_federation=False,
593+
) -> Any:
594+
"""
595+
Do type conversion. Usually accessed using typical .type decorator. But
596+
can also be used as standalone function.
597+
"""
598+
old_annotations = getattr(type_, "__annotations__", {})
599+
type_.__annotations__ = {}
600+
mapper: Mapper = inspect(model)
601+
generated_field_keys = []
602+
603+
excluded_keys = getattr(type_, "__exclude__", [])
604+
605+
# if the type inherits from another mapped type, then it may have
606+
# generated resolvers. These will be treated by dataclasses as having
607+
# a default value, which will likely cause issues because of keys
608+
# that don't have default values. To fix this, we wrap them in
609+
# `strawberry.field()` (like when they were originally made), so
610+
# dataclasses will ignore them.
611+
# TODO: Potentially raise/fix this issue upstream
612+
for key in dir(type_):
613+
val = getattr(type_, key)
614+
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
615+
setattr(type_, key, strawberry.field(resolver=val))
616+
generated_field_keys.append(key)
617+
618+
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
619+
for key, relationship in mapper.relationships.items():
620+
relationship: RelationshipProperty
621+
if (
622+
key in excluded_keys
623+
or key in type_.__annotations__
624+
or hasattr(type_, key)
625+
):
626+
continue
627+
strawberry_type = self._convert_relationship_to_strawberry_type(
628+
relationship
629+
)
630+
self._add_annotation(
631+
type_,
632+
key,
633+
strawberry_type,
634+
generated_field_keys,
635+
)
636+
field = strawberry.field(
637+
resolver=self.connection_resolver_for(relationship)
638+
)
639+
assert not field.init
640+
setattr(
641+
type_,
642+
key,
643+
field,
644+
)
645+
for key, descriptor in mapper.all_orm_descriptors.items():
646+
if (
647+
key in excluded_keys
648+
or key in type_.__annotations__
649+
or hasattr(type_, key)
650+
):
651+
continue
652+
if key in mapper.columns or key in mapper.relationships:
653+
continue
654+
if key in model.__annotations__:
655+
annotation = eval(model.__annotations__[key])
656+
for (
657+
sqlalchemy_type,
599658
strawberry_type,
600-
generated_field_keys,
659+
) in self.sqlalchemy_type_to_strawberry_type_map.items():
660+
if isinstance(annotation, sqlalchemy_type):
661+
self._add_annotation(
662+
type_, key, strawberry_type, generated_field_keys
663+
)
664+
break
665+
elif isinstance(descriptor, AssociationProxy):
666+
strawberry_type = self._get_association_proxy_annotation(
667+
mapper, key, descriptor
601668
)
669+
if strawberry_type is SkipTypeSentinel:
670+
continue
671+
self._add_annotation(type_, key, strawberry_type, generated_field_keys)
602672
field = strawberry.field(
603-
resolver=self.connection_resolver_for(relationship)
673+
resolver=self.association_proxy_resolver_for(
674+
mapper, descriptor, strawberry_type
675+
)
604676
)
605677
assert not field.init
606-
setattr(
678+
setattr(type_, key, field)
679+
elif isinstance(descriptor, hybrid_property):
680+
if (
681+
not hasattr(descriptor, "__annotations__")
682+
or "return" not in descriptor.__annotations__
683+
):
684+
raise HybridPropertyNotAnnotated(key)
685+
annotation = descriptor.__annotations__["return"]
686+
if isinstance(annotation, str):
687+
try:
688+
if "typing" in annotation:
689+
# Try to evaluate from existing typing imports
690+
annotation = annotation[7:]
691+
annotation = eval(annotation)
692+
except NameError:
693+
raise UnsupportedDescriptorType(key)
694+
self._add_annotation(
607695
type_,
608696
key,
609-
field,
697+
annotation,
698+
generated_field_keys,
610699
)
611-
for key, descriptor in mapper.all_orm_descriptors.items():
612-
if (
613-
key in excluded_keys
614-
or key in type_.__annotations__
615-
or hasattr(type_, key)
616-
):
617-
continue
618-
if key in mapper.columns or key in mapper.relationships:
619-
continue
620-
if isinstance(descriptor, AssociationProxy):
621-
strawberry_type = self._get_association_proxy_annotation(
622-
mapper, key, descriptor
623-
)
624-
if strawberry_type is SkipTypeSentinel:
625-
continue
626-
self._add_annotation(
627-
type_, key, strawberry_type, generated_field_keys
628-
)
629-
field = strawberry.field(
630-
resolver=self.association_proxy_resolver_for(
631-
mapper, descriptor, strawberry_type
632-
)
633-
)
634-
assert not field.init
635-
setattr(type_, key, field)
636-
elif isinstance(descriptor, hybrid_property):
637-
if (
638-
not hasattr(descriptor, "__annotations__")
639-
or "return" not in descriptor.__annotations__
640-
):
641-
raise HybridPropertyNotAnnotated(key)
642-
annotation = descriptor.__annotations__["return"]
643-
if isinstance(annotation, str):
644-
try:
645-
if "typing" in annotation:
646-
# Try to evaluate from existing typing imports
647-
annotation = annotation[7:]
648-
annotation = eval(annotation)
649-
except NameError:
650-
raise UnsupportedDescriptorType(key)
651-
self._add_annotation(
652-
type_,
653-
key,
654-
annotation,
655-
generated_field_keys,
656-
)
657-
else:
658-
raise UnsupportedDescriptorType(key)
700+
else:
701+
raise UnsupportedDescriptorType(key)
659702

660-
# ignore inherited `is_type_of`
661-
if "is_type_of" not in type_.__dict__:
662-
type_.is_type_of = (
663-
lambda obj, info: type(obj) == model or type(obj) == type_
664-
)
703+
# ignore inherited `is_type_of`
704+
if "is_type_of" not in type_.__dict__:
705+
type_.is_type_of = (
706+
lambda obj, info: type(obj) == model or type(obj) == type_
707+
)
665708

666-
# need to make fields that are already in the type
667-
# (prior to mapping) appear *after* the mapped fields
668-
# because the pre-existing fields might have default values,
669-
# which will cause the mapped fields to fail
670-
# (because they may not have default values)
671-
type_.__annotations__.update(old_annotations)
672-
673-
if make_interface:
674-
mapped_type = strawberry.interface(type_)
675-
elif use_federation:
676-
mapped_type = strawberry.federation.type(type_)
677-
else:
678-
mapped_type = strawberry.type(type_)
679-
self.mapped_types[type_.__name__] = mapped_type
680-
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
681-
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
682-
return mapped_type
709+
# need to make fields that are already in the type
710+
# (prior to mapping) appear *after* the mapped fields
711+
# because the pre-existing fields might have default values,
712+
# which will cause the mapped fields to fail
713+
# (because they may not have default values)
714+
type_.__annotations__.update(old_annotations)
683715

684-
return convert
716+
if make_interface:
717+
type_name = self.model_to_interface_name(type_)
718+
mapped_type = strawberry.interface(type_, name=type_name)
719+
else:
720+
type_name = self.model_to_type_name(type_)
721+
if use_federation:
722+
mapped_type = strawberry.federation.type(type_, name=type_name)
723+
else:
724+
mapped_type = strawberry.type(type_, name=type_name)
725+
self.mapped_types[type_name] = mapped_type
726+
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
727+
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
728+
return mapped_type
685729

686730
def interface(self, model: Type[BaseModelType]) -> Callable[[Type[object]], Any]:
687731
"""

0 commit comments

Comments
 (0)