Skip to content

Commit daa654d

Browse files
Implemented Standalone Convert Function
1 parent 0aaa4d8 commit daa654d

File tree

2 files changed

+143
-112
lines changed

2 files changed

+143
-112
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ repos:
3434
args: [--config=setup.cfg]
3535

3636
- repo: https://github.com/pycqa/isort
37-
rev: 5.10.1
37+
rev: 5.12.0
3838
hooks:
3939
- id: isort
4040

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 142 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def type(
539539
model: Type[BaseModelType],
540540
make_interface=False,
541541
use_federation=False,
542+
**kwargs,
542543
) -> Callable[[Type[object]], Any]:
543544
"""
544545
Decorate a type with this to register it as a strawberry type
@@ -560,128 +561,158 @@ class Employee:
560561
```
561562
"""
562563

563-
def convert(type_: Any) -> Any:
564-
old_annotations = getattr(type_, "__annotations__", {})
565-
type_.__annotations__ = {}
566-
mapper: Mapper = inspect(model)
567-
generated_field_keys = []
568-
569-
excluded_keys = getattr(type_, "__exclude__", [])
570-
571-
# if the type inherits from another mapped type, then it may have
572-
# generated resolvers. These will be treated by dataclasses as having
573-
# a default value, which will likely cause issues because of keys
574-
# that don't have default values. To fix this, we wrap them in
575-
# `strawberry.field()` (like when they were originally made), so
576-
# dataclasses will ignore them.
577-
# TODO: Potentially raise/fix this issue upstream
578-
for key in dir(type_):
579-
val = getattr(type_, key)
580-
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
581-
setattr(type_, key, strawberry.field(resolver=val))
582-
generated_field_keys.append(key)
583-
584-
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
585-
for key, relationship in mapper.relationships.items():
586-
relationship: RelationshipProperty
587-
if (
588-
key in excluded_keys
589-
or key in type_.__annotations__
590-
or hasattr(type_, key)
591-
):
592-
continue
593-
strawberry_type = self._convert_relationship_to_strawberry_type(
594-
relationship
595-
)
596-
self._add_annotation(
597-
type_,
598-
key,
564+
def do_conversion(type_):
565+
return self.convert(
566+
type_,
567+
model,
568+
make_interface,
569+
use_federation,
570+
)
571+
572+
return do_conversion
573+
574+
def convert(
575+
self,
576+
type_: Any,
577+
model: Type[BaseModelType],
578+
make_interface=False,
579+
use_federation=False,
580+
) -> Any:
581+
"""
582+
Do type conversion. Usually accessed using typical .type decorator. But
583+
can also be used as standalone function.
584+
"""
585+
old_annotations = getattr(type_, "__annotations__", {})
586+
type_.__annotations__ = {}
587+
mapper: Mapper = inspect(model)
588+
generated_field_keys = []
589+
590+
excluded_keys = getattr(type_, "__exclude__", [])
591+
592+
# if the type inherits from another mapped type, then it may have
593+
# generated resolvers. These will be treated by dataclasses as having
594+
# a default value, which will likely cause issues because of keys
595+
# that don't have default values. To fix this, we wrap them in
596+
# `strawberry.field()` (like when they were originally made), so
597+
# dataclasses will ignore them.
598+
# TODO: Potentially raise/fix this issue upstream
599+
for key in dir(type_):
600+
val = getattr(type_, key)
601+
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
602+
setattr(type_, key, strawberry.field(resolver=val))
603+
generated_field_keys.append(key)
604+
605+
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
606+
for key, relationship in mapper.relationships.items():
607+
relationship: RelationshipProperty
608+
if (
609+
key in excluded_keys
610+
or key in type_.__annotations__
611+
or hasattr(type_, key)
612+
):
613+
continue
614+
strawberry_type = self._convert_relationship_to_strawberry_type(
615+
relationship
616+
)
617+
self._add_annotation(
618+
type_,
619+
key,
620+
strawberry_type,
621+
generated_field_keys,
622+
)
623+
field = strawberry.field(
624+
resolver=self.connection_resolver_for(relationship)
625+
)
626+
assert not field.init
627+
setattr(
628+
type_,
629+
key,
630+
field,
631+
)
632+
for key, descriptor in mapper.all_orm_descriptors.items():
633+
if (
634+
key in excluded_keys
635+
or key in type_.__annotations__
636+
or hasattr(type_, key)
637+
):
638+
continue
639+
if key in mapper.columns or key in mapper.relationships:
640+
continue
641+
if key in model.__annotations__:
642+
annotation = eval(model.__annotations__[key])
643+
for (
644+
sqlalchemy_type,
599645
strawberry_type,
600-
generated_field_keys,
646+
) in self.sqlalchemy_type_to_strawberry_type_map.items():
647+
if isinstance(annotation, sqlalchemy_type):
648+
self._add_annotation(
649+
type_, key, strawberry_type, generated_field_keys
650+
)
651+
break
652+
elif isinstance(descriptor, AssociationProxy):
653+
strawberry_type = self._get_association_proxy_annotation(
654+
mapper, key, descriptor
601655
)
656+
if strawberry_type is SkipTypeSentinel:
657+
continue
658+
self._add_annotation(type_, key, strawberry_type, generated_field_keys)
602659
field = strawberry.field(
603-
resolver=self.connection_resolver_for(relationship)
660+
resolver=self.association_proxy_resolver_for(
661+
mapper, descriptor, strawberry_type
662+
)
604663
)
605664
assert not field.init
606-
setattr(
665+
setattr(type_, key, field)
666+
elif isinstance(descriptor, hybrid_property):
667+
if (
668+
not hasattr(descriptor, "__annotations__")
669+
or "return" not in descriptor.__annotations__
670+
):
671+
raise HybridPropertyNotAnnotated(key)
672+
annotation = descriptor.__annotations__["return"]
673+
if isinstance(annotation, str):
674+
try:
675+
if "typing" in annotation:
676+
# Try to evaluate from existing typing imports
677+
annotation = annotation[7:]
678+
annotation = eval(annotation)
679+
except NameError:
680+
raise UnsupportedDescriptorType(key)
681+
self._add_annotation(
607682
type_,
608683
key,
609-
field,
684+
annotation,
685+
generated_field_keys,
610686
)
611-
for key, descriptor in mapper.all_orm_descriptors.items():
612-
if (
613-
key in excluded_keys
614-
or key in type_.__annotations__
615-
or hasattr(type_, key)
616-
):
617-
continue
618-
if key in mapper.columns or key in mapper.relationships:
619-
continue
620-
if isinstance(descriptor, AssociationProxy):
621-
strawberry_type = self._get_association_proxy_annotation(
622-
mapper, key, descriptor
623-
)
624-
if strawberry_type is SkipTypeSentinel:
625-
continue
626-
self._add_annotation(
627-
type_, key, strawberry_type, generated_field_keys
628-
)
629-
field = strawberry.field(
630-
resolver=self.association_proxy_resolver_for(
631-
mapper, descriptor, strawberry_type
632-
)
633-
)
634-
assert not field.init
635-
setattr(type_, key, field)
636-
elif isinstance(descriptor, hybrid_property):
637-
if (
638-
not hasattr(descriptor, "__annotations__")
639-
or "return" not in descriptor.__annotations__
640-
):
641-
raise HybridPropertyNotAnnotated(key)
642-
annotation = descriptor.__annotations__["return"]
643-
if isinstance(annotation, str):
644-
try:
645-
if "typing" in annotation:
646-
# Try to evaluate from existing typing imports
647-
annotation = annotation[7:]
648-
annotation = eval(annotation)
649-
except NameError:
650-
raise UnsupportedDescriptorType(key)
651-
self._add_annotation(
652-
type_,
653-
key,
654-
annotation,
655-
generated_field_keys,
656-
)
657-
else:
658-
raise UnsupportedDescriptorType(key)
687+
else:
688+
raise UnsupportedDescriptorType(key)
659689

660-
# ignore inherited `is_type_of`
661-
if "is_type_of" not in type_.__dict__:
662-
type_.is_type_of = (
663-
lambda obj, info: type(obj) == model or type(obj) == type_
664-
)
690+
# ignore inherited `is_type_of`
691+
if "is_type_of" not in type_.__dict__:
692+
type_.is_type_of = (
693+
lambda obj, info: type(obj) == model or type(obj) == type_
694+
)
665695

666-
# need to make fields that are already in the type
667-
# (prior to mapping) appear *after* the mapped fields
668-
# because the pre-existing fields might have default values,
669-
# which will cause the mapped fields to fail
670-
# (because they may not have default values)
671-
type_.__annotations__.update(old_annotations)
672-
673-
if make_interface:
674-
mapped_type = strawberry.interface(type_)
675-
elif use_federation:
676-
mapped_type = strawberry.federation.type(type_)
677-
else:
678-
mapped_type = strawberry.type(type_)
679-
self.mapped_types[type_.__name__] = mapped_type
680-
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
681-
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
682-
return mapped_type
696+
# need to make fields that are already in the type
697+
# (prior to mapping) appear *after* the mapped fields
698+
# because the pre-existing fields might have default values,
699+
# which will cause the mapped fields to fail
700+
# (because they may not have default values)
701+
type_.__annotations__.update(old_annotations)
683702

684-
return convert
703+
if make_interface:
704+
type_name = self.model_to_interface_name(type_)
705+
mapped_type = strawberry.interface(type_, name=type_name)
706+
else:
707+
type_name = self.model_to_type_name(type_)
708+
if use_federation:
709+
mapped_type = strawberry.federation.type(type_, name=type_name)
710+
else:
711+
mapped_type = strawberry.type(type_, name=type_name)
712+
self.mapped_types[type_name] = mapped_type
713+
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
714+
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
715+
return mapped_type
685716

686717
def interface(self, model: Type[BaseModelType]) -> Callable[[Type[object]], Any]:
687718
"""

0 commit comments

Comments
 (0)