@@ -151,6 +151,8 @@ def __init__(
151
151
extra_sqlalchemy_type_to_strawberry_type_map : Optional [
152
152
Mapping [Type [TypeEngine ], Type [Any ]]
153
153
] = None ,
154
+ edge_type : Type = None ,
155
+ connection_type : Type = None ,
154
156
) -> None :
155
157
if model_to_type_name is None :
156
158
model_to_type_name = self ._default_model_to_type_name
@@ -172,6 +174,9 @@ def __init__(
172
174
self ._related_type_models = set ()
173
175
self ._related_interface_models = set ()
174
176
177
+ self .edge_type = edge_type
178
+ self .connection_type = connection_type
179
+
175
180
@staticmethod
176
181
def _default_model_to_type_name (model : Type [BaseModelType ]) -> str :
177
182
return model .__name__
@@ -211,6 +216,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
211
216
Get or create a corresponding Edge model for the given type
212
217
(to support future pagination)
213
218
"""
219
+ if self .edge_type is not None :
220
+ return self .edge_type
214
221
edge_name = f"{ type_name } Edge"
215
222
if edge_name not in self .edge_types :
216
223
self .edge_types [edge_name ] = edge_type = strawberry .type (
@@ -229,6 +236,8 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
229
236
Get or create a corresponding Connection model for the given type
230
237
(to support future pagination)
231
238
"""
239
+ if self .connection_type is not None :
240
+ return self .connection_type [ForwardRef (type_name )]
232
241
connection_name = f"{ type_name } Connection"
233
242
if connection_name not in self .connection_types :
234
243
self .connection_types [connection_name ] = connection_type = strawberry .type (
0 commit comments