Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Query Batching Support #3755

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
68 changes: 68 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
Release type: minor

## Add GraphQL Query batching support

GraphQL query batching is now supported across all frameworks (sync and async)
To enable query batching, set `batching_config.enabled` to True in the schema configuration.

This makes your GraphQL API compatible with batching features supported by various
client side libraries, such as [Apollo GraphQL](https://www.apollographql.com/docs/react/api/link/apollo-link-batch-http) and [Relay](https://github.com/relay-tools/react-relay-network-modern?tab=readme-ov-file#batching-several-requests-into-one).

Example (FastAPI):

```py
import strawberry

from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig


@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "Hello World"


schema = strawberry.Schema(
Query, config=StrawberryConfig(batching_config={"enabled": True})
)

graphql_app = GraphQLRouter(schema)

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
```

Example (Flask):
```py
import strawberry

from flask import Flask
from strawberry.flask.views import GraphQLView

app = Flask(__name__)


@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "Hello World"


schema = strawberry.Schema(
Query, config=StrawberryConfig(batching_config={"enabled": True})
)

app.add_url_rule(
"/graphql/batch",
view_func=GraphQLView.as_view("graphql_view", schema=schema),
)

if __name__ == "__main__":
app.run()
```

Note: Query Batching is not supported for multipart subscriptions
4 changes: 3 additions & 1 deletion strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ async def get_context(
return {"request": request, "response": response} # type: ignore

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: web.Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: web.Response,
) -> web.Response:
sub_response.text = self.encode_json(response_data)
sub_response.content_type = "application/json"
Expand Down
4 changes: 3 additions & 1 deletion strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ async def render_graphql_ide(self, request: Request) -> Response:
return HTMLResponse(self.graphql_ide_html)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
response = Response(
self.encode_json(response_data),
Expand Down
4 changes: 3 additions & 1 deletion strawberry/chalice/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def get_context(self, request: Request, response: TemporalResponse) -> Context:
return {"request": request, "response": response} # type: ignore

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: TemporalResponse,
) -> Response:
status_code = 200

Expand Down
12 changes: 4 additions & 8 deletions strawberry/channels/handlers/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
import warnings
from functools import cached_property
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing_extensions import TypeGuard, assert_never
from urllib.parse import parse_qs

Expand Down Expand Up @@ -186,7 +180,9 @@ def __init__(
super().__init__(**kwargs)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: TemporalResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: TemporalResponse,
) -> ChannelsResponse:
return ChannelsResponse(
content=json.dumps(response_data).encode(),
Expand Down
2 changes: 1 addition & 1 deletion strawberry/django/apps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.apps import AppConfig # pragma: no cover


class StrawberryConfig(AppConfig): # pragma: no cover
class StrawberryAppConfig(AppConfig): # pragma: no cover
name = "strawberry"
4 changes: 3 additions & 1 deletion strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def __init__(
super().__init__(**kwargs)

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: HttpResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: HttpResponse,
) -> HttpResponseBase:
data = self.encode_json(response_data)

Expand Down
4 changes: 3 additions & 1 deletion strawberry/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ async def get_sub_response(self, request: Request) -> Response:
return self.temporal_response

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
response = Response(
self.encode_json(response_data),
Expand Down
4 changes: 3 additions & 1 deletion strawberry/flask/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def __init__(
self.graphql_ide = graphql_ide

def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: Response
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: Response,
) -> Response:
sub_response.set_data(self.encode_json(response_data)) # type: ignore

Expand Down
124 changes: 101 additions & 23 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ async def get_root_value(

@abc.abstractmethod
def create_response(
self, response_data: GraphQLHTTPResponse, sub_response: SubResponse
self,
response_data: Union[GraphQLHTTPResponse, list[GraphQLHTTPResponse]],
sub_response: SubResponse,
) -> Response: ...

@abc.abstractmethod
Expand Down Expand Up @@ -178,8 +180,12 @@ async def create_websocket_response(
) -> WebSocketResponse: ...

async def execute_operation(
self, request: Request, context: Context, root_value: Optional[RootValue]
) -> Union[ExecutionResult, SubscriptionExecutionResult]:
self,
request: Request,
context: Context,
root_value: Optional[RootValue],
sub_response: SubResponse,
) -> Union[ExecutionResult, list[ExecutionResult], SubscriptionExecutionResult]:
request_adapter = self.request_adapter_class(request)

try:
Expand All @@ -197,6 +203,22 @@ async def execute_operation(

assert self.schema

if isinstance(request_data, list):
# batch GraphQL requests
tasks = [
aryaniyaps marked this conversation as resolved.
Show resolved Hide resolved
self.execute_single(
request=request,
request_adapter=request_adapter,
sub_response=sub_response,
context=context,
root_value=root_value,
request_data=data,
)
for data in request_data
]

return await asyncio.gather(*tasks)

if request_data.protocol == "multipart-subscription":
return await self.schema.subscribe(
request_data.query, # type: ignore
Expand All @@ -206,15 +228,49 @@ async def execute_operation(
operation_name=request_data.operation_name,
)

return await self.schema.execute(
request_data.query,
return await self.execute_single(
request=request,
request_adapter=request_adapter,
sub_response=sub_response,
context=context,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
request_data=request_data,
)

async def execute_single(
self,
request: Request,
request_adapter: AsyncHTTPRequestAdapter,
sub_response: SubResponse,
context: Context,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like context is shared between all operations?
This may cause unexpected side effects.
Although sometimes that may even be beneficial (for dataloaders etc).

There seems to be no way to isolate request, but it would be nice to have different context/root value per operation in my case.

Copy link
Contributor Author

@aryaniyaps aryaniyaps Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the context should be shared, reusing dataloaders and other resources must be one of the main factors behind adopting query batching, if we want separate context, we can use separate requests with batching disabled, right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, maybe my use case not that common, because I use context not only for dataloaders, but also for some application logic and some tracing.
Since sharing is simpler - separate context might be a separate feature behind config flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, maybe my use case not that common, because I use context not only for dataloaders, but also for some application logic and some tracing. Since sharing is simpler - separate context might be a separate feature behind config flag.

but Im curious @Object905 what benefits would batching bring if the context is not shared?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing request count mostly to avoid network round-trips.
For example one of my subgraphs is "external" and has ping of about 50ms. But even for "local" sub-graphs this might be noticeable.

Also there are other means to share global data (request itself, thread locals, etc.) besides context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I was trying to create a test case, replicating shopify's context driven translation.

I tried to create QUERY level schema directives like this but it wasn't working.

@strawberry.directive(locations=[DirectiveLocation.QUERY], name="in_context")
def in_context(language: str, info: strawberry.Info):
    # put the language in the context here
    print("in_context called")
    info.context.update({"language": language})

am I doing this correctly? I wasn't able to find any documentation/ test cases which were using this type of directive

Copy link
Member

@erikwrede erikwrede Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me tag @DoctorJohn for a second opinion on the FastAPI Issue, he's done quite a bit of web view work recently.

Regarding the shopify example, this is not something we have as a test case in strawberry. I was just referring to it as an example where multiple operations could cause clashes in context.

To be fair, multiple queries in a single GraphQL Document are also a valid request that could cause clashes. However, in this case, the implementation can detect it by parsing the document containing all queries. With the current batching implementation, all batch entries are isolated, i.e. don't know of each other and the possiblity to clash, but still share the same context.

In general, I still think this is a trade off worth making. As mentioned on Discord, I'll look further into the problem with directives and get back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case this is a blocker for too long, I'd be open to splitting this into 2 separate PRs as well. One that adds batching support, and the other adding an option to share or not share context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But application logic that I have in context - just can't be shared, because it's dependent on inputs, session, query itself, etc.
It can be separated manually (custom context, that hashes current operation/variables and stores different context for each one) - but this will be flaky and much better supported by strawberry itself.

@Object905 can you show us some example of this? 😊

I'm leaning towards always have a new context for each operation

Copy link

@Object905 Object905 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I can't give the full code, but here is the gist of it.

class Query:
    @strawberry.field
    def location_terms(input: LocationInfoInput, info: StrawInfo) -> LocationTerms:
        input = info.context.reduced_location_input(input, fetch_terms=False)
        ...

This is an "entry point" into location api. It can't handle multiple addresses in a single query by design (due to context sharing, etc.).
reduced_location_input - might be quite heavy, because it goes through geocoding and a lot of db queries, so doing it twice in a single request is a no-no.

Also in context there are other "heavy" objects that depend on location and building them a second time during query is heavy perf hit.

So, full info about location and heavy stuff is collected once and stored in context, so its available at any point below.

While its possible to forward such context in other ways - given that LocationTerms is quite big - its much more convenient to just use context.

root_value: Optional[RootValue],
request_data: GraphQLRequestData,
) -> ExecutionResult:
allowed_operation_types = OperationType.from_http(request_adapter.method)

if not self.allow_queries_via_get and request_adapter.method == "GET":
allowed_operation_types = allowed_operation_types - {OperationType.QUERY}

assert self.schema

try:
result = await self.schema.execute(
request_data.query,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
)
except InvalidOperationTypeError as e:
raise HTTPException(
400, e.as_http_error_reason(request_adapter.method)
) from e
except MissingQueryError as e:
raise HTTPException(400, "No GraphQL query found in the request") from e

return result

async def parse_multipart(self, request: AsyncHTTPRequestAdapter) -> dict[str, str]:
try:
form_data = await request.get_form_data()
Expand Down Expand Up @@ -326,16 +382,12 @@ async def run(
return await self.render_graphql_ide(request)
raise HTTPException(404, "Not Found")

try:
result = await self.execute_operation(
request=request, context=context, root_value=root_value
)
except InvalidOperationTypeError as e:
raise HTTPException(
400, e.as_http_error_reason(request_adapter.method)
) from e
except MissingQueryError as e:
raise HTTPException(400, "No GraphQL query found in the request") from e
result = await self.execute_operation(
request=request,
context=context,
root_value=root_value,
sub_response=sub_response,
)

if isinstance(result, SubscriptionExecutionResult):
stream = self._get_stream(request, result)
Expand All @@ -350,10 +402,20 @@ async def run(
},
)

response_data = await self.process_result(request=request, result=result)
if isinstance(result, list):
response_data = []
for execution_result in result:
result = await self.process_result(
request=request, result=execution_result
)
if execution_result.errors:
self._handle_errors(execution_result.errors, result)
response_data.append(result)
else:
response_data = await self.process_result(request=request, result=result)

if result.errors:
self._handle_errors(result.errors, response_data)
if result.errors:
self._handle_errors(result.errors, response_data)

return self.create_response(
response_data=response_data, sub_response=sub_response
Expand Down Expand Up @@ -449,7 +511,7 @@ async def parse_multipart_subscriptions(

async def parse_http_body(
self, request: AsyncHTTPRequestAdapter
) -> GraphQLRequestData:
) -> Union[GraphQLRequestData, list[GraphQLRequestData]]:
headers = {key.lower(): value for key, value in request.headers.items()}
content_type, _ = parse_content_type(request.content_type or "")
accept = headers.get("accept", "")
Expand All @@ -468,13 +530,29 @@ async def parse_http_body(
else:
raise HTTPException(400, "Unsupported content type")

if isinstance(data, list):
await self.validate_batch_request(data, protocol=protocol)
return [
GraphQLRequestData(
query=item.get("query"),
variables=item.get("variables"),
operation_name=item.get("operationName"),
protocol=protocol,
)
for item in data
]
return GraphQLRequestData(
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
protocol=protocol,
)

async def validate_batch_request(
self, request_data: list[GraphQLRequestData], protocol: str
) -> None:
self._validate_batch_request(request_data=request_data, protocol=protocol)

async def process_result(
self, request: Request, result: ExecutionResult
) -> GraphQLHTTPResponse:
Expand Down
Loading
Loading