diff --git a/strawberry/http/__init__.py b/strawberry/http/__init__.py index 8e42a4d680..6c3a4b3dbd 100644 --- a/strawberry/http/__init__.py +++ b/strawberry/http/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import Literal, NotRequired, TypedDict +from strawberry.types import InitialIncrementalExecutionResult + if TYPE_CHECKING: from strawberry.types import ExecutionResult @@ -14,7 +16,9 @@ class GraphQLHTTPResponse(TypedDict, total=False): extensions: Optional[dict[str, object]] -def process_result(result: ExecutionResult) -> GraphQLHTTPResponse: +def process_result( + result: Union[ExecutionResult, InitialIncrementalExecutionResult], +) -> GraphQLHTTPResponse: data: GraphQLHTTPResponse = {"data": result.data} if result.errors: @@ -22,6 +26,10 @@ def process_result(result: ExecutionResult) -> GraphQLHTTPResponse: if result.extensions: data["extensions"] = result.extensions + if isinstance(result, InitialIncrementalExecutionResult): + data["hasNext"] = result.has_next + data["pending"] = result.pending + return data @@ -39,6 +47,7 @@ class IncrementalGraphQLHTTPResponse(TypedDict): incremental: list[GraphQLHTTPResponse] hasNext: bool extensions: NotRequired[dict[str, Any]] + completed: list[GraphQLHTTPResponse] __all__ = [ diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 568288ac0a..b813135003 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -578,14 +578,20 @@ async def process_result( result: Union[ExecutionResult, InitialIncrementalExecutionResult], ) -> GraphQLHTTPResponse: if isinstance(result, InitialIncrementalExecutionResult): - return { - "data": result.data, - "pending": [ - pending_result.formatted for pending_result in result.pending - ], - "hasNext": result.has_next, - "extensions": result.extensions, - } + # TODO: fix this mess + from strawberry.types import ( + InitialIncrementalExecutionResult as InitialIncrementalExecutionResultType, + ) + + # TODO: do this where we create ExecutionResult + # or maybe remove our wrappers and just GraphQL core's types + result = InitialIncrementalExecutionResultType( + data=result.data, + pending=[pending_result.formatted for pending_result in result.pending], + has_next=result.has_next, + extensions=result.extensions, + errors=result.errors, + ) result = await self.schema._handle_execution_result( context=self.schema.execution_context, diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index e81c68f9d4..1cd747dcd8 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -300,7 +300,7 @@ def _create_execution_context( provided_operation_name=operation_name, ) - # TODO: is this the right place to do this? + # TODO: is this the right place to do this? async def _handle_execution_result( self, context: ExecutionContext, @@ -319,6 +319,7 @@ async def _handle_execution_result( if isinstance(result, GraphQLExecutionResult): result = ExecutionResult(data=result.data, errors=result.errors) + # TODO: not correct when handling incremental results result.extensions = await extensions_runner.get_extensions_results(context) context.result = result # type: ignore # mypy failed to deduce correct type. diff --git a/strawberry/types/__init__.py b/strawberry/types/__init__.py index ba3b5158b9..8c6cc7ba2c 100644 --- a/strawberry/types/__init__.py +++ b/strawberry/types/__init__.py @@ -1,12 +1,17 @@ from .base import get_object_definition, has_object_definition -from .execution import ExecutionContext, ExecutionResult, SubscriptionExecutionResult +from .execution import ( + ExecutionContext, + ExecutionResult, + InitialIncrementalExecutionResult, + SubscriptionExecutionResult, +) from .info import Info __all__ = [ "ExecutionContext", "ExecutionResult", "Info", - "Info", + "InitialIncrementalExecutionResult", "SubscriptionExecutionResult", "get_object_definition", "has_object_definition", diff --git a/strawberry/types/execution.py b/strawberry/types/execution.py index d28440246a..72f4c2567e 100644 --- a/strawberry/types/execution.py +++ b/strawberry/types/execution.py @@ -92,6 +92,15 @@ class ExecutionResult: extensions: Optional[dict[str, Any]] = None +@dataclasses.dataclass +class InitialIncrementalExecutionResult: + data: Optional[dict[str, Any]] + errors: Optional[list[GraphQLError]] + pending: list[Any] + has_next: bool + extensions: Optional[dict[str, Any]] = None + + @dataclasses.dataclass class PreExecutionError(ExecutionResult): """Differentiate between a normal execution result and an immediate error. diff --git a/tests/http/incremental/test_defer.py b/tests/http/incremental/test_defer.py index 28547ec4ab..8e56306cc2 100644 --- a/tests/http/incremental/test_defer.py +++ b/tests/http/incremental/test_defer.py @@ -30,8 +30,7 @@ async def test_basic_defer(method: Literal["get", "post"], http_client: HttpClie "data": {"hero": {"id": "1"}}, "hasNext": True, "pending": [{"path": ["hero"], "id": "0"}], - # TODO: why is this None? - "extensions": None, + "extensions": {"example": "example"}, } subsequent = await stream.__anext__()