From dd6722b676aadb2161f1abf6fea5a54cb498862d Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 10:50:50 +0530 Subject: [PATCH 01/12] add batching for sync base view --- strawberry/aiohttp/views.py | 4 +- strawberry/asgi/__init__.py | 4 +- strawberry/chalice/views.py | 4 +- strawberry/channels/handlers/http_handler.py | 12 +- strawberry/django/views.py | 4 +- strawberry/fastapi/router.py | 4 +- strawberry/flask/views.py | 4 +- strawberry/http/async_base_view.py | 123 +++++++++++++++---- strawberry/http/sync_base_view.py | 122 ++++++++++++++---- strawberry/quart/views.py | 6 +- strawberry/sanic/views.py | 5 +- 11 files changed, 227 insertions(+), 65 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index eee14dff5d..f1f836b526 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -210,7 +210,9 @@ async def get_context( return {"request": request, "response": response} # type: ignore def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: web.Response + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: web.Response, ) -> web.Response: sub_response.text = self.encode_json(response_data) sub_response.content_type = "application/json" diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 1a1845d39f..f2fbe217c2 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -205,7 +205,9 @@ async def render_graphql_ide(self, request: Request) -> Response: return HTMLResponse(self.graphql_ide_html) def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: Response + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: Response, ) -> Response: response = Response( self.encode_json(response_data), diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index 9d5c424402..cb3aa709fd 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -114,7 +114,9 @@ def get_context(self, request: Request, response: TemporalResponse) -> Context: return {"request": request, "response": response} # type: ignore def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: TemporalResponse, ) -> Response: status_code = 200 diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 7281f53cdf..a00702fdc3 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -5,13 +5,7 @@ import warnings from functools import cached_property from io import BytesIO -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing_extensions import TypeGuard, assert_never from urllib.parse import parse_qs @@ -186,7 +180,9 @@ def __init__( super().__init__(**kwargs) def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: TemporalResponse, ) -> ChannelsResponse: return ChannelsResponse( content=json.dumps(response_data).encode(), diff --git a/strawberry/django/views.py b/strawberry/django/views.py index c7b01f6d59..4245cc5906 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -163,7 +163,9 @@ def __init__( super().__init__(**kwargs) def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: HttpResponse, ) -> HttpResponseBase: data = self.encode_json(response_data) diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index f43c9ecb0c..a90f7aa0ce 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -274,7 +274,9 @@ async def get_sub_response(self, request: Request) -> Response: return self.temporal_response def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: Response + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: Response, ) -> Response: response = Response( self.encode_json(response_data), diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index f33a3e44e7..e09ed51f57 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -91,7 +91,9 @@ def __init__( self.graphql_ide = graphql_ide def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: Response + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: Response, ) -> Response: sub_response.set_data(self.encode_json(response_data)) # type: ignore diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index b73eb55181..91d9a7cb8f 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -126,6 +126,8 @@ class AsyncBaseHTTPView( BaseGraphQLWSHandler[Context, RootValue] ) + batch: bool + @property @abc.abstractmethod def allow_queries_via_get(self) -> bool: ... @@ -147,7 +149,9 @@ async def get_root_value( @abc.abstractmethod def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: SubResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: SubResponse, ) -> Response: ... @abc.abstractmethod @@ -178,8 +182,12 @@ async def create_websocket_response( ) -> WebSocketResponse: ... async def execute_operation( - self, request: Request, context: Context, root_value: Optional[RootValue] - ) -> Union[ExecutionResult, SubscriptionExecutionResult]: + self, + request: Request, + context: Context, + root_value: Optional[RootValue], + sub_response: SubResponse, + ) -> Union[ExecutionResult, list[ExecutionResult], SubscriptionExecutionResult]: request_adapter = self.request_adapter_class(request) try: @@ -197,6 +205,22 @@ async def execute_operation( assert self.schema + if isinstance(request_data, list): + # batch GraphQL requests + tasks = [ + self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + context=context, + root_value=root_value, + request_data=data, + ) + for data in request_data + ] + + return await asyncio.gather(*tasks) + if request_data.protocol == "multipart-subscription": return await self.schema.subscribe( request_data.query, # type: ignore @@ -206,15 +230,49 @@ async def execute_operation( operation_name=request_data.operation_name, ) - return await self.schema.execute( - request_data.query, + return await self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + context=context, root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - allowed_operation_types=allowed_operation_types, + request_data=request_data, ) + async def execute_single( + self, + request: Request, + request_adapter: AsyncHTTPRequestAdapter, + sub_response: SubResponse, + context: Context, + root_value: Optional[RootValue], + request_data: GraphQLRequestData, + ) -> ExecutionResult: + allowed_operation_types = OperationType.from_http(request_adapter.method) + + if not self.allow_queries_via_get and request_adapter.method == "GET": + allowed_operation_types = allowed_operation_types - {OperationType.QUERY} + + assert self.schema + + try: + result = await self.schema.execute( + request_data.query, + root_value=root_value, + variable_values=request_data.variables, + context_value=context, + operation_name=request_data.operation_name, + allowed_operation_types=allowed_operation_types, + ) + except InvalidOperationTypeError as e: + raise HTTPException( + 400, e.as_http_error_reason(request_adapter.method) + ) from e + except MissingQueryError as e: + raise HTTPException(400, "No GraphQL query found in the request") from e + + return result + async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> dict[str, str]: try: form_data = await request.get_form_data() @@ -326,16 +384,12 @@ async def run( return await self.render_graphql_ide(request) raise HTTPException(404, "Not Found") - try: - result = await self.execute_operation( - request=request, context=context, root_value=root_value - ) - except InvalidOperationTypeError as e: - raise HTTPException( - 400, e.as_http_error_reason(request_adapter.method) - ) from e - except MissingQueryError as e: - raise HTTPException(400, "No GraphQL query found in the request") from e + result = await self.execute_operation( + request=request, + context=context, + root_value=root_value, + sub_response=sub_response, + ) if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) @@ -350,10 +404,20 @@ async def run( }, ) - response_data = await self.process_result(request=request, result=result) + if isinstance(result, list): + response_data = [] + for execution_result in result: + result = await self.process_result( + request=request, result=execution_result + ) + if execution_result.errors: + self._handle_errors(execution_result.errors, result) + response_data.append(result) + else: + response_data = await self.process_result(request=request, result=result) - if result.errors: - self._handle_errors(result.errors, response_data) + if result.errors: + self._handle_errors(result.errors, response_data) return self.create_response( response_data=response_data, sub_response=sub_response @@ -449,7 +513,7 @@ async def parse_multipart_subscriptions( async def parse_http_body( self, request: AsyncHTTPRequestAdapter - ) -> GraphQLRequestData: + ) -> Union[GraphQLRequestData, list[GraphQLRequestData]]: headers = {key.lower(): value for key, value in request.headers.items()} content_type, _ = parse_content_type(request.content_type or "") accept = headers.get("accept", "") @@ -468,6 +532,19 @@ async def parse_http_body( else: raise HTTPException(400, "Unsupported content type") + if isinstance(data, list): + if protocol == "multipart-subscription" or not self.batch: + # note: multipart-subscriptions are not supported in batch requests + raise HTTPException(400, "Batch requests are not supported") + return [ + GraphQLRequestData( + query=item.get("query"), + variables=item.get("variables"), + operation_name=item.get("operationName"), + protocol=protocol, + ) + for item in data + ] return GraphQLRequestData( query=data.get("query"), variables=data.get("variables"), diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 149d4b50e6..b5b8f4e455 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -72,6 +72,8 @@ class SyncBaseHTTPView( graphql_ide: Optional[GraphQL_IDE] request_adapter_class: Callable[[Request], SyncHTTPRequestAdapter] + batch: bool = False + # Methods that need to be implemented by individual frameworks @property @@ -89,15 +91,21 @@ def get_root_value(self, request: Request) -> Optional[RootValue]: ... @abc.abstractmethod def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: SubResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: SubResponse, ) -> Response: ... @abc.abstractmethod def render_graphql_ide(self, request: Request) -> Response: ... def execute_operation( - self, request: Request, context: Context, root_value: Optional[RootValue] - ) -> ExecutionResult: + self, + request: Request, + context: Context, + root_value: Optional[RootValue], + sub_response: SubResponse, + ) -> Union[ExecutionResult, list[ExecutionResult]]: request_adapter = self.request_adapter_class(request) try: @@ -115,15 +123,63 @@ def execute_operation( assert self.schema - return self.schema.execute_sync( - request_data.query, + if isinstance(request_data, list): + # batch GraphQL requests + return [ + self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + context=context, + root_value=root_value, + request_data=data, + ) + for data in request_data + ] + + return self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + context=context, root_value=root_value, - variable_values=request_data.variables, - context_value=context, - operation_name=request_data.operation_name, - allowed_operation_types=allowed_operation_types, + request_data=request_data, ) + def execute_single( + self, + request: Request, + request_adapter: SyncHTTPRequestAdapter, + sub_response: SubResponse, + context: Context, + root_value: Optional[RootValue], + request_data: GraphQLRequestData, + ) -> ExecutionResult: + allowed_operation_types = OperationType.from_http(request_adapter.method) + + if not self.allow_queries_via_get and request_adapter.method == "GET": + allowed_operation_types = allowed_operation_types - {OperationType.QUERY} + + assert self.schema + + try: + result = self.schema.execute_sync( + request_data.query, + root_value=root_value, + variable_values=request_data.variables, + context_value=context, + operation_name=request_data.operation_name, + allowed_operation_types=allowed_operation_types, + ) + except InvalidOperationTypeError as e: + raise HTTPException( + 400, e.as_http_error_reason(request_adapter.method) + ) from e + except MissingQueryError as e: + raise HTTPException(400, "No GraphQL query found in the request") from e + + return result + def parse_multipart(self, request: SyncHTTPRequestAdapter) -> dict[str, str]: operations = self.parse_json(request.post_data.get("operations", "{}")) files_map = self.parse_json(request.post_data.get("map", "{}")) @@ -133,7 +189,9 @@ def parse_multipart(self, request: SyncHTTPRequestAdapter) -> dict[str, str]: except KeyError as e: raise HTTPException(400, "File(s) missing in form data") from e - def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData: + def parse_http_body( + self, request: SyncHTTPRequestAdapter + ) -> Union[GraphQLRequestData, list[GraphQLRequestData]]: content_type, params = parse_content_type(request.content_type or "") if request.method == "GET": @@ -150,6 +208,18 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData else: raise HTTPException(400, "Unsupported content type") + if isinstance(data, list): + if not self.batch: + # note: multipart-subscriptions are not supported in batch requests + raise HTTPException(400, "Batch requests are not supported") + return [ + GraphQLRequestData( + query=item.get("query"), + variables=item.get("variables"), + operation_name=item.get("operationName"), + ) + for item in data + ] return GraphQLRequestData( query=data.get("query"), variables=data.get("variables"), @@ -187,23 +257,25 @@ def run( assert context - try: - result = self.execute_operation( - request=request, - context=context, - root_value=root_value, - ) - except InvalidOperationTypeError as e: - raise HTTPException( - 400, e.as_http_error_reason(request_adapter.method) - ) from e - except MissingQueryError as e: - raise HTTPException(400, "No GraphQL query found in the request") from e + result = self.execute_operation( + request=request, + context=context, + root_value=root_value, + sub_response=sub_response, + ) - response_data = self.process_result(request=request, result=result) + if isinstance(result, list): + response_data = [] + for execution_result in result: + result = self.process_result(request=request, result=execution_result) + if execution_result.errors: + self._handle_errors(execution_result.errors, result) + response_data.append(result) + else: + response_data = self.process_result(request=request, result=result) - if result.errors: - self._handle_errors(result.errors, response_data) + if result.errors: + self._handle_errors(result.errors, response_data) return self.create_response( response_data=response_data, sub_response=sub_response diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 3a1ff28058..ac856ddc8f 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings from collections.abc import AsyncGenerator, Mapping -from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast +from typing import TYPE_CHECKING, Callable, ClassVar, Optional, Union, cast from typing_extensions import TypeGuard from quart import Request, Response, request @@ -82,7 +82,9 @@ async def render_graphql_ide(self, request: Request) -> Response: return Response(self.graphql_ide_html) def create_response( - self, response_data: "GraphQLHTTPResponse", sub_response: Response + self, + response_data: Union["GraphQLHTTPResponse", list["GraphQLHTTPResponse"]], + sub_response: Response, ) -> Response: sub_response.set_data(self.encode_json(response_data)) diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 9044df2b0a..1f7888201c 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -7,6 +7,7 @@ Any, Callable, Optional, + Union, cast, ) from typing_extensions import TypeGuard @@ -158,7 +159,9 @@ async def get_sub_response(self, request: Request) -> TemporalResponse: return TemporalResponse() def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: TemporalResponse, ) -> HTTPResponse: status_code = sub_response.status_code From a0c3d29cbc1063eea0a14bfc9613bed150fbcc10 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 11:01:16 +0530 Subject: [PATCH 02/12] pass batch config from views --- strawberry/aiohttp/views.py | 2 ++ strawberry/asgi/__init__.py | 2 ++ strawberry/chalice/views.py | 2 ++ strawberry/channels/handlers/http_handler.py | 4 ++++ strawberry/django/views.py | 2 ++ strawberry/fastapi/router.py | 1 + strawberry/flask/views.py | 2 ++ strawberry/quart/views.py | 1 + strawberry/sanic/views.py | 3 +++ 9 files changed, 19 insertions(+) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index f1f836b526..0d94a6e1fd 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -150,6 +150,7 @@ def __init__( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -159,6 +160,7 @@ def __init__( self.subscription_protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index f2fbe217c2..407d880ca4 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -145,6 +145,7 @@ def __init__( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -154,6 +155,7 @@ def __init__( self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index cb3aa709fd..5c758d6e9b 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -63,9 +63,11 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + batch: bool = False, ) -> None: self.allow_queries_via_get = allow_queries_via_get self.schema = schema + self.batch = batch if graphiql is not None: warnings.warn( "The `graphiql` argument is deprecated in favor of `graphql_ide`", diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index a00702fdc3..31b5846778 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -161,11 +161,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, + batch: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( @@ -252,6 +254,7 @@ class GraphQLHTTPConsumer( ``` """ + batch: bool = False allow_queries_via_get: bool = True request_adapter_class = ChannelsRequestAdapter @@ -325,6 +328,7 @@ class SyncGraphQLHTTPConsumer( synchronous and not asynchronous). """ + batch: bool = False allow_queries_via_get: bool = True request_adapter_class = SyncChannelsRequestAdapter diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 4245cc5906..b27742a137 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -217,6 +217,7 @@ class GraphQLView( allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = DjangoHTTPRequestAdapter + batch: bool = False def get_root_value(self, request: HttpRequest) -> Optional[RootValue]: return None @@ -265,6 +266,7 @@ class AsyncGraphQLView( allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = AsyncDjangoHTTPRequestAdapter + batch: bool = False @classonlymethod # pyright: ignore[reportIncompatibleMethodOverride] def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N805 diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index a90f7aa0ce..0c7abdd698 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -59,6 +59,7 @@ class GraphQLRouter( allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter websocket_adapter_class = ASGIWebSocketAdapter + batch: bool = False @staticmethod async def __get_root_value() -> None: diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index e09ed51f57..36068e83c9 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -108,6 +108,7 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = FlaskHTTPRequestAdapter + batch: bool = False def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore @@ -171,6 +172,7 @@ class AsyncGraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = AsyncFlaskHTTPRequestAdapter + batch: bool = False async def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index ac856ddc8f..e52736743f 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -55,6 +55,7 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter + batch: bool = False def __init__( self, diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 1f7888201c..643b23a253 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -101,6 +101,7 @@ class GraphQLView( allow_queries_via_get = True request_adapter_class = SanicHTTPRequestAdapter + batch: bool = False def __init__( self, @@ -111,12 +112,14 @@ def __init__( json_encoder: Optional[type[json.JSONEncoder]] = None, json_dumps_params: Optional[dict[str, Any]] = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.json_encoder = json_encoder self.json_dumps_params = json_dumps_params self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if self.json_encoder is not None: # pragma: no cover warnings.warn( From c8c0d7c8facf61d8b46d7beec875f65aaeed9ac2 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 11:34:58 +0530 Subject: [PATCH 03/12] pass basic batching tests --- strawberry/asgi/__init__.py | 1 + strawberry/fastapi/router.py | 2 ++ strawberry/flask/views.py | 2 ++ strawberry/http/async_base_view.py | 8 ++++-- strawberry/http/sync_base_view.py | 3 +-- strawberry/litestar/controller.py | 7 +++++- strawberry/quart/views.py | 2 ++ tests/http/clients/aiohttp.py | 2 ++ tests/http/clients/asgi.py | 2 ++ tests/http/clients/async_django.py | 1 + tests/http/clients/async_flask.py | 2 ++ tests/http/clients/base.py | 3 ++- tests/http/clients/chalice.py | 2 ++ tests/http/clients/channels.py | 4 +++ tests/http/clients/django.py | 3 +++ tests/http/clients/fastapi.py | 2 ++ tests/http/clients/flask.py | 2 ++ tests/http/clients/litestar.py | 2 ++ tests/http/clients/quart.py | 2 ++ tests/http/clients/sanic.py | 2 ++ tests/http/test_query_batching.py | 39 ++++++++++++++++++++++++++++++ 21 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 tests/http/test_query_batching.py diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 407d880ca4..35ceed082f 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -129,6 +129,7 @@ class GraphQL( allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter websocket_adapter_class = ASGIWebSocketAdapter + batch: bool = False def __init__( self, diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 0c7abdd698..5dd16b738f 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -152,6 +152,7 @@ def __init__( generate_unique_id ), multipart_uploads_enabled: bool = False, + batch: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -187,6 +188,7 @@ def __init__( self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index 36068e83c9..e6074c979e 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -74,11 +74,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> None: self.schema = schema self.graphiql = graphiql self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 91d9a7cb8f..5f15da7262 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -533,9 +533,13 @@ async def parse_http_body( raise HTTPException(400, "Unsupported content type") if isinstance(data, list): - if protocol == "multipart-subscription" or not self.batch: + if protocol == "multipart-subscription": # note: multipart-subscriptions are not supported in batch requests - raise HTTPException(400, "Batch requests are not supported") + raise HTTPException( + 400, "Batching is not supported for multipart subscriptions" + ) + if not self.batch: + raise HTTPException(400, "Batching is not enabled") return [ GraphQLRequestData( query=item.get("query"), diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index b5b8f4e455..c8ecaed81d 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -210,8 +210,7 @@ def parse_http_body( if isinstance(data, list): if not self.batch: - # note: multipart-subscriptions are not supported in batch requests - raise HTTPException(400, "Batch requests are not supported") + raise HTTPException(400, "Batching is not enabled") return [ GraphQLRequestData( query=item.get("query"), diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 1edf2d5a89..a383597492 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -259,6 +259,7 @@ class GraphQLController( ) keep_alive: bool = False keep_alive_interval: float = 1 + batch: bool = False def is_websocket_request( self, request: Union[Request, WebSocket] @@ -302,7 +303,9 @@ async def render_graphql_ide( return Response(self.graphql_ide_html, media_type=MediaType.HTML) def create_response( - self, response_data: GraphQLHTTPResponse, sub_response: Response[bytes] + self, + response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]], + sub_response: Response[bytes], ) -> Response[bytes]: response = Response( self.encode_json(response_data).encode(), @@ -417,6 +420,7 @@ def make_graphql_controller( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> type[GraphQLController]: # sourcery skip: move-assign if context_getter is None: custom_context_getter_ = _none_custom_context_getter @@ -464,6 +468,7 @@ class _GraphQLController(GraphQLController): _GraphQLController.allow_queries_via_get = allow_queries_via_get_ _GraphQLController.graphql_ide = graphql_ide_ _GraphQLController.multipart_uploads_enabled = multipart_uploads_enabled + _GraphQLController.batch = batch return _GraphQLController diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index e52736743f..2136b9a481 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -64,10 +64,12 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, + batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch if graphiql is not None: warnings.warn( diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 07f588c349..e991af95f9 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -64,6 +64,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): view = GraphQLView( schema=schema, @@ -72,6 +73,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) view.result_override = result_override diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index a354dcb935..96a511ff8d 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -66,6 +66,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): view = GraphQLView( schema, @@ -74,6 +75,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) view.result_override = result_override diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 870d27f6ed..c1165472c9 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -47,6 +47,7 @@ async def _do_request(self, request: HttpRequest) -> Response: allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, multipart_uploads_enabled=self.multipart_uploads_enabled, + batch=self.batch, ) try: diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index 1ad8cb0356..c8aef354a5 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -53,6 +53,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -65,6 +66,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 786ec4f8bd..60453a1fc8 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -102,6 +102,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): ... @abc.abstractmethod @@ -135,7 +136,7 @@ async def post( self, url: str, data: Optional[bytes] = None, - json: Optional[JSON] = None, + json: Optional[Union[JSON, list[JSON]]] = None, headers: Optional[dict[str, str]] = None, ) -> Response: ... diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index 3ea6189326..f1d9ca8b8b 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -51,6 +51,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = Chalice(app_name="TheStackBadger") @@ -59,6 +60,7 @@ def __init__( graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, + batch=batch, ) view.result_override = result_override diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 4a89324a70..0ce2832f2e 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -142,6 +142,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.ws_app = DebuggableGraphQLWSConsumer.as_asgi( schema=schema, @@ -155,6 +156,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) def create_app(self, **kwargs: Any) -> None: @@ -266,6 +268,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( schema=schema, @@ -274,6 +277,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 1a2301ad07..8bbd9966fd 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -51,12 +51,14 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.graphiql = graphiql self.graphql_ide = graphql_ide self.allow_queries_via_get = allow_queries_via_get self.result_override = result_override self.multipart_uploads_enabled = multipart_uploads_enabled + self.batch = batch def _get_header_name(self, key: str) -> str: return f"HTTP_{key.upper().replace('-', '_')}" @@ -80,6 +82,7 @@ async def _do_request(self, request: HttpRequest) -> Response: allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, multipart_uploads_enabled=self.multipart_uploads_enabled, + batch=self.batch, ) try: diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index c5a8da97da..edfd6630e8 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -76,6 +76,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = FastAPI() @@ -88,6 +89,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) graphql_app.result_override = result_override self.app.include_router(graphql_app, prefix="/graphql") diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 644f30095c..090599008d 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -62,6 +62,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -74,6 +75,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 48c4f0703d..9afcb63474 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -51,6 +51,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.create_app( graphiql=graphiql, @@ -58,6 +59,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: Any): diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 1711e58b45..7853358b49 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -55,6 +55,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = Quart(__name__) self.app.debug = True @@ -67,6 +68,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index f43b324afe..05d5f4e904 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -54,6 +54,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, + batch: bool = False, ): self.app = Sanic( f"test_{int(randint(0, 1000))}", # noqa: S311 @@ -65,6 +66,7 @@ def __init__( allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, + batch=batch, ) self.app.add_route( view, diff --git a/tests/http/test_query_batching.py b/tests/http/test_query_batching.py new file mode 100644 index 0000000000..6128e8441e --- /dev/null +++ b/tests/http/test_query_batching.py @@ -0,0 +1,39 @@ +from .clients.base import HttpClient + + +async def test_batch_graphql_query(http_client_class: type[HttpClient]): + http_client = http_client_class(batch=True) + + response = await http_client.post( + url="/graphql", + json=[ + {"query": "{ hello }"}, + {"query": "{ hello }"}, + ], + headers={"content-type": "application/json"}, + ) + + assert response.status_code == 200 + assert response.json == [ + {"data": {"hello": "Hello world"}, "extensions": {"example": "example"}}, + {"data": {"hello": "Hello world"}, "extensions": {"example": "example"}}, + ] + + +async def test_returns_error_when_batching_is_disabled( + http_client_class: type[HttpClient], +): + http_client = http_client_class(batch=False) + + response = await http_client.post( + url="/graphql", + json=[ + {"query": "{ hello }"}, + {"query": "{ hello }"}, + ], + headers={"content-type": "application/json"}, + ) + + assert response.status_code == 400 + + assert "Batching is not enabled" in response.text From 88f1961cda831c12928de5deed773d3ed9a875b6 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 11:46:22 +0530 Subject: [PATCH 04/12] add test case: ensure batching is not supported for multipart subscriptions --- tests/http/test_query_batching.py | 75 +++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/http/test_query_batching.py b/tests/http/test_query_batching.py index 6128e8441e..06f311d808 100644 --- a/tests/http/test_query_batching.py +++ b/tests/http/test_query_batching.py @@ -1,6 +1,61 @@ +import contextlib + +import pytest + from .clients.base import HttpClient +@pytest.fixture +def multipart_subscriptions_batch_http_client( + http_client_class: type[HttpClient], +) -> HttpClient: + with contextlib.suppress(ImportError): + import django + + if django.VERSION < (4, 2): + pytest.skip(reason="Django < 4.2 doesn't async streaming responses") + + from .clients.django import DjangoHttpClient + + if http_client_class is DjangoHttpClient: + pytest.skip( + reason="(sync) DjangoHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.channels import SyncChannelsHttpClient + + # TODO: why do we have a sync channels client? + if http_client_class is SyncChannelsHttpClient: + pytest.skip( + reason="SyncChannelsHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.async_flask import AsyncFlaskHttpClient + from .clients.flask import FlaskHttpClient + + if http_client_class is FlaskHttpClient: + pytest.skip( + reason="FlaskHttpClient doesn't support multipart subscriptions" + ) + + if http_client_class is AsyncFlaskHttpClient: + pytest.xfail( + reason="AsyncFlaskHttpClient doesn't support multipart subscriptions" + ) + + with contextlib.suppress(ImportError): + from .clients.chalice import ChaliceHttpClient + + if http_client_class is ChaliceHttpClient: + pytest.skip( + reason="ChaliceHttpClient doesn't support multipart subscriptions" + ) + + return http_client_class(batch=True) + + async def test_batch_graphql_query(http_client_class: type[HttpClient]): http_client = http_client_class(batch=True) @@ -37,3 +92,23 @@ async def test_returns_error_when_batching_is_disabled( assert response.status_code == 400 assert "Batching is not enabled" in response.text + + +async def test_returns_error_for_multipart_subscriptions( + multipart_subscriptions_batch_http_client: HttpClient, +): + response = await multipart_subscriptions_batch_http_client.post( + url="/graphql", + json=[ + {"query": 'subscription { echo(message: "Hello world", delay: 0.2) }'}, + {"query": 'subscription { echo(message: "Hello world", delay: 0.2) }'}, + ], + headers={ + "content-type": "application/json", + "accept": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + + assert response.status_code == 400 + + assert "Batching is not supported for multipart subscriptions" in response.text From 287f615fa4a378c7e97ca44d9f3edee2089a1cee Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 11:54:12 +0530 Subject: [PATCH 05/12] add RELEASE.md --- RELEASE.md | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..75b6c22b32 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,53 @@ +Release type: minor + +## Add GraphQL Query batching support + +GraphQL query batching is now supported across all frameworks (sync and async) +To enable query batching, set the `batch` parameter to True at the view level. + +This makes your GraphQL API compatible with batching features supported by various +client side libraries, such as [Apollo GraphQL](https://www.apollographql.com/docs/react/api/link/apollo-link-batch-http) and [Relay](https://github.com/relay-tools/react-relay-network-modern?tab=readme-ov-file#batching-several-requests-into-one). + +Example (FastAPI): + +```py +import strawberry + +from fastapi import FastAPI +from strawberry.fastapi import GraphQLRouter + + +@strawberry.type +class Query: + @strawberry.field + def hello(self) -> str: + return "Hello World" + + +schema = strawberry.Schema(Query) + +graphql_app = GraphQLRouter(schema, batch=True) + +app = FastAPI() +app.include_router(graphql_app, prefix="/graphql/batch") +``` + +Example (Flask): +```py +from flask import Flask +from strawberry.flask.views import GraphQLView + +from api.schema import schema + +app = Flask(__name__) + +app.add_url_rule( + "/graphql/batch", + view_func=GraphQLView.as_view("graphql_view", schema=schema, batch=True), +) + +if __name__ == "__main__": + app.run() +``` + +Note: Query Batching is not supported for multipart subscriptions From a15a4f7dcfbca618d22bbf5c9c0e654256bb09ee Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:13:15 +0530 Subject: [PATCH 06/12] pass batch config via schema config --- strawberry/aiohttp/views.py | 2 -- strawberry/asgi/__init__.py | 3 --- strawberry/chalice/views.py | 2 -- strawberry/channels/handlers/http_handler.py | 4 ---- strawberry/django/apps.py | 2 +- strawberry/django/views.py | 2 -- strawberry/fastapi/router.py | 3 --- strawberry/flask/views.py | 4 ---- strawberry/http/async_base_view.py | 15 ++++++--------- strawberry/http/base.py | 20 ++++++++++++++++++++ strawberry/http/sync_base_view.py | 10 ++++++---- strawberry/litestar/controller.py | 3 --- strawberry/quart/views.py | 3 --- strawberry/sanic/views.py | 3 --- strawberry/schema/config.py | 12 +++++++++++- tests/http/clients/aiohttp.py | 11 ++++++----- tests/http/clients/asgi.py | 11 ++++++----- tests/http/clients/async_django.py | 5 ++--- tests/http/clients/async_flask.py | 8 ++++---- tests/http/clients/base.py | 7 ++++--- tests/http/clients/chalice.py | 8 ++++---- tests/http/clients/channels.py | 13 ++++++------- tests/http/clients/django.py | 10 +++++----- tests/http/clients/fastapi.py | 11 ++++++----- tests/http/clients/flask.py | 8 ++++---- tests/http/clients/litestar.py | 9 +++++---- tests/http/clients/quart.py | 8 ++++---- tests/http/clients/sanic.py | 8 ++++---- tests/http/test_query_batching.py | 14 +++++++++++--- tests/views/schema.py | 15 +++++++++------ 30 files changed, 124 insertions(+), 110 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 0d94a6e1fd..f1f836b526 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -150,7 +150,6 @@ def __init__( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -160,7 +159,6 @@ def __init__( self.subscription_protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 35ceed082f..f2fbe217c2 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -129,7 +129,6 @@ class GraphQL( allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter websocket_adapter_class = ASGIWebSocketAdapter - batch: bool = False def __init__( self, @@ -146,7 +145,6 @@ def __init__( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -156,7 +154,6 @@ def __init__( self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index 5c758d6e9b..cb3aa709fd 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -63,11 +63,9 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, - batch: bool = False, ) -> None: self.allow_queries_via_get = allow_queries_via_get self.schema = schema - self.batch = batch if graphiql is not None: warnings.warn( "The `graphiql` argument is deprecated in favor of `graphql_ide`", diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 31b5846778..a00702fdc3 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -161,13 +161,11 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, - batch: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( @@ -254,7 +252,6 @@ class GraphQLHTTPConsumer( ``` """ - batch: bool = False allow_queries_via_get: bool = True request_adapter_class = ChannelsRequestAdapter @@ -328,7 +325,6 @@ class SyncGraphQLHTTPConsumer( synchronous and not asynchronous). """ - batch: bool = False allow_queries_via_get: bool = True request_adapter_class = SyncChannelsRequestAdapter diff --git a/strawberry/django/apps.py b/strawberry/django/apps.py index 89297bd63a..d46d9b5961 100644 --- a/strawberry/django/apps.py +++ b/strawberry/django/apps.py @@ -1,5 +1,5 @@ from django.apps import AppConfig # pragma: no cover -class StrawberryConfig(AppConfig): # pragma: no cover +class StrawberryAppConfig(AppConfig): # pragma: no cover name = "strawberry" diff --git a/strawberry/django/views.py b/strawberry/django/views.py index b27742a137..4245cc5906 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -217,7 +217,6 @@ class GraphQLView( allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = DjangoHTTPRequestAdapter - batch: bool = False def get_root_value(self, request: HttpRequest) -> Optional[RootValue]: return None @@ -266,7 +265,6 @@ class AsyncGraphQLView( allow_queries_via_get = True schema: BaseSchema = None # type: ignore request_adapter_class = AsyncDjangoHTTPRequestAdapter - batch: bool = False @classonlymethod # pyright: ignore[reportIncompatibleMethodOverride] def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N805 diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 5dd16b738f..a90f7aa0ce 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -59,7 +59,6 @@ class GraphQLRouter( allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter websocket_adapter_class = ASGIWebSocketAdapter - batch: bool = False @staticmethod async def __get_root_value() -> None: @@ -152,7 +151,6 @@ def __init__( generate_unique_id ), multipart_uploads_enabled: bool = False, - batch: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -188,7 +186,6 @@ def __init__( self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index e6074c979e..e09ed51f57 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -74,13 +74,11 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> None: self.schema = schema self.graphiql = graphiql self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( @@ -110,7 +108,6 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = FlaskHTTPRequestAdapter - batch: bool = False def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore @@ -174,7 +171,6 @@ class AsyncGraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = AsyncFlaskHTTPRequestAdapter - batch: bool = False async def get_context(self, request: Request, response: Response) -> Context: return {"request": request, "response": response} # type: ignore diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 5f15da7262..d1f1d941da 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -126,8 +126,6 @@ class AsyncBaseHTTPView( BaseGraphQLWSHandler[Context, RootValue] ) - batch: bool - @property @abc.abstractmethod def allow_queries_via_get(self) -> bool: ... @@ -533,13 +531,7 @@ async def parse_http_body( raise HTTPException(400, "Unsupported content type") if isinstance(data, list): - if protocol == "multipart-subscription": - # note: multipart-subscriptions are not supported in batch requests - raise HTTPException( - 400, "Batching is not supported for multipart subscriptions" - ) - if not self.batch: - raise HTTPException(400, "Batching is not enabled") + await self.validate_batch_request(data, protocol=protocol) return [ GraphQLRequestData( query=item.get("query"), @@ -556,6 +548,11 @@ async def parse_http_body( protocol=protocol, ) + async def validate_batch_request( + self, request_data: list[GraphQLRequestData], protocol: str + ) -> None: + self._validate_batch_request(request_data=request_data, protocol=protocol) + async def process_result( self, request: Request, result: ExecutionResult ) -> GraphQLHTTPResponse: diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 1cb1904888..b5cae15c63 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -3,8 +3,10 @@ from typing import Any, Generic, Optional, Union from typing_extensions import Protocol +from strawberry.http import GraphQLRequestData from strawberry.http.ides import GraphQL_IDE, get_graphql_ide_html from strawberry.http.types import HTTPMethod, QueryParams +from strawberry.schema.base import BaseSchema from .exceptions import HTTPException from .typevars import Request @@ -24,6 +26,7 @@ def headers(self) -> Mapping[str, str]: ... class BaseView(Generic[Request]): graphql_ide: Optional[GraphQL_IDE] multipart_uploads_enabled: bool = False + schema: BaseSchema def should_render_graphql_ide(self, request: BaseRequestProtocol) -> bool: return ( @@ -76,5 +79,22 @@ def _is_multipart_subscriptions( return params.get("subscriptionspec", "").startswith("1.0") + def _validate_batch_request( + self, request_data: list[GraphQLRequestData], protocol: str + ) -> None: + if self.schema.config.batching_config["enabled"] is False: + raise HTTPException(400, "Batching is not enabled") + + if protocol == "multipart-subscription": + # note: multipart-subscriptions are not supported in batch requests + raise HTTPException( + 400, "Batching is not supported for multipart subscriptions" + ) + + if len(request_data) > self.schema.config.batching_config.get( + "max_operations", 3 + ): + raise HTTPException(400, "Too many operations") + __all__ = ["BaseView"] diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index c8ecaed81d..532ca79b7e 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -72,8 +72,6 @@ class SyncBaseHTTPView( graphql_ide: Optional[GraphQL_IDE] request_adapter_class: Callable[[Request], SyncHTTPRequestAdapter] - batch: bool = False - # Methods that need to be implemented by individual frameworks @property @@ -209,8 +207,7 @@ def parse_http_body( raise HTTPException(400, "Unsupported content type") if isinstance(data, list): - if not self.batch: - raise HTTPException(400, "Batching is not enabled") + self.validate_batch_request(data, protocol="http") return [ GraphQLRequestData( query=item.get("query"), @@ -225,6 +222,11 @@ def parse_http_body( operation_name=data.get("operationName"), ) + def validate_batch_request( + self, request_data: list[GraphQLRequestData], protocol: str + ) -> None: + self._validate_batch_request(request_data=request_data, protocol=protocol) + def _handle_errors( self, errors: list[GraphQLError], response_data: GraphQLHTTPResponse ) -> None: diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index a383597492..693fb25fbf 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -259,7 +259,6 @@ class GraphQLController( ) keep_alive: bool = False keep_alive_interval: float = 1 - batch: bool = False def is_websocket_request( self, request: Union[Request, WebSocket] @@ -420,7 +419,6 @@ def make_graphql_controller( ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> type[GraphQLController]: # sourcery skip: move-assign if context_getter is None: custom_context_getter_ = _none_custom_context_getter @@ -468,7 +466,6 @@ class _GraphQLController(GraphQLController): _GraphQLController.allow_queries_via_get = allow_queries_via_get_ _GraphQLController.graphql_ide = graphql_ide_ _GraphQLController.multipart_uploads_enabled = multipart_uploads_enabled - _GraphQLController.batch = batch return _GraphQLController diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 2136b9a481..ac856ddc8f 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -55,7 +55,6 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter - batch: bool = False def __init__( self, @@ -64,12 +63,10 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if graphiql is not None: warnings.warn( diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 643b23a253..1f7888201c 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -101,7 +101,6 @@ class GraphQLView( allow_queries_via_get = True request_adapter_class = SanicHTTPRequestAdapter - batch: bool = False def __init__( self, @@ -112,14 +111,12 @@ def __init__( json_encoder: Optional[type[json.JSONEncoder]] = None, json_dumps_params: Optional[dict[str, Any]] = None, multipart_uploads_enabled: bool = False, - batch: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.json_encoder = json_encoder self.json_dumps_params = json_dumps_params self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch if self.json_encoder is not None: # pragma: no cover warnings.warn( diff --git a/strawberry/schema/config.py b/strawberry/schema/config.py index 230ae7dc10..4fe74d7f68 100644 --- a/strawberry/schema/config.py +++ b/strawberry/schema/config.py @@ -1,13 +1,19 @@ from __future__ import annotations from dataclasses import InitVar, dataclass, field -from typing import Any, Callable +from typing import Any, Callable, TypedDict +from typing_extensions import Required from strawberry.types.info import Info from .name_converter import NameConverter +class BatchingConfig(TypedDict, total=False): + enabled: Required[bool] + max_operations: int + + @dataclass class StrawberryConfig: auto_camel_case: InitVar[bool] = None # pyright: reportGeneralTypeIssues=false @@ -17,6 +23,8 @@ class StrawberryConfig: disable_field_suggestions: bool = False info_class: type[Info] = Info + batching_config: BatchingConfig = None # type: ignore + def __post_init__( self, auto_camel_case: bool, @@ -26,6 +34,8 @@ def __post_init__( if not issubclass(self.info_class, Info): raise TypeError("`info_class` must be a subclass of strawberry.Info") + if self.batching_config is None: # type: ignore + self.batching_config = {"enabled": False} __all__ = ["StrawberryConfig"] diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index e991af95f9..4ef0f709b6 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -14,9 +14,10 @@ from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from tests.websockets.views import OnWSConnectMixin from .base import ( @@ -64,16 +65,16 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): + self.schema = get_schema(schema_config) view = GraphQLView( - schema=schema, + schema=self.schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) view.result_override = result_override @@ -85,7 +86,7 @@ def __init__( ) def create_app(self, **kwargs: Any) -> None: - view = GraphQLView(schema=schema, **kwargs) + view = GraphQLView(schema=self.schema, **kwargs) self.app = web.Application() self.app.router.add_route( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 96a511ff8d..282d4afaf6 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -15,9 +15,10 @@ from strawberry.asgi import GraphQL as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from tests.websockets.views import OnWSConnectMixin from .base import ( @@ -66,23 +67,23 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): + self.schema = get_schema(schema_config) view = GraphQLView( - schema, + self.schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) view.result_override = result_override self.client = TestClient(view) def create_app(self, **kwargs: Any) -> None: - view = GraphQLView(schema=schema, **kwargs) + view = GraphQLView(schema=self.schema, **kwargs) self.client = TestClient(view) async def _graphql_request( diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index c1165472c9..cb2274a0b0 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -9,7 +9,7 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import Response, ResultOverrideFunction from .django import DjangoHttpClient @@ -41,13 +41,12 @@ async def process_result( class AsyncDjangoHttpClient(DjangoHttpClient): async def _do_request(self, request: HttpRequest) -> Response: view = AsyncGraphQLView.as_view( - schema=schema, + schema=get_schema(self.schema_config), graphiql=self.graphiql, graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, multipart_uploads_enabled=self.multipart_uploads_enabled, - batch=self.batch, ) try: diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index c8aef354a5..ba24476fe1 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -8,9 +8,10 @@ from strawberry.flask.views import AsyncGraphQLView as BaseAsyncGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import ResultOverrideFunction from .flask import FlaskHttpClient @@ -53,20 +54,19 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = Flask(__name__) self.app.debug = True view = GraphQLView.as_view( "graphql_view", - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 60453a1fc8..e5ce3ce5b8 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -11,6 +11,7 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( BaseGraphQLTransportWSHandler, ) @@ -23,7 +24,7 @@ logger = logging.getLogger("strawberry.test.http_client") -JSON = dict[str, object] +JSON = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] ResultOverrideFunction = Optional[Callable[[ExecutionResult], GraphQLHTTPResponse]] @@ -102,7 +103,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): ... @abc.abstractmethod @@ -136,7 +137,7 @@ async def post( self, url: str, data: Optional[bytes] = None, - json: Optional[Union[JSON, list[JSON]]] = None, + json: Optional[JSON] = None, headers: Optional[dict[str, str]] = None, ) -> Response: ... diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index f1d9ca8b8b..132d1ff3ff 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -13,9 +13,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -51,16 +52,15 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = Chalice(app_name="TheStackBadger") view = GraphQLView( - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, - batch=batch, ) view.result_override = result_override diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 0ce2832f2e..2f9627684c 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -19,9 +19,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema, schema from tests.websockets.views import OnWSConnectMixin from .base import ( @@ -142,7 +143,7 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.ws_app = DebuggableGraphQLWSConsumer.as_asgi( schema=schema, @@ -150,13 +151,12 @@ def __init__( ) self.http_app = DebuggableGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) def create_app(self, **kwargs: Any) -> None: @@ -268,16 +268,15 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 8bbd9966fd..561466775f 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -13,9 +13,10 @@ from strawberry.django.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -51,14 +52,14 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): + self.schema_config = schema_config self.graphiql = graphiql self.graphql_ide = graphql_ide self.allow_queries_via_get = allow_queries_via_get self.result_override = result_override self.multipart_uploads_enabled = multipart_uploads_enabled - self.batch = batch def _get_header_name(self, key: str) -> str: return f"HTTP_{key.upper().replace('-', '_')}" @@ -76,13 +77,12 @@ def _get_headers( async def _do_request(self, request: HttpRequest) -> Response: view = GraphQLView.as_view( - schema=schema, + schema=get_schema(self.schema_config), graphiql=self.graphiql, graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, multipart_uploads_enabled=self.multipart_uploads_enabled, - batch=self.batch, ) try: diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index edfd6630e8..97c3dba57a 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -12,9 +12,10 @@ from strawberry.fastapi import GraphQLRouter as BaseGraphQLRouter from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from tests.websockets.views import OnWSConnectMixin from .asgi import AsgiWebSocketClient @@ -76,12 +77,13 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = FastAPI() + self.schema = get_schema(schema_config) graphql_app = GraphQLRouter( - schema, + self.schema, graphiql=graphiql, graphql_ide=graphql_ide, context_getter=fastapi_get_context, @@ -89,7 +91,6 @@ def __init__( allow_queries_via_get=allow_queries_via_get, keep_alive=False, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) graphql_app.result_override = result_override self.app.include_router(graphql_app, prefix="/graphql") @@ -98,7 +99,7 @@ def __init__( def create_app(self, **kwargs: Any) -> None: self.app = FastAPI() - graphql_app = GraphQLRouter(schema=schema, **kwargs) + graphql_app = GraphQLRouter(schema=self.schema, **kwargs) self.app.include_router(graphql_app, prefix="/graphql") self.client = TestClient(self.app) diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 090599008d..fce80e0353 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -15,9 +15,10 @@ from strawberry.flask.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -62,20 +63,19 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = Flask(__name__) self.app.debug = True view = GraphQLView.as_view( "graphql_view", - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 9afcb63474..b0e8faf4ab 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -14,9 +14,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.litestar import make_graphql_controller +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from tests.websockets.views import OnWSConnectMixin from .base import ( @@ -51,20 +52,20 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): + self.schema = get_schema(schema_config) self.create_app( graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: Any): BaseGraphQLController = make_graphql_controller( - schema=schema, + schema=self.schema, path="/graphql", context_getter=litestar_get_context, root_value_getter=get_root_value, diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 7853358b49..a69e19a353 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -11,9 +11,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.quart.views import GraphQLView as BaseGraphQLView +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -55,20 +56,19 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = Quart(__name__) self.app.debug = True view = GraphQLView.as_view( "graphql_view", - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) self.app.add_url_rule( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 05d5f4e904..c4ca5d776a 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -12,9 +12,10 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.http.temporal_response import TemporalResponse from strawberry.sanic.views import GraphQLView as BaseGraphQLView +from strawberry.schema.config import StrawberryConfig from strawberry.types import ExecutionResult from tests.http.context import get_context -from tests.views.schema import Query, schema +from tests.views.schema import Query, get_schema from .base import JSON, HttpClient, Response, ResultOverrideFunction @@ -54,19 +55,18 @@ def __init__( allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, - batch: bool = False, + schema_config: Optional[StrawberryConfig] = None, ): self.app = Sanic( f"test_{int(randint(0, 1000))}", # noqa: S311 ) view = GraphQLView.as_view( - schema=schema, + schema=get_schema(schema_config), graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, multipart_uploads_enabled=multipart_uploads_enabled, - batch=batch, ) self.app.add_route( view, diff --git a/tests/http/test_query_batching.py b/tests/http/test_query_batching.py index 06f311d808..eec174b43a 100644 --- a/tests/http/test_query_batching.py +++ b/tests/http/test_query_batching.py @@ -2,6 +2,8 @@ import pytest +from strawberry.schema.config import StrawberryConfig + from .clients.base import HttpClient @@ -53,11 +55,15 @@ def multipart_subscriptions_batch_http_client( reason="ChaliceHttpClient doesn't support multipart subscriptions" ) - return http_client_class(batch=True) + return http_client_class( + schema_config=StrawberryConfig(batching_config={"enabled": True}) + ) async def test_batch_graphql_query(http_client_class: type[HttpClient]): - http_client = http_client_class(batch=True) + http_client = http_client_class( + schema_config=StrawberryConfig(batching_config={"enabled": True}) + ) response = await http_client.post( url="/graphql", @@ -78,7 +84,9 @@ async def test_batch_graphql_query(http_client_class: type[HttpClient]): async def test_returns_error_when_batching_is_disabled( http_client_class: type[HttpClient], ): - http_client = http_client_class(batch=False) + http_client = http_client_class( + schema_config=StrawberryConfig(batching_config={"enabled": False}) + ) response = await http_client.post( url="/graphql", diff --git a/tests/views/schema.py b/tests/views/schema.py index 7cd69e581f..26c0f1486b 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -10,6 +10,7 @@ from strawberry.extensions import SchemaExtension from strawberry.file_uploads import Upload from strawberry.permission import BasePermission +from strawberry.schema.config import StrawberryConfig from strawberry.subscriptions.protocols.graphql_transport_ws.types import PingMessage from strawberry.types import ExecutionContext @@ -275,9 +276,11 @@ def process_errors( return super().process_errors(errors, execution_context) -schema = Schema( - query=Query, - mutation=Mutation, - subscription=Subscription, - extensions=[MyExtension], -) +def get_schema(config: Optional[StrawberryConfig] = None) -> strawberry.Schema: + return Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, + extensions=[MyExtension], + config=config, + ) From f592bec388183accc4cd5ab355ee8b9ef541c934 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:14:21 +0530 Subject: [PATCH 07/12] test too many operations --- tests/http/test_query_batching.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/http/test_query_batching.py b/tests/http/test_query_batching.py index eec174b43a..b35dd43ecb 100644 --- a/tests/http/test_query_batching.py +++ b/tests/http/test_query_batching.py @@ -120,3 +120,27 @@ async def test_returns_error_for_multipart_subscriptions( assert response.status_code == 400 assert "Batching is not supported for multipart subscriptions" in response.text + + +async def test_returns_error_when_trying_too_many_operations( + http_client_class: type[HttpClient], +): + http_client = http_client_class( + schema_config=StrawberryConfig( + batching_config={"enabled": True, "max_operations": 2} + ) + ) + + response = await http_client.post( + url="/graphql", + json=[ + {"query": "{ hello }"}, + {"query": "{ hello }"}, + {"query": "{ hello }"}, + ], + headers={"content-type": "application/json"}, + ) + + assert response.status_code == 400 + + assert "Too many operations" in response.text From 9ba0109fa16e03ad41b47cbcd2127cc2d11eeca5 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:17:41 +0530 Subject: [PATCH 08/12] update RELEASE.md --- RELEASE.md | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 75b6c22b32..08e2a82872 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,7 +3,7 @@ Release type: minor ## Add GraphQL Query batching support GraphQL query batching is now supported across all frameworks (sync and async) -To enable query batching, set the `batch` parameter to True at the view level. +To enable query batching, set `batching_config.enabled` to True in the schema configuration. This makes your GraphQL API compatible with batching features supported by various client side libraries, such as [Apollo GraphQL](https://www.apollographql.com/docs/react/api/link/apollo-link-batch-http) and [Relay](https://github.com/relay-tools/react-relay-network-modern?tab=readme-ov-file#batching-several-requests-into-one). @@ -15,6 +15,7 @@ import strawberry from fastapi import FastAPI from strawberry.fastapi import GraphQLRouter +from strawberry.schema.config import StrawberryConfig @strawberry.type @@ -24,26 +25,40 @@ class Query: return "Hello World" -schema = strawberry.Schema(Query) +schema = strawberry.Schema( + Query, config=StrawberryConfig(batching_config={"enabled": True}) +) -graphql_app = GraphQLRouter(schema, batch=True) +graphql_app = GraphQLRouter(schema) app = FastAPI() -app.include_router(graphql_app, prefix="/graphql/batch") +app.include_router(graphql_app, prefix="/graphql") ``` Example (Flask): ```py +import strawberry + from flask import Flask from strawberry.flask.views import GraphQLView -from api.schema import schema - app = Flask(__name__) + +@strawberry.type +class Query: + @strawberry.field + def hello(self) -> str: + return "Hello World" + + +schema = strawberry.Schema( + Query, config=StrawberryConfig(batching_config={"enabled": True}) +) + app.add_url_rule( "/graphql/batch", - view_func=GraphQLView.as_view("graphql_view", schema=schema, batch=True), + view_func=GraphQLView.as_view("graphql_view", schema=schema), ) if __name__ == "__main__": From f0d14ef28532c0a479615386a40306bd6d79362a Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:22:51 +0530 Subject: [PATCH 09/12] ensure multipart subscriptions work as usual when single request is passed in (even though batching is enabled) --- tests/http/test_query_batching.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/http/test_query_batching.py b/tests/http/test_query_batching.py index b35dd43ecb..03ae659ee2 100644 --- a/tests/http/test_query_batching.py +++ b/tests/http/test_query_batching.py @@ -122,6 +122,23 @@ async def test_returns_error_for_multipart_subscriptions( assert "Batching is not supported for multipart subscriptions" in response.text +async def test_single_multipart_subscription_works_without_batching( + multipart_subscriptions_batch_http_client: HttpClient, +): + response = await multipart_subscriptions_batch_http_client.post( + url="/graphql", + json={"query": 'subscription { echo(message: "Hello world", delay: 0.2) }'}, + headers={ + "content-type": "application/json", + "accept": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) + assert response.status_code == 200 + assert response.headers["content-type"].startswith( + "multipart/mixed;boundary=graphql" + ) + + async def test_returns_error_when_trying_too_many_operations( http_client_class: type[HttpClient], ): From 630882594d783ade340604ab17952b3cd9845a3c Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:39:07 +0530 Subject: [PATCH 10/12] add relevant documentation --- docs/guides/query-batching.md | 122 ++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 docs/guides/query-batching.md diff --git a/docs/guides/query-batching.md b/docs/guides/query-batching.md new file mode 100644 index 0000000000..996db7fe88 --- /dev/null +++ b/docs/guides/query-batching.md @@ -0,0 +1,122 @@ +--- +title: Query Batching +--- + +# Query Batching + +Query batching is a feature in Strawberry GraphQL that allows clients to send +multiple queries, mutations, or a combination of both in a single HTTP request. +This can help optimize network usage and improve performance for applications +that make frequent GraphQL requests. + +This document explains how to enable query batching, its configuration options, +and how to integrate it into your application with an example using FastAPI. + +--- + +## Enabling Query Batching + +To enable query batching in Strawberry, you need to configure the +`StrawberryConfig` when defining your GraphQL schema. The batching configuration +is provided as a dictionary with the key `enabled`: + +```python +from strawberry.schema.config import StrawberryConfig + +config = StrawberryConfig(batching_config={"enabled": True}) +``` + +When batching is enabled, the server can handle a list of operations +(queries/mutations) in a single request and return a list of responses. + +## Example Integration with FastAPI + +Query Batching is supported on all Strawberry GraphQL framework integrations. +Below is an example of how to enable query batching in a FastAPI application: + +```python +import strawberry +from fastapi import FastAPI +from strawberry.fastapi import GraphQLRouter +from strawberry.schema.config import StrawberryConfig + + +@strawberry.type +class Query: + @strawberry.field + def hello(self) -> str: + return "Hello World" + + +schema = strawberry.Schema( + Query, config=StrawberryConfig(batching_config={"enabled": True}) +) + +graphql_app = GraphQLRouter(schema) + +app = FastAPI() +app.include_router(graphql_app, prefix="/graphql") +``` + +### Running the Application + +1.Save the code in a file (e.g., app.py). 2. Start the FastAPI server: +`bash uvicorn app:app --reload ` 3.The GraphQL endpoint will be +available at http://127.0.0.1:8000/graphql. + +### Testing Query Batching + +You can test query batching by sending a single HTTP request with multiple +GraphQL operations. For example: + +#### Request + +```bash +curl -X POST -H "Content-Type: application/json" \ +-d '[{"query": "{ hello }"}, {"query": "{ hello }"}]' \ +http://127.0.0.1:8000/graphql +``` + +#### Response + +```json +[{ "data": { "hello": "Hello World" } }, { "data": { "hello": "Hello World" } }] +``` + +### Error Handling + +#### Batching Disabled + +If batching is not enabled in the server configuration and a batch request is +sent, the server will respond with a 400 status code and an error message: + +```json +{ + "error": "Batching is not enabled" +} +``` + +#### Too Many Operations + +If the number of operations in a batch exceeds the max_operations limit, the +server will return a 400 status code and an error message: + +```json +{ + "error": "Too many operations" +} +``` + +### Limitations + +#### Multipart Subscriptions: + +Query batching does not support multipart subscriptions. Attempting to batch +such operations will result in a 400 error with a relevant message. + +### Additional Notes + +Query batching is particularly useful for clients that need to perform multiple +operations simultaneously, reducing the overhead of multiple HTTP requests. +Ensure your client library supports query batching before enabling it on the +server. From eb6c9cb8e5ede04da484208ae5ea0d8086d0f639 Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Wed, 15 Jan 2025 13:43:15 +0530 Subject: [PATCH 11/12] add doc example of configuring max operations --- docs/guides/query-batching.md | 36 ++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/docs/guides/query-batching.md b/docs/guides/query-batching.md index 996db7fe88..b4c6a135b5 100644 --- a/docs/guides/query-batching.md +++ b/docs/guides/query-batching.md @@ -18,7 +18,12 @@ and how to integrate it into your application with an example using FastAPI. To enable query batching in Strawberry, you need to configure the `StrawberryConfig` when defining your GraphQL schema. The batching configuration -is provided as a dictionary with the key `enabled`: +is provided as a dictionary with the key `enabled`. + +You can also specify the maximum number of operations allowed in a batch request +using the `max_operations` key. + +### Basic Configuration ```python from strawberry.schema.config import StrawberryConfig @@ -26,6 +31,17 @@ from strawberry.schema.config import StrawberryConfig config = StrawberryConfig(batching_config={"enabled": True}) ``` +### Configuring Maximum Operations + +To set a limit on the number of operations in a batch request, use the +`max_operations` key: + +```python +from strawberry.schema.config import StrawberryConfig + +config = StrawberryConfig(batching_config={"enabled": True, "max_operations": 5}) +``` + When batching is enabled, the server can handle a list of operations (queries/mutations) in a single request and return a list of responses. @@ -49,7 +65,8 @@ class Query: schema = strawberry.Schema( - Query, config=StrawberryConfig(batching_config={"enabled": True}) + Query, + config=StrawberryConfig(batching_config={"enabled": True, "max_operations": 5}), ) graphql_app = GraphQLRouter(schema) @@ -60,9 +77,14 @@ app.include_router(graphql_app, prefix="/graphql") ### Running the Application -1.Save the code in a file (e.g., app.py). 2. Start the FastAPI server: -`bash uvicorn app:app --reload ` 3.The GraphQL endpoint will be -available at http://127.0.0.1:8000/graphql. +1. Save the code in a file (e.g., `app.py`). +2. Start the FastAPI server: + +```bash +uvicorn app:app --reload +``` + +3. The GraphQL endpoint will be available at `http://127.0.0.1:8000/graphql`. ### Testing Query Batching @@ -98,7 +120,7 @@ sent, the server will respond with a 400 status code and an error message: #### Too Many Operations -If the number of operations in a batch exceeds the max_operations limit, the +If the number of operations in a batch exceeds the `max_operations` limit, the server will return a 400 status code and an error message: ```json @@ -109,7 +131,7 @@ server will return a 400 status code and an error message: ### Limitations -#### Multipart Subscriptions: +#### Multipart Subscriptions Query batching does not support multipart subscriptions. Attempting to batch such operations will result in a 400 error with a relevant message. From a885b6763fa3fa0f431acfde52943804ec72441b Mon Sep 17 00:00:00 2001 From: Aryan Iyappan Date: Tue, 21 Jan 2025 15:25:35 +0530 Subject: [PATCH 12/12] make context sharing configurable --- strawberry/http/async_base_view.py | 36 +++++++++++++++++++++--------- strawberry/http/sync_base_view.py | 13 +++++++++++ strawberry/schema/config.py | 1 + 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index d1f1d941da..adedb6b970 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -205,17 +205,31 @@ async def execute_operation( if isinstance(request_data, list): # batch GraphQL requests - tasks = [ - self.execute_single( - request=request, - request_adapter=request_adapter, - sub_response=sub_response, - context=context, - root_value=root_value, - request_data=data, - ) - for data in request_data - ] + if not self.schema.config.batching_config["share_context"]: + tasks = [ + self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + # create a new context for each request data + context=await self.get_context(request, response=sub_response), + root_value=root_value, + request_data=data, + ) + for data in request_data + ] + else: + tasks = [ + self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + context=context, + root_value=root_value, + request_data=data, + ) + for data in request_data + ] return await asyncio.gather(*tasks) diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index 532ca79b7e..906b8fbe01 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -123,6 +123,19 @@ def execute_operation( if isinstance(request_data, list): # batch GraphQL requests + if not self.schema.config.batching_config["share_context"]: + return [ + self.execute_single( + request=request, + request_adapter=request_adapter, + sub_response=sub_response, + # create a new context for each request data + context=self.get_context(request, response=sub_response), + root_value=root_value, + request_data=data, + ) + for data in request_data + ] return [ self.execute_single( request=request, diff --git a/strawberry/schema/config.py b/strawberry/schema/config.py index 4fe74d7f68..c7e2720405 100644 --- a/strawberry/schema/config.py +++ b/strawberry/schema/config.py @@ -12,6 +12,7 @@ class BatchingConfig(TypedDict, total=False): enabled: Required[bool] max_operations: int + share_context: Required[bool] @dataclass