Skip to content

Commit e105975

Browse files
Enable conditional acknowledgment of WS connections. (#3720)
Co-authored-by: Patrick Arminio <[email protected]>
1 parent 45f4e50 commit e105975

File tree

22 files changed

+609
-34
lines changed

22 files changed

+609
-34
lines changed

.alexrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"executed",
1010
"executes",
1111
"execution",
12+
"reject",
1213
"special",
1314
"primitive",
1415
"invalid",

RELEASE.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Release type: minor
2+
3+
This release adds a new `on_ws_connect` method to all HTTP view integrations.
4+
The method is called when a `graphql-transport-ws` or `graphql-ws` connection is
5+
established and can be used to customize the connection acknowledgment behavior.
6+
7+
This is particularly useful for authentication, authorization, and sending a
8+
custom acknowledgment payload to clients when a connection is accepted. For
9+
example:
10+
11+
```python
12+
class MyGraphQLView(GraphQLView):
13+
async def on_ws_connect(self, context: Dict[str, object]):
14+
connection_params = context["connection_params"]
15+
16+
if not isinstance(connection_params, dict):
17+
# Reject without a custom graphql-ws error payload
18+
raise ConnectionRejectionError()
19+
20+
if connection_params.get("password") != "secret:
21+
# Reject with a custom graphql-ws error payload
22+
raise ConnectionRejectionError({"reason": "Invalid password"})
23+
24+
if username := connection_params.get("username"):
25+
# Accept with a custom acknowledgement payload
26+
return {"message": f"Hello, {username}!"}
27+
28+
# Accept without a acknowledgement payload
29+
return await super().on_ws_connect(context)
30+
```
31+
32+
Take a look at our documentation to learn more.

docs/integrations/aiohttp.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ methods:
5353
- `def decode_json(self, data: Union[str, bytes]) -> object`
5454
- `def encode_json(self, data: object) -> str`
5555
- `async def render_graphql_ide(self, request: Request) -> Response`
56+
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`
5657

5758
### get_context
5859

@@ -199,3 +200,47 @@ class MyGraphQLView(GraphQLView):
199200

200201
return Response(text=custom_html, content_type="text/html")
201202
```
203+
204+
### on_ws_connect
205+
206+
By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
207+
or `graphql-transport-ws` connection is established. This is particularly useful
208+
for authentication and authorization. By default, all connections are accepted.
209+
210+
To manually accept a connection, return `strawberry.UNSET` or a connection
211+
acknowledgment payload. The acknowledgment payload will be sent to the client.
212+
213+
Note that the legacy protocol does not support `None`/`null` acknowledgment
214+
payloads, while the new protocol does. Our implementation will treat
215+
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
216+
legacy protocol.
217+
218+
To reject a connection, raise a `ConnectionRejectionError`. You can optionally
219+
provide a custom error payload that will be sent to the client when the legacy
220+
GraphQL over WebSocket protocol is used.
221+
222+
```python
223+
from typing import Dict
224+
from strawberry.exceptions import ConnectionRejectionError
225+
from strawberry.aiohttp.views import GraphQLView
226+
227+
228+
class MyGraphQLView(GraphQLView):
229+
async def on_ws_connect(self, context: Dict[str, object]):
230+
connection_params = context["connection_params"]
231+
232+
if not isinstance(connection_params, dict):
233+
# Reject without a custom graphql-ws error payload
234+
raise ConnectionRejectionError()
235+
236+
if connection_params.get("password") != "secret":
237+
# Reject with a custom graphql-ws error payload
238+
raise ConnectionRejectionError({"reason": "Invalid password"})
239+
240+
if username := connection_params.get("username"):
241+
# Accept with a custom acknowledgment payload
242+
return {"message": f"Hello, {username}!"}
243+
244+
# Accept without a acknowledgment payload
245+
return await super().on_ws_connect(context)
246+
```

docs/integrations/asgi.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ methods:
5353
- `def decode_json(self, data: Union[str, bytes]) -> object`
5454
- `def encode_json(self, data: object) -> str`
5555
- `async def render_graphql_ide(self, request: Request) -> Response`
56+
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`
5657

5758
### get_context
5859

@@ -241,3 +242,47 @@ class MyGraphQL(GraphQL):
241242

242243
return HTMLResponse(custom_html)
243244
```
245+
246+
### on_ws_connect
247+
248+
By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
249+
or `graphql-transport-ws` connection is established. This is particularly useful
250+
for authentication and authorization. By default, all connections are accepted.
251+
252+
To manually accept a connection, return `strawberry.UNSET` or a connection
253+
acknowledgment payload. The acknowledgment payload will be sent to the client.
254+
255+
Note that the legacy protocol does not support `None`/`null` acknowledgment
256+
payloads, while the new protocol does. Our implementation will treat
257+
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
258+
legacy protocol.
259+
260+
To reject a connection, raise a `ConnectionRejectionError`. You can optionally
261+
provide a custom error payload that will be sent to the client when the legacy
262+
GraphQL over WebSocket protocol is used.
263+
264+
```python
265+
from typing import Dict
266+
from strawberry.exceptions import ConnectionRejectionError
267+
from strawberry.asgi import GraphQL
268+
269+
270+
class MyGraphQL(GraphQL):
271+
async def on_ws_connect(self, context: Dict[str, object]):
272+
connection_params = context["connection_params"]
273+
274+
if not isinstance(connection_params, dict):
275+
# Reject without a custom graphql-ws error payload
276+
raise ConnectionRejectionError()
277+
278+
if connection_params.get("password") != "secret":
279+
# Reject with a custom graphql-ws error payload
280+
raise ConnectionRejectionError({"reason": "Invalid password"})
281+
282+
if username := connection_params.get("username"):
283+
# Accept with a custom acknowledgment payload
284+
return {"message": f"Hello, {username}!"}
285+
286+
# Accept without a acknowledgment payload
287+
return await super().on_ws_connect(context)
288+
```

docs/integrations/channels.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,51 @@ following methods:
592592
- `async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]`
593593
- `def decode_json(self, data: Union[str, bytes]) -> object`
594594
- `def encode_json(self, data: object) -> str`
595+
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`
596+
597+
### on_ws_connect
598+
599+
By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
600+
or `graphql-transport-ws` connection is established. This is particularly useful
601+
for authentication and authorization. By default, all connections are accepted.
602+
603+
To manually accept a connection, return `strawberry.UNSET` or a connection
604+
acknowledgment payload. The acknowledgment payload will be sent to the client.
605+
606+
Note that the legacy protocol does not support `None`/`null` acknowledgment
607+
payloads, while the new protocol does. Our implementation will treat
608+
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
609+
legacy protocol.
610+
611+
To reject a connection, raise a `ConnectionRejectionError`. You can optionally
612+
provide a custom error payload that will be sent to the client when the legacy
613+
GraphQL over WebSocket protocol is used.
614+
615+
```python
616+
from typing import Dict
617+
from strawberry.exceptions import ConnectionRejectionError
618+
from strawberry.channels import GraphQLWSConsumer
619+
620+
621+
class MyGraphQLWSConsumer(GraphQLWSConsumer):
622+
async def on_ws_connect(self, context: Dict[str, object]):
623+
connection_params = context["connection_params"]
624+
625+
if not isinstance(connection_params, dict):
626+
# Reject without a custom graphql-ws error payload
627+
raise ConnectionRejectionError()
628+
629+
if connection_params.get("password") != "secret":
630+
# Reject with a custom graphql-ws error payload
631+
raise ConnectionRejectionError({"reason": "Invalid password"})
632+
633+
if username := connection_params.get("username"):
634+
# Accept with a custom acknowledgment payload
635+
return {"message": f"Hello, {username}!"}
636+
637+
# Accept without a acknowledgment payload
638+
return await super().on_ws_connect(context)
639+
```
595640

596641
### Context
597642

docs/integrations/fastapi.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ following methods:
268268
- `def decode_json(self, data: Union[str, bytes]) -> object`
269269
- `def encode_json(self, data: object) -> str`
270270
- `async def render_graphql_ide(self, request: Request) -> HTMLResponse`
271+
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`
271272

272273
### process_result
273274

@@ -354,3 +355,47 @@ class MyGraphQLRouter(GraphQLRouter):
354355

355356
return HTMLResponse(custom_html)
356357
```
358+
359+
### on_ws_connect
360+
361+
By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
362+
or `graphql-transport-ws` connection is established. This is particularly useful
363+
for authentication and authorization. By default, all connections are accepted.
364+
365+
To manually accept a connection, return `strawberry.UNSET` or a connection
366+
acknowledgment payload. The acknowledgment payload will be sent to the client.
367+
368+
Note that the legacy protocol does not support `None`/`null` acknowledgment
369+
payloads, while the new protocol does. Our implementation will treat
370+
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
371+
legacy protocol.
372+
373+
To reject a connection, raise a `ConnectionRejectionError`. You can optionally
374+
provide a custom error payload that will be sent to the client when the legacy
375+
GraphQL over WebSocket protocol is used.
376+
377+
```python
378+
from typing import Dict
379+
from strawberry.exceptions import ConnectionRejectionError
380+
from strawberry.fastapi import GraphQLRouter
381+
382+
383+
class MyGraphQLRouter(GraphQLRouter):
384+
async def on_ws_connect(self, context: Dict[str, object]):
385+
connection_params = context["connection_params"]
386+
387+
if not isinstance(connection_params, dict):
388+
# Reject without a custom graphql-ws error payload
389+
raise ConnectionRejectionError()
390+
391+
if connection_params.get("password") != "secret":
392+
# Reject with a custom graphql-ws error payload
393+
raise ConnectionRejectionError({"reason": "Invalid password"})
394+
395+
if username := connection_params.get("username"):
396+
# Accept with a custom acknowledgment payload
397+
return {"message": f"Hello, {username}!"}
398+
399+
# Accept without a acknowledgment payload
400+
return await super().on_ws_connect(context)
401+
```

docs/integrations/litestar.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ extended by overriding any of the following methods:
327327
- `def decode_json(self, data: Union[str, bytes]) -> object`
328328
- `def encode_json(self, data: object) -> str`
329329
- `async def render_graphql_ide(self, request: Request) -> Response`
330+
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`
330331

331332
### process_result
332333

@@ -476,3 +477,63 @@ class MyGraphQLController(GraphQLController):
476477

477478
return Response(custom_html, media_type=MediaType.HTML)
478479
```
480+
481+
### on_ws_connect
482+
483+
By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
484+
or `graphql-transport-ws` connection is established. This is particularly useful
485+
for authentication and authorization. By default, all connections are accepted.
486+
487+
To manually accept a connection, return `strawberry.UNSET` or a connection
488+
acknowledgment payload. The acknowledgment payload will be sent to the client.
489+
490+
Note that the legacy protocol does not support `None`/`null` acknowledgment
491+
payloads, while the new protocol does. Our implementation will treat
492+
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
493+
legacy protocol.
494+
495+
To reject a connection, raise a `ConnectionRejectionError`. You can optionally
496+
provide a custom error payload that will be sent to the client when the legacy
497+
GraphQL over WebSocket protocol is used.
498+
499+
```python
500+
import strawberry
501+
from typing import Dict
502+
from strawberry.exceptions import ConnectionRejectionError
503+
from strawberry.litestar import make_graphql_controller
504+
505+
506+
@strawberry.type
507+
class Query:
508+
@strawberry.field
509+
def hello(self) -> str:
510+
return "world"
511+
512+
513+
schema = strawberry.Schema(Query)
514+
515+
GraphQLController = make_graphql_controller(
516+
schema,
517+
path="/graphql",
518+
)
519+
520+
521+
class MyGraphQLController(GraphQLController):
522+
async def on_ws_connect(self, context: Dict[str, object]):
523+
connection_params = context["connection_params"]
524+
525+
if not isinstance(connection_params, dict):
526+
# Reject without a custom graphql-ws error payload
527+
raise ConnectionRejectionError()
528+
529+
if connection_params.get("password") != "secret":
530+
# Reject with a custom graphql-ws error payload
531+
raise ConnectionRejectionError({"reason": "Invalid password"})
532+
533+
if username := connection_params.get("username"):
534+
# Accept with a custom acknowledgment payload
535+
return {"message": f"Hello, {username}!"}
536+
537+
# Accept without a acknowledgment payload
538+
return await super().on_ws_connect(context)
539+
```

strawberry/exceptions/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import cached_property
4-
from typing import TYPE_CHECKING, Optional, Set, Union
4+
from typing import TYPE_CHECKING, Dict, Optional, Set, Union
55

66
from graphql import GraphQLError
77

@@ -157,6 +157,13 @@ class StrawberryGraphQLError(GraphQLError):
157157
"""Use it when you want to override the graphql.GraphQLError in custom extensions."""
158158

159159

160+
class ConnectionRejectionError(Exception):
161+
"""Use it when you want to reject a WebSocket connection."""
162+
163+
def __init__(self, payload: Dict[str, object] = {}) -> None:
164+
self.payload = payload
165+
166+
160167
__all__ = [
161168
"StrawberryException",
162169
"UnableToFindExceptionSource",

strawberry/http/async_base_view.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from graphql import GraphQLError
2323

24-
from strawberry import UNSET
2524
from strawberry.exceptions import MissingQueryError
2625
from strawberry.file_uploads.utils import replace_placeholders_with_files
2726
from strawberry.http import (
@@ -39,6 +38,7 @@
3938
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
4039
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
4140
from strawberry.types.graphql import OperationType
41+
from strawberry.types.unset import UNSET, UnsetType
4242

4343
from .base import BaseView
4444
from .exceptions import HTTPException
@@ -279,6 +279,7 @@ async def run(
279279

280280
if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
281281
await self.graphql_transport_ws_handler_class(
282+
view=self,
282283
websocket=websocket,
283284
context=context,
284285
root_value=root_value,
@@ -288,6 +289,7 @@ async def run(
288289
).handle()
289290
elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
290291
await self.graphql_ws_handler_class(
292+
view=self,
291293
websocket=websocket,
292294
context=context,
293295
root_value=root_value,
@@ -476,5 +478,10 @@ async def process_result(
476478
) -> GraphQLHTTPResponse:
477479
return process_result(result)
478480

481+
async def on_ws_connect(
482+
self, context: Context
483+
) -> Union[UnsetType, None, Dict[str, object]]:
484+
return UNSET
485+
479486

480487
__all__ = ["AsyncBaseHTTPView"]

0 commit comments

Comments
 (0)