Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Query Batching Support #3755

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
53 changes: 53 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
Release type: minor

## Add GraphQL Query batching support

GraphQL query batching is now supported across all frameworks (sync and async)
To enable query batching, set the `batch` parameter to True at the view level.

This makes your GraphQL API compatible with batching features supported by various
client side libraries, such as [Apollo GraphQL](https://www.apollographql.com/docs/react/api/link/apollo-link-batch-http) and [Relay](https://github.com/relay-tools/react-relay-network-modern?tab=readme-ov-file#batching-several-requests-into-one).

Example (FastAPI):

```py
import strawberry

from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter


@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "Hello World"


schema = strawberry.Schema(Query)

graphql_app = GraphQLRouter(schema, batch=True)

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql/batch")
```

Example (Flask):
```py
from flask import Flask
from strawberry.flask.views import GraphQLView

from api.schema import schema

app = Flask(__name__)

app.add_url_rule(
"/graphql/batch",
view_func=GraphQLView.as_view("graphql_view", schema=schema, batch=True),
)

if __name__ == "__main__":
app.run()
```

Note: Query Batching is not supported for multipart subscriptions
6 changes: 5 additions & 1 deletion strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
multipart_uploads_enabled: bool = False,
batch: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
Expand All @@ -159,6 +160,7 @@ def __init__(
self.subscription_protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
self.multipart_uploads_enabled = multipart_uploads_enabled
self.batch = batch

if graphiql is not None:
warnings.warn(
Expand Down Expand Up @@ -210,7 +212,9 @@ async def get_context(
return {"request": request, "response": response} # type: ignore

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: web.Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: web.Response,
) -> web.Response:
sub_response.text = self.encode_json(response_data)
sub_response.content_type = "application/json"
Expand Down
7 changes: 6 additions & 1 deletion strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class GraphQL(
allow_queries_via_get = True
request_adapter_class = ASGIRequestAdapter
websocket_adapter_class = ASGIWebSocketAdapter
batch: bool = False

def __init__(
self,
Expand All @@ -145,6 +146,7 @@ def __init__(
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
multipart_uploads_enabled: bool = False,
batch: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
Expand All @@ -154,6 +156,7 @@ def __init__(
self.protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
self.multipart_uploads_enabled = multipart_uploads_enabled
self.batch = batch

if graphiql is not None:
warnings.warn(
Expand Down Expand Up @@ -205,7 +208,9 @@ async def render_graphql_ide(self, request: Request) -> Response:
return HTMLResponse(self.graphql_ide_html)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
response = Response(
self.encode_json(response_data),
Expand Down
6 changes: 5 additions & 1 deletion strawberry/chalice/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def __init__(
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
batch: bool = False,
) -> None:
self.allow_queries_via_get = allow_queries_via_get
self.schema = schema
self.batch = batch
if graphiql is not None:
warnings.warn(
"The `graphiql` argument is deprecated in favor of `graphql_ide`",
Expand Down Expand Up @@ -114,7 +116,9 @@ def get_context(self, request: Request, response: TemporalResponse) -> Context:
return {"request": request, "response": response} # type: ignore

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: TemporalResponse,
) -> Response:
status_code = 200

Expand Down
16 changes: 8 additions & 8 deletions strawberry/channels/handlers/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
import warnings
from functools import cached_property
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing_extensions import TypeGuard, assert_never
from urllib.parse import parse_qs

Expand Down Expand Up @@ -167,11 +161,13 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
multipart_uploads_enabled: bool = False,
batch: bool = False,
**kwargs: Any,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
self.multipart_uploads_enabled = multipart_uploads_enabled
self.batch = batch

if graphiql is not None:
warnings.warn(
Expand All @@ -186,7 +182,9 @@ def __init__(
super().__init__(**kwargs)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: TemporalResponse,
) -> ChannelsResponse:
return ChannelsResponse(
content=json.dumps(response_data).encode(),
Expand Down Expand Up @@ -256,6 +254,7 @@ class GraphQLHTTPConsumer(
```
"""

batch: bool = False
allow_queries_via_get: bool = True
request_adapter_class = ChannelsRequestAdapter

Expand Down Expand Up @@ -329,6 +328,7 @@ class SyncGraphQLHTTPConsumer(
synchronous and not asynchronous).
"""

batch: bool = False
allow_queries_via_get: bool = True
request_adapter_class = SyncChannelsRequestAdapter

Expand Down
6 changes: 5 additions & 1 deletion strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def __init__(
super().__init__(**kwargs)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: HttpResponse,
) -> HttpResponseBase:
data = self.encode_json(response_data)

Expand Down Expand Up @@ -215,6 +217,7 @@ class GraphQLView(
allow_queries_via_get = True
schema: BaseSchema = None # type: ignore
request_adapter_class = DjangoHTTPRequestAdapter
batch: bool = False

def get_root_value(self, request: HttpRequest) -> Optional[RootValue]:
return None
Expand Down Expand Up @@ -263,6 +266,7 @@ class AsyncGraphQLView(
allow_queries_via_get = True
schema: BaseSchema = None # type: ignore
request_adapter_class = AsyncDjangoHTTPRequestAdapter
batch: bool = False

@classonlymethod # pyright: ignore[reportIncompatibleMethodOverride]
def as_view(cls, **initkwargs: Any) -> Callable[..., HttpResponse]: # noqa: N805
Expand Down
7 changes: 6 additions & 1 deletion strawberry/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class GraphQLRouter(
allow_queries_via_get = True
request_adapter_class = ASGIRequestAdapter
websocket_adapter_class = ASGIWebSocketAdapter
batch: bool = False

@staticmethod
async def __get_root_value() -> None:
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
generate_unique_id
),
multipart_uploads_enabled: bool = False,
batch: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
self.protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
self.multipart_uploads_enabled = multipart_uploads_enabled
self.batch = batch

if graphiql is not None:
warnings.warn(
Expand Down Expand Up @@ -274,7 +277,9 @@ async def get_sub_response(self, request: Request) -> Response:
return self.temporal_response

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
response = Response(
self.encode_json(response_data),
Expand Down
8 changes: 7 additions & 1 deletion strawberry/flask/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
multipart_uploads_enabled: bool = False,
batch: bool = False,
) -> None:
self.schema = schema
self.graphiql = graphiql
self.allow_queries_via_get = allow_queries_via_get
self.multipart_uploads_enabled = multipart_uploads_enabled
self.batch = batch

if graphiql is not None:
warnings.warn(
Expand All @@ -91,7 +93,9 @@ def __init__(
self.graphql_ide = graphql_ide

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
sub_response.set_data(self.encode_json(response_data)) # type: ignore

Expand All @@ -106,6 +110,7 @@ class GraphQLView(
methods: ClassVar[list[str]] = ["GET", "POST"]
allow_queries_via_get: bool = True
request_adapter_class = FlaskHTTPRequestAdapter
batch: bool = False

def get_context(self, request: Request, response: Response) -> Context:
return {"request": request, "response": response} # type: ignore
Expand Down Expand Up @@ -169,6 +174,7 @@ class AsyncGraphQLView(
methods: ClassVar[list[str]] = ["GET", "POST"]
allow_queries_via_get: bool = True
request_adapter_class = AsyncFlaskHTTPRequestAdapter
batch: bool = False

async def get_context(self, request: Request, response: Response) -> Context:
return {"request": request, "response": response} # type: ignore
Expand Down
Loading
Loading