diff --git a/.alexrc b/.alexrc index ea3756ce5a..0d9c4005d9 100644 --- a/.alexrc +++ b/.alexrc @@ -9,6 +9,7 @@ "executed", "executes", "execution", + "reject", "special", "primitive", "invalid", diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..2a0c1a0554 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,32 @@ +Release type: minor + +This release adds a new `on_ws_connect` method to all HTTP view integrations. +The method is called when a `graphql-transport-ws` or `graphql-ws` connection is +established and can be used to customize the connection acknowledgment behavior. + +This is particularly useful for authentication, authorization, and sending a +custom acknowledgment payload to clients when a connection is accepted. For +example: + +```python +class MyGraphQLView(GraphQLView): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret: + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgement payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgement payload + return await super().on_ws_connect(context) +``` + +Take a look at our documentation to learn more. diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index 64655d5fbf..2c7be1e725 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -53,6 +53,7 @@ methods: - `def decode_json(self, data: Union[str, bytes]) -> object` - `def encode_json(self, data: object) -> str` - `async def render_graphql_ide(self, request: Request) -> Response` +- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]` ### get_context @@ -199,3 +200,47 @@ class MyGraphQLView(GraphQLView): return Response(text=custom_html, content_type="text/html") ``` + +### on_ws_connect + +By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws` +or `graphql-transport-ws` connection is established. This is particularly useful +for authentication and authorization. By default, all connections are accepted. + +To manually accept a connection, return `strawberry.UNSET` or a connection +acknowledgment payload. The acknowledgment payload will be sent to the client. + +Note that the legacy protocol does not support `None`/`null` acknowledgment +payloads, while the new protocol does. Our implementation will treat +`None`/`null` payloads the same as `strawberry.UNSET` in the context of the +legacy protocol. + +To reject a connection, raise a `ConnectionRejectionError`. You can optionally +provide a custom error payload that will be sent to the client when the legacy +GraphQL over WebSocket protocol is used. + +```python +from typing import Dict +from strawberry.exceptions import ConnectionRejectionError +from strawberry.aiohttp.views import GraphQLView + + +class MyGraphQLView(GraphQLView): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret": + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgment payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgment payload + return await super().on_ws_connect(context) +``` diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index 3cad706772..81b90b393c 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -53,6 +53,7 @@ methods: - `def decode_json(self, data: Union[str, bytes]) -> object` - `def encode_json(self, data: object) -> str` - `async def render_graphql_ide(self, request: Request) -> Response` +- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]` ### get_context @@ -241,3 +242,47 @@ class MyGraphQL(GraphQL): return HTMLResponse(custom_html) ``` + +### on_ws_connect + +By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws` +or `graphql-transport-ws` connection is established. This is particularly useful +for authentication and authorization. By default, all connections are accepted. + +To manually accept a connection, return `strawberry.UNSET` or a connection +acknowledgment payload. The acknowledgment payload will be sent to the client. + +Note that the legacy protocol does not support `None`/`null` acknowledgment +payloads, while the new protocol does. Our implementation will treat +`None`/`null` payloads the same as `strawberry.UNSET` in the context of the +legacy protocol. + +To reject a connection, raise a `ConnectionRejectionError`. You can optionally +provide a custom error payload that will be sent to the client when the legacy +GraphQL over WebSocket protocol is used. + +```python +from typing import Dict +from strawberry.exceptions import ConnectionRejectionError +from strawberry.asgi import GraphQL + + +class MyGraphQL(GraphQL): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret": + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgment payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgment payload + return await super().on_ws_connect(context) +``` diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index 86c3041d10..6921fa7a5f 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -592,6 +592,51 @@ following methods: - `async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]` - `def decode_json(self, data: Union[str, bytes]) -> object` - `def encode_json(self, data: object) -> str` +- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]` + +### on_ws_connect + +By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws` +or `graphql-transport-ws` connection is established. This is particularly useful +for authentication and authorization. By default, all connections are accepted. + +To manually accept a connection, return `strawberry.UNSET` or a connection +acknowledgment payload. The acknowledgment payload will be sent to the client. + +Note that the legacy protocol does not support `None`/`null` acknowledgment +payloads, while the new protocol does. Our implementation will treat +`None`/`null` payloads the same as `strawberry.UNSET` in the context of the +legacy protocol. + +To reject a connection, raise a `ConnectionRejectionError`. You can optionally +provide a custom error payload that will be sent to the client when the legacy +GraphQL over WebSocket protocol is used. + +```python +from typing import Dict +from strawberry.exceptions import ConnectionRejectionError +from strawberry.channels import GraphQLWSConsumer + + +class MyGraphQLWSConsumer(GraphQLWSConsumer): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret": + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgment payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgment payload + return await super().on_ws_connect(context) +``` ### Context diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 1095dfca9f..6fbb7dbc2a 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -268,6 +268,7 @@ following methods: - `def decode_json(self, data: Union[str, bytes]) -> object` - `def encode_json(self, data: object) -> str` - `async def render_graphql_ide(self, request: Request) -> HTMLResponse` +- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]` ### process_result @@ -354,3 +355,47 @@ class MyGraphQLRouter(GraphQLRouter): return HTMLResponse(custom_html) ``` + +### on_ws_connect + +By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws` +or `graphql-transport-ws` connection is established. This is particularly useful +for authentication and authorization. By default, all connections are accepted. + +To manually accept a connection, return `strawberry.UNSET` or a connection +acknowledgment payload. The acknowledgment payload will be sent to the client. + +Note that the legacy protocol does not support `None`/`null` acknowledgment +payloads, while the new protocol does. Our implementation will treat +`None`/`null` payloads the same as `strawberry.UNSET` in the context of the +legacy protocol. + +To reject a connection, raise a `ConnectionRejectionError`. You can optionally +provide a custom error payload that will be sent to the client when the legacy +GraphQL over WebSocket protocol is used. + +```python +from typing import Dict +from strawberry.exceptions import ConnectionRejectionError +from strawberry.fastapi import GraphQLRouter + + +class MyGraphQLRouter(GraphQLRouter): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret": + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgment payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgment payload + return await super().on_ws_connect(context) +``` diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md index 002fb5cbee..77af89472a 100644 --- a/docs/integrations/litestar.md +++ b/docs/integrations/litestar.md @@ -327,6 +327,7 @@ extended by overriding any of the following methods: - `def decode_json(self, data: Union[str, bytes]) -> object` - `def encode_json(self, data: object) -> str` - `async def render_graphql_ide(self, request: Request) -> Response` +- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]` ### process_result @@ -476,3 +477,63 @@ class MyGraphQLController(GraphQLController): return Response(custom_html, media_type=MediaType.HTML) ``` + +### on_ws_connect + +By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws` +or `graphql-transport-ws` connection is established. This is particularly useful +for authentication and authorization. By default, all connections are accepted. + +To manually accept a connection, return `strawberry.UNSET` or a connection +acknowledgment payload. The acknowledgment payload will be sent to the client. + +Note that the legacy protocol does not support `None`/`null` acknowledgment +payloads, while the new protocol does. Our implementation will treat +`None`/`null` payloads the same as `strawberry.UNSET` in the context of the +legacy protocol. + +To reject a connection, raise a `ConnectionRejectionError`. You can optionally +provide a custom error payload that will be sent to the client when the legacy +GraphQL over WebSocket protocol is used. + +```python +import strawberry +from typing import Dict +from strawberry.exceptions import ConnectionRejectionError +from strawberry.litestar import make_graphql_controller + + +@strawberry.type +class Query: + @strawberry.field + def hello(self) -> str: + return "world" + + +schema = strawberry.Schema(Query) + +GraphQLController = make_graphql_controller( + schema, + path="/graphql", +) + + +class MyGraphQLController(GraphQLController): + async def on_ws_connect(self, context: Dict[str, object]): + connection_params = context["connection_params"] + + if not isinstance(connection_params, dict): + # Reject without a custom graphql-ws error payload + raise ConnectionRejectionError() + + if connection_params.get("password") != "secret": + # Reject with a custom graphql-ws error payload + raise ConnectionRejectionError({"reason": "Invalid password"}) + + if username := connection_params.get("username"): + # Accept with a custom acknowledgment payload + return {"message": f"Hello, {username}!"} + + # Accept without a acknowledgment payload + return await super().on_ws_connect(context) +``` diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py index ee331af721..d492c7e9bb 100644 --- a/strawberry/exceptions/__init__.py +++ b/strawberry/exceptions/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Optional, Set, Union +from typing import TYPE_CHECKING, Dict, Optional, Set, Union from graphql import GraphQLError @@ -157,6 +157,13 @@ class StrawberryGraphQLError(GraphQLError): """Use it when you want to override the graphql.GraphQLError in custom extensions.""" +class ConnectionRejectionError(Exception): + """Use it when you want to reject a WebSocket connection.""" + + def __init__(self, payload: Dict[str, object] = {}) -> None: + self.payload = payload + + __all__ = [ "StrawberryException", "UnableToFindExceptionSource", diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 57a307b2e8..3363851509 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -21,7 +21,6 @@ from graphql import GraphQLError -from strawberry import UNSET from strawberry.exceptions import MissingQueryError from strawberry.file_uploads.utils import replace_placeholders_with_files from strawberry.http import ( @@ -39,6 +38,7 @@ from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.graphql import OperationType +from strawberry.types.unset import UNSET, UnsetType from .base import BaseView from .exceptions import HTTPException @@ -279,6 +279,7 @@ async def run( if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: await self.graphql_transport_ws_handler_class( + view=self, websocket=websocket, context=context, root_value=root_value, @@ -288,6 +289,7 @@ async def run( ).handle() elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL: await self.graphql_ws_handler_class( + view=self, websocket=websocket, context=context, root_value=root_value, @@ -476,5 +478,10 @@ async def process_result( ) -> GraphQLHTTPResponse: return process_result(result) + async def on_ws_connect( + self, context: Context + ) -> Union[UnsetType, None, Dict[str, object]]: + return UNSET + __all__ = ["AsyncBaseHTTPView"] diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index cdda39595a..647ab7ab3c 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -14,6 +14,7 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse +from strawberry.exceptions import ConnectionRejectionError from strawberry.http.exceptions import ( NonJsonMessageReceived, NonTextMessageReceived, @@ -31,13 +32,14 @@ from strawberry.types import ExecutionResult from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType +from strawberry.types.unset import UnsetType from strawberry.utils.debug import pretty_print_graphql_operation from strawberry.utils.operation import get_operation_type if TYPE_CHECKING: from datetime import timedelta - from strawberry.http.async_base_view import AsyncWebSocketAdapter + from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult @@ -47,6 +49,7 @@ class BaseGraphQLTransportWSHandler: def __init__( self, + view: AsyncBaseHTTPView, websocket: AsyncWebSocketAdapter, context: object, root_value: object, @@ -54,6 +57,7 @@ def __init__( debug: bool, connection_init_wait_timeout: timedelta, ) -> None: + self.view = view self.websocket = websocket self.context = context self.root_value = root_value @@ -66,7 +70,6 @@ def __init__( self.connection_timed_out = False self.operations: Dict[str, Operation] = {} self.completed_tasks: List[asyncio.Task] = [] - self.connection_params: Optional[Dict[str, object]] = None async def handle(self) -> None: self.on_request_accepted() @@ -169,15 +172,31 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: ) return - self.connection_params = payload - if self.connection_init_received: reason = "Too many initialisation requests" await self.websocket.close(code=4429, reason=reason) return self.connection_init_received = True - await self.send_message({"type": "connection_ack"}) + + if isinstance(self.context, dict): + self.context["connection_params"] = payload + elif hasattr(self.context, "connection_params"): + self.context.connection_params = payload + + try: + connection_ack_payload = await self.view.on_ws_connect(self.context) + except ConnectionRejectionError: + await self.websocket.close(code=4403, reason="Forbidden") + return + + if isinstance(connection_ack_payload, UnsetType): + await self.send_message({"type": "connection_ack"}) + else: + await self.send_message( + {"type": "connection_ack", "payload": connection_ack_payload} + ) + self.connection_acknowledged = True async def handle_ping(self, message: PingMessage) -> None: @@ -219,11 +238,6 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: message["payload"].get("variables"), ) - if isinstance(self.context, dict): - self.context["connection_params"] = self.connection_params - elif hasattr(self.context, "connection_params"): - self.context.connection_params = self.connection_params - operation = Operation( self, message["id"], diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 03cf11b71f..95536bc4fd 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -10,6 +10,7 @@ cast, ) +from strawberry.exceptions import ConnectionRejectionError from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected from strawberry.subscriptions.protocols.graphql_ws.types import ( ConnectionInitMessage, @@ -20,16 +21,18 @@ StopMessage, ) from strawberry.types.execution import ExecutionResult, PreExecutionError +from strawberry.types.unset import UnsetType from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: - from strawberry.http.async_base_view import AsyncWebSocketAdapter + from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter from strawberry.schema import BaseSchema class BaseGraphQLWSHandler: def __init__( self, + view: AsyncBaseHTTPView, websocket: AsyncWebSocketAdapter, context: object, root_value: object, @@ -38,6 +41,7 @@ def __init__( keep_alive: bool, keep_alive_interval: Optional[float], ) -> None: + self.view = view self.websocket = websocket self.context = context self.root_value = root_value @@ -48,7 +52,6 @@ def __init__( self.keep_alive_task: Optional[asyncio.Task] = None self.subscriptions: Dict[str, AsyncGenerator] = {} self.tasks: Dict[str, asyncio.Task] = {} - self.connection_params: Optional[Dict[str, object]] = None async def handle(self) -> None: try: @@ -92,9 +95,27 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: await self.websocket.close(code=1000, reason="") return - self.connection_params = payload + if isinstance(self.context, dict): + self.context["connection_params"] = payload + elif hasattr(self.context, "connection_params"): + self.context.connection_params = payload - await self.send_message({"type": "connection_ack"}) + try: + connection_ack_payload = await self.view.on_ws_connect(self.context) + except ConnectionRejectionError as e: + await self.send_message({"type": "connection_error", "payload": e.payload}) + await self.websocket.close(code=1011, reason="") + return + + if ( + isinstance(connection_ack_payload, UnsetType) + or connection_ack_payload is None + ): + await self.send_message({"type": "connection_ack"}) + else: + await self.send_message( + {"type": "connection_ack", "payload": connection_ack_payload} + ) if self.keep_alive: keep_alive_handler = self.handle_keep_alive() @@ -112,11 +133,6 @@ async def handle_start(self, message: StartMessage) -> None: operation_name = payload.get("operationName") variables = payload.get("variables") - if isinstance(self.context, dict): - self.context["connection_params"] = self.connection_params - elif hasattr(self.context, "connection_params"): - self.context.connection_params = self.connection_params - if self.debug: pretty_print_graphql_operation(operation_name, query, variables) diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index 56aa81ab1b..d29a6209fb 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -37,6 +37,7 @@ class ConnectionErrorMessage(TypedDict): class ConnectionAckMessage(TypedDict): type: Literal["connection_ack"] + payload: NotRequired[Dict[str, object]] class DataMessagePayload(TypedDict): diff --git a/tests/channels/test_testing.py b/tests/channels/test_testing.py index 99aa9dd6c8..f45f535214 100644 --- a/tests/channels/test_testing.py +++ b/tests/channels/test_testing.py @@ -52,4 +52,4 @@ async def test_graphql_error(communicator): async def test_simple_connection_params(communicator): async for res in communicator.subscribe(query="subscription { connectionParams }"): - assert res.data["connectionParams"] == "Hi" + assert res.data["connectionParams"]["strawberry"] == "Hi" diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 89b0c718e8..12cb035937 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -15,6 +15,7 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin from ..context import get_context from .base import ( @@ -29,14 +30,14 @@ ) -class GraphQLView(BaseGraphQLView): +class GraphQLView(OnWSConnectMixin, BaseGraphQLView[Dict[str, object], object]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler async def get_context( self, request: web.Request, response: web.StreamResponse - ) -> object: + ) -> Dict[str, object]: context = await super().get_context(request, response) return get_context(context) diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 7910e02f73..7d9b86ea8e 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -16,6 +16,7 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin from ..context import get_context from .base import ( @@ -30,7 +31,7 @@ ) -class GraphQLView(BaseGraphQLView): +class GraphQLView(OnWSConnectMixin, BaseGraphQLView[Dict[str, object], object]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler @@ -41,8 +42,8 @@ async def get_root_value(self, request: Union[WebSocket, Request]) -> Query: async def get_context( self, request: Union[Request, WebSocket], - response: Optional[StarletteResponse] = None, - ) -> object: + response: Union[StarletteResponse, WebSocket], + ) -> Dict[str, object]: context = await super().get_context(request, response) return get_context(context) diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index bde2364128..802fc263cc 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -19,6 +19,7 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.http.typevars import Context, RootValue from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin from ..context import get_context from .base import ( @@ -113,7 +114,7 @@ def process_result( return super().process_result(request, result) -class DebuggableGraphQLWSConsumer(GraphQLWSConsumer): +class DebuggableGraphQLWSConsumer(OnWSConnectMixin, GraphQLWSConsumer): graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index b1b80625fa..70eded4049 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -15,6 +15,7 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin from ..context import get_context from .asgi import AsgiWebSocketClient @@ -54,7 +55,7 @@ async def get_root_value( return Query() -class GraphQLRouter(BaseGraphQLRouter[Any, Any]): +class GraphQLRouter(OnWSConnectMixin, BaseGraphQLRouter[Any, Any]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 2548dc563c..dc948e9868 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -15,6 +15,7 @@ from strawberry.litestar import make_graphql_controller from strawberry.types import ExecutionResult from tests.views.schema import Query, schema +from tests.websockets.views import OnWSConnectMixin from ..context import get_context from .base import ( @@ -67,7 +68,7 @@ def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: A **kwargs, ) - class GraphQLController(BaseGraphQLController): + class GraphQLController(OnWSConnectMixin, BaseGraphQLController): graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler diff --git a/tests/views/schema.py b/tests/views/schema.py index ab959fbe01..cb5d9959ef 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -248,8 +248,8 @@ async def listener_with_confirmation( @strawberry.subscription async def connection_params( self, info: strawberry.Info - ) -> AsyncGenerator[str, None]: - yield info.context["connection_params"]["strawberry"] + ) -> AsyncGenerator[strawberry.scalars.JSON, None]: + yield info.context["connection_params"] @strawberry.subscription async def long_finalizer( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 09e271f681..3a7b5849a3 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -247,6 +247,94 @@ async def test_too_many_initialisation_requests(ws: WebSocketClient): assert ws.close_reason == "Too many initialisation requests" +async def test_connections_are_accepted_by_default(ws_raw: WebSocketClient): + await ws_raw.send_message({"type": "connection_init"}) + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.close() + assert ws_raw.closed + + +@pytest.mark.parametrize("payload", [None, {"token": "secret"}]) +async def test_setting_a_connection_ack_payload(ws_raw: WebSocketClient, payload): + await ws_raw.send_message( + { + "type": "connection_init", + "payload": {"test-accept": True, "ack-payload": payload}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack", "payload": payload} + + await ws_raw.close() + assert ws_raw.closed + + +async def test_connection_ack_payload_may_be_unset(ws_raw: WebSocketClient): + await ws_raw.send_message( + { + "type": "connection_init", + "payload": {"test-accept": True}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.close() + assert ws_raw.closed + + +async def test_rejecting_connection_closes_socket_with_expected_code_and_message( + ws_raw: WebSocketClient, +): + await ws_raw.send_message( + {"type": "connection_init", "payload": {"test-reject": True}} + ) + + await ws_raw.receive(timeout=2) + assert ws_raw.closed + assert ws_raw.close_code == 4403 + assert ws_raw.close_reason == "Forbidden" + + +async def test_context_can_be_modified_from_within_on_ws_connect( + ws_raw: WebSocketClient, +): + await ws_raw.send_message( + { + "type": "connection_init", + "payload": {"test-modify": True}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.send_message( + { + "type": "subscribe", + "id": "demo", + "payload": { + "query": "subscription { connectionParams }", + }, + } + ) + + next_message: NextMessage = await ws_raw.receive_json() + assert next_message["type"] == "next" + assert next_message["id"] == "demo" + assert "data" in next_message["payload"] + assert next_message["payload"]["data"] == { + "connectionParams": {"test-modify": True, "modified": True} + } + + await ws_raw.close() + assert ws_raw.closed + + async def test_ping_pong(ws: WebSocketClient): await ws.send_message({"type": "ping"}) pong_message: PongMessage = await ws.receive_json() @@ -823,7 +911,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): ) next_message: NextMessage = await ws.receive_json() - assert_next(next_message, "sub1", {"connectionParams": "rocks"}) + assert_next(next_message, "sub1", {"connectionParams": {"strawberry": "rocks"}}) await ws.send_message({"id": "sub1", "type": "complete"}) diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 5564238ca3..26488c9911 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -107,6 +107,137 @@ async def test_operation_selection(ws: WebSocketClient): assert complete_message["id"] == "demo" +async def test_connections_are_accepted_by_default(ws_raw: WebSocketClient): + await ws_raw.send_legacy_message({"type": "connection_init"}) + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.close() + assert ws_raw.closed + + +async def test_setting_a_connection_ack_payload(ws_raw: WebSocketClient): + await ws_raw.send_legacy_message( + { + "type": "connection_init", + "payload": {"test-accept": True, "ack-payload": {"token": "secret"}}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == { + "type": "connection_ack", + "payload": {"token": "secret"}, + } + + await ws_raw.close() + assert ws_raw.closed + + +async def test_connection_ack_payload_may_be_unset(ws_raw: WebSocketClient): + await ws_raw.send_legacy_message( + { + "type": "connection_init", + "payload": {"test-accept": True}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.close() + assert ws_raw.closed + + +async def test_a_connection_ack_payload_of_none_is_treated_as_unset( + ws_raw: WebSocketClient, +): + await ws_raw.send_legacy_message( + { + "type": "connection_init", + "payload": {"test-accept": True, "ack-payload": None}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.close() + assert ws_raw.closed + + +async def test_rejecting_connection_results_in_error_message_and_socket_closure( + ws_raw: WebSocketClient, +): + await ws_raw.send_legacy_message( + {"type": "connection_init", "payload": {"test-reject": True}} + ) + + connection_error_message: ConnectionErrorMessage = await ws_raw.receive_json() + assert connection_error_message == {"type": "connection_error", "payload": {}} + + await ws_raw.receive(timeout=2) + assert ws_raw.closed + assert ws_raw.close_code == 1011 + assert not ws_raw.close_reason + + +async def test_rejecting_connection_with_custom_connection_error_payload( + ws_raw: WebSocketClient, +): + await ws_raw.send_legacy_message( + { + "type": "connection_init", + "payload": {"test-reject": True, "err-payload": {"custom": "error"}}, + } + ) + + connection_error_message: ConnectionErrorMessage = await ws_raw.receive_json() + assert connection_error_message == { + "type": "connection_error", + "payload": {"custom": "error"}, + } + + await ws_raw.receive(timeout=2) + assert ws_raw.closed + assert ws_raw.close_code == 1011 + assert not ws_raw.close_reason + + +async def test_context_can_be_modified_from_within_on_ws_connect( + ws_raw: WebSocketClient, +): + await ws_raw.send_legacy_message( + { + "type": "connection_init", + "payload": {"test-modify": True}, + } + ) + + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws_raw.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { connectionParams }", + }, + } + ) + + data_message: DataMessage = await ws_raw.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == { + "connectionParams": {"test-modify": True, "modified": True} + } + + await ws_raw.close() + assert ws_raw.closed + + async def test_sends_keep_alive(aiohttp_app_client: HttpClient): aiohttp_app_client.create_app(keep_alive=True, keep_alive_interval=0.1) async with aiohttp_app_client.ws_connect( @@ -589,7 +720,9 @@ async def test_injects_connection_params(aiohttp_app_client: HttpClient): data_message: DataMessage = await ws.receive_json() assert data_message["type"] == "data" assert data_message["id"] == "demo" - assert data_message["payload"]["data"] == {"connectionParams": "rocks"} + assert data_message["payload"]["data"] == { + "connectionParams": {"strawberry": "rocks"} + } await ws.send_legacy_message({"type": "stop", "id": "demo"}) diff --git a/tests/websockets/views.py b/tests/websockets/views.py new file mode 100644 index 0000000000..eec511131e --- /dev/null +++ b/tests/websockets/views.py @@ -0,0 +1,30 @@ +from typing import Dict, Union + +from strawberry import UNSET +from strawberry.exceptions import ConnectionRejectionError +from strawberry.http.async_base_view import AsyncBaseHTTPView +from strawberry.types.unset import UnsetType + + +class OnWSConnectMixin(AsyncBaseHTTPView): + async def on_ws_connect( + self, context: Dict[str, object] + ) -> Union[UnsetType, None, Dict[str, object]]: + connection_params = context["connection_params"] + + if isinstance(connection_params, dict): + if connection_params.get("test-reject"): + if "err-payload" in connection_params: + raise ConnectionRejectionError(connection_params["err-payload"]) + raise ConnectionRejectionError() + + if connection_params.get("test-accept"): + if "ack-payload" in connection_params: + return connection_params["ack-payload"] + return UNSET + + if connection_params.get("test-modify"): + connection_params["modified"] = True + return UNSET + + return await super().on_ws_connect(context)