@@ -151,6 +151,8 @@ def __init__(
151
151
extra_sqlalchemy_type_to_strawberry_type_map : Optional [
152
152
Mapping [Type [TypeEngine ], Type [Any ]]
153
153
] = None ,
154
+ edge_type : Type = None ,
155
+ connection_type : Type = None ,
154
156
) -> None :
155
157
if model_to_type_name is None :
156
158
model_to_type_name = self ._default_model_to_type_name
@@ -172,6 +174,9 @@ def __init__(
172
174
self ._related_type_models = set ()
173
175
self ._related_interface_models = set ()
174
176
177
+ self .edge_type = edge_type
178
+ self .connection_type = connection_type
179
+
175
180
@staticmethod
176
181
def _default_model_to_type_name (model : Type [BaseModelType ]) -> str :
177
182
return model .__name__
@@ -211,6 +216,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
211
216
Get or create a corresponding Edge model for the given type
212
217
(to support future pagination)
213
218
"""
219
+ if self .edge_type is not None :
220
+ return self .edge_type
214
221
edge_name = f"{ type_name } Edge"
215
222
if edge_name not in self .edge_types :
216
223
self .edge_types [edge_name ] = edge_type = strawberry .type (
@@ -229,6 +236,8 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
229
236
Get or create a corresponding Connection model for the given type
230
237
(to support future pagination)
231
238
"""
239
+ if self .connection_type is not None :
240
+ return self .connection_type [ForwardRef (type_name )]
232
241
connection_name = f"{ type_name } Connection"
233
242
if connection_name not in self .connection_types :
234
243
self .connection_types [connection_name ] = connection_type = strawberry .type (
@@ -259,7 +268,7 @@ def _convert_column_to_strawberry_type(
259
268
corresponding strawberry type.
260
269
"""
261
270
if isinstance (column .type , Enum ):
262
- type_annotation = column .type .python_type
271
+ type_annotation = strawberry . enum ( column .type .python_type )
263
272
elif isinstance (column .type , ARRAY ):
264
273
item_type = self ._convert_column_to_strawberry_type (
265
274
Column (column .type .item_type , nullable = False )
@@ -404,7 +413,11 @@ async def resolve(self, info: Info):
404
413
else :
405
414
relationship_key = tuple (
406
415
[
407
- getattr (self , local .key )
416
+ [
417
+ getattr (self , k )
418
+ for k , column in self .__mapper__ .c .items ()
419
+ if local .key == column .key
420
+ ][0 ]
408
421
for local , _ in relationship .local_remote_pairs
409
422
]
410
423
)
@@ -539,6 +552,7 @@ def type(
539
552
model : Type [BaseModelType ],
540
553
make_interface = False ,
541
554
use_federation = False ,
555
+ ** kwargs ,
542
556
) -> Callable [[Type [object ]], Any ]:
543
557
"""
544
558
Decorate a type with this to register it as a strawberry type
@@ -560,128 +574,158 @@ class Employee:
560
574
```
561
575
"""
562
576
563
- def convert (type_ : Any ) -> Any :
564
- old_annotations = getattr (type_ , "__annotations__" , {})
565
- type_ .__annotations__ = {}
566
- mapper : Mapper = inspect (model )
567
- generated_field_keys = []
568
-
569
- excluded_keys = getattr (type_ , "__exclude__" , [])
570
-
571
- # if the type inherits from another mapped type, then it may have
572
- # generated resolvers. These will be treated by dataclasses as having
573
- # a default value, which will likely cause issues because of keys
574
- # that don't have default values. To fix this, we wrap them in
575
- # `strawberry.field()` (like when they were originally made), so
576
- # dataclasses will ignore them.
577
- # TODO: Potentially raise/fix this issue upstream
578
- for key in dir (type_ ):
579
- val = getattr (type_ , key )
580
- if getattr (val , _IS_GENERATED_RESOLVER_KEY , False ):
581
- setattr (type_ , key , strawberry .field (resolver = val ))
582
- generated_field_keys .append (key )
583
-
584
- self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
585
- for key , relationship in mapper .relationships .items ():
586
- relationship : RelationshipProperty
587
- if (
588
- key in excluded_keys
589
- or key in type_ .__annotations__
590
- or hasattr (type_ , key )
591
- ):
592
- continue
593
- strawberry_type = self ._convert_relationship_to_strawberry_type (
594
- relationship
595
- )
596
- self ._add_annotation (
597
- type_ ,
598
- key ,
577
+ def do_conversion (type_ ):
578
+ return self .convert (
579
+ type_ ,
580
+ model ,
581
+ make_interface ,
582
+ use_federation ,
583
+ )
584
+
585
+ return do_conversion
586
+
587
+ def convert (
588
+ self ,
589
+ type_ : Any ,
590
+ model : Type [BaseModelType ],
591
+ make_interface = False ,
592
+ use_federation = False ,
593
+ ) -> Any :
594
+ """
595
+ Do type conversion. Usually accessed using typical .type decorator. But
596
+ can also be used as standalone function.
597
+ """
598
+ old_annotations = getattr (type_ , "__annotations__" , {})
599
+ type_ .__annotations__ = {}
600
+ mapper : Mapper = inspect (model )
601
+ generated_field_keys = []
602
+
603
+ excluded_keys = getattr (type_ , "__exclude__" , [])
604
+
605
+ # if the type inherits from another mapped type, then it may have
606
+ # generated resolvers. These will be treated by dataclasses as having
607
+ # a default value, which will likely cause issues because of keys
608
+ # that don't have default values. To fix this, we wrap them in
609
+ # `strawberry.field()` (like when they were originally made), so
610
+ # dataclasses will ignore them.
611
+ # TODO: Potentially raise/fix this issue upstream
612
+ for key in dir (type_ ):
613
+ val = getattr (type_ , key )
614
+ if getattr (val , _IS_GENERATED_RESOLVER_KEY , False ):
615
+ setattr (type_ , key , strawberry .field (resolver = val ))
616
+ generated_field_keys .append (key )
617
+
618
+ self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
619
+ for key , relationship in mapper .relationships .items ():
620
+ relationship : RelationshipProperty
621
+ if (
622
+ key in excluded_keys
623
+ or key in type_ .__annotations__
624
+ or hasattr (type_ , key )
625
+ ):
626
+ continue
627
+ strawberry_type = self ._convert_relationship_to_strawberry_type (
628
+ relationship
629
+ )
630
+ self ._add_annotation (
631
+ type_ ,
632
+ key ,
633
+ strawberry_type ,
634
+ generated_field_keys ,
635
+ )
636
+ field = strawberry .field (
637
+ resolver = self .connection_resolver_for (relationship )
638
+ )
639
+ assert not field .init
640
+ setattr (
641
+ type_ ,
642
+ key ,
643
+ field ,
644
+ )
645
+ for key , descriptor in mapper .all_orm_descriptors .items ():
646
+ if (
647
+ key in excluded_keys
648
+ or key in type_ .__annotations__
649
+ or hasattr (type_ , key )
650
+ ):
651
+ continue
652
+ if key in mapper .columns or key in mapper .relationships :
653
+ continue
654
+ if key in model .__annotations__ :
655
+ annotation = eval (model .__annotations__ [key ])
656
+ for (
657
+ sqlalchemy_type ,
599
658
strawberry_type ,
600
- generated_field_keys ,
659
+ ) in self .sqlalchemy_type_to_strawberry_type_map .items ():
660
+ if isinstance (annotation , sqlalchemy_type ):
661
+ self ._add_annotation (
662
+ type_ , key , strawberry_type , generated_field_keys
663
+ )
664
+ break
665
+ elif isinstance (descriptor , AssociationProxy ):
666
+ strawberry_type = self ._get_association_proxy_annotation (
667
+ mapper , key , descriptor
601
668
)
669
+ if strawberry_type is SkipTypeSentinel :
670
+ continue
671
+ self ._add_annotation (type_ , key , strawberry_type , generated_field_keys )
602
672
field = strawberry .field (
603
- resolver = self .connection_resolver_for (relationship )
673
+ resolver = self .association_proxy_resolver_for (
674
+ mapper , descriptor , strawberry_type
675
+ )
604
676
)
605
677
assert not field .init
606
- setattr (
678
+ setattr (type_ , key , field )
679
+ elif isinstance (descriptor , hybrid_property ):
680
+ if (
681
+ not hasattr (descriptor , "__annotations__" )
682
+ or "return" not in descriptor .__annotations__
683
+ ):
684
+ raise HybridPropertyNotAnnotated (key )
685
+ annotation = descriptor .__annotations__ ["return" ]
686
+ if isinstance (annotation , str ):
687
+ try :
688
+ if "typing" in annotation :
689
+ # Try to evaluate from existing typing imports
690
+ annotation = annotation [7 :]
691
+ annotation = eval (annotation )
692
+ except NameError :
693
+ raise UnsupportedDescriptorType (key )
694
+ self ._add_annotation (
607
695
type_ ,
608
696
key ,
609
- field ,
697
+ annotation ,
698
+ generated_field_keys ,
610
699
)
611
- for key , descriptor in mapper .all_orm_descriptors .items ():
612
- if (
613
- key in excluded_keys
614
- or key in type_ .__annotations__
615
- or hasattr (type_ , key )
616
- ):
617
- continue
618
- if key in mapper .columns or key in mapper .relationships :
619
- continue
620
- if isinstance (descriptor , AssociationProxy ):
621
- strawberry_type = self ._get_association_proxy_annotation (
622
- mapper , key , descriptor
623
- )
624
- if strawberry_type is SkipTypeSentinel :
625
- continue
626
- self ._add_annotation (
627
- type_ , key , strawberry_type , generated_field_keys
628
- )
629
- field = strawberry .field (
630
- resolver = self .association_proxy_resolver_for (
631
- mapper , descriptor , strawberry_type
632
- )
633
- )
634
- assert not field .init
635
- setattr (type_ , key , field )
636
- elif isinstance (descriptor , hybrid_property ):
637
- if (
638
- not hasattr (descriptor , "__annotations__" )
639
- or "return" not in descriptor .__annotations__
640
- ):
641
- raise HybridPropertyNotAnnotated (key )
642
- annotation = descriptor .__annotations__ ["return" ]
643
- if isinstance (annotation , str ):
644
- try :
645
- if "typing" in annotation :
646
- # Try to evaluate from existing typing imports
647
- annotation = annotation [7 :]
648
- annotation = eval (annotation )
649
- except NameError :
650
- raise UnsupportedDescriptorType (key )
651
- self ._add_annotation (
652
- type_ ,
653
- key ,
654
- annotation ,
655
- generated_field_keys ,
656
- )
657
- else :
658
- raise UnsupportedDescriptorType (key )
700
+ else :
701
+ raise UnsupportedDescriptorType (key )
659
702
660
- # ignore inherited `is_type_of`
661
- if "is_type_of" not in type_ .__dict__ :
662
- type_ .is_type_of = (
663
- lambda obj , info : type (obj ) == model or type (obj ) == type_
664
- )
703
+ # ignore inherited `is_type_of`
704
+ if "is_type_of" not in type_ .__dict__ :
705
+ type_ .is_type_of = (
706
+ lambda obj , info : type (obj ) == model or type (obj ) == type_
707
+ )
665
708
666
- # need to make fields that are already in the type
667
- # (prior to mapping) appear *after* the mapped fields
668
- # because the pre-existing fields might have default values,
669
- # which will cause the mapped fields to fail
670
- # (because they may not have default values)
671
- type_ .__annotations__ .update (old_annotations )
672
-
673
- if make_interface :
674
- mapped_type = strawberry .interface (type_ )
675
- elif use_federation :
676
- mapped_type = strawberry .federation .type (type_ )
677
- else :
678
- mapped_type = strawberry .type (type_ )
679
- self .mapped_types [type_ .__name__ ] = mapped_type
680
- setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
681
- setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
682
- return mapped_type
709
+ # need to make fields that are already in the type
710
+ # (prior to mapping) appear *after* the mapped fields
711
+ # because the pre-existing fields might have default values,
712
+ # which will cause the mapped fields to fail
713
+ # (because they may not have default values)
714
+ type_ .__annotations__ .update (old_annotations )
683
715
684
- return convert
716
+ if make_interface :
717
+ type_name = self .model_to_interface_name (type_ )
718
+ mapped_type = strawberry .interface (type_ , name = type_name )
719
+ else :
720
+ type_name = self .model_to_type_name (type_ )
721
+ if use_federation :
722
+ mapped_type = strawberry .federation .type (type_ , name = type_name )
723
+ else :
724
+ mapped_type = strawberry .type (type_ , name = type_name )
725
+ self .mapped_types [type_name ] = mapped_type
726
+ setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
727
+ setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
728
+ return mapped_type
685
729
686
730
def interface (self , model : Type [BaseModelType ]) -> Callable [[Type [object ]], Any ]:
687
731
"""
0 commit comments