Skip to content

Commit

Permalink
Enable conditional acknowledgment of WS connections. (#3720)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
DoctorJohn and patrick91 authored Dec 13, 2024
1 parent 45f4e50 commit e105975
Show file tree
Hide file tree
Showing 22 changed files with 609 additions and 34 deletions.
1 change: 1 addition & 0 deletions .alexrc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"executed",
"executes",
"execution",
"reject",
"special",
"primitive",
"invalid",
Expand Down
32 changes: 32 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Release type: minor

This release adds a new `on_ws_connect` method to all HTTP view integrations.
The method is called when a `graphql-transport-ws` or `graphql-ws` connection is
established and can be used to customize the connection acknowledgment behavior.

This is particularly useful for authentication, authorization, and sending a
custom acknowledgment payload to clients when a connection is accepted. For
example:

```python
class MyGraphQLView(GraphQLView):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret:
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgement payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgement payload
return await super().on_ws_connect(context)
```

Take a look at our documentation to learn more.
45 changes: 45 additions & 0 deletions docs/integrations/aiohttp.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ methods:
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def render_graphql_ide(self, request: Request) -> Response`
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`

### get_context

Expand Down Expand Up @@ -199,3 +200,47 @@ class MyGraphQLView(GraphQLView):

return Response(text=custom_html, content_type="text/html")
```

### on_ws_connect

By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
or `graphql-transport-ws` connection is established. This is particularly useful
for authentication and authorization. By default, all connections are accepted.

To manually accept a connection, return `strawberry.UNSET` or a connection
acknowledgment payload. The acknowledgment payload will be sent to the client.

Note that the legacy protocol does not support `None`/`null` acknowledgment
payloads, while the new protocol does. Our implementation will treat
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
legacy protocol.

To reject a connection, raise a `ConnectionRejectionError`. You can optionally
provide a custom error payload that will be sent to the client when the legacy
GraphQL over WebSocket protocol is used.

```python
from typing import Dict
from strawberry.exceptions import ConnectionRejectionError
from strawberry.aiohttp.views import GraphQLView


class MyGraphQLView(GraphQLView):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret":
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgment payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgment payload
return await super().on_ws_connect(context)
```
45 changes: 45 additions & 0 deletions docs/integrations/asgi.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ methods:
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def render_graphql_ide(self, request: Request) -> Response`
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`

### get_context

Expand Down Expand Up @@ -241,3 +242,47 @@ class MyGraphQL(GraphQL):

return HTMLResponse(custom_html)
```

### on_ws_connect

By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
or `graphql-transport-ws` connection is established. This is particularly useful
for authentication and authorization. By default, all connections are accepted.

To manually accept a connection, return `strawberry.UNSET` or a connection
acknowledgment payload. The acknowledgment payload will be sent to the client.

Note that the legacy protocol does not support `None`/`null` acknowledgment
payloads, while the new protocol does. Our implementation will treat
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
legacy protocol.

To reject a connection, raise a `ConnectionRejectionError`. You can optionally
provide a custom error payload that will be sent to the client when the legacy
GraphQL over WebSocket protocol is used.

```python
from typing import Dict
from strawberry.exceptions import ConnectionRejectionError
from strawberry.asgi import GraphQL


class MyGraphQL(GraphQL):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret":
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgment payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgment payload
return await super().on_ws_connect(context)
```
45 changes: 45 additions & 0 deletions docs/integrations/channels.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,51 @@ following methods:
- `async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]`
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`

### on_ws_connect

By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
or `graphql-transport-ws` connection is established. This is particularly useful
for authentication and authorization. By default, all connections are accepted.

To manually accept a connection, return `strawberry.UNSET` or a connection
acknowledgment payload. The acknowledgment payload will be sent to the client.

Note that the legacy protocol does not support `None`/`null` acknowledgment
payloads, while the new protocol does. Our implementation will treat
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
legacy protocol.

To reject a connection, raise a `ConnectionRejectionError`. You can optionally
provide a custom error payload that will be sent to the client when the legacy
GraphQL over WebSocket protocol is used.

```python
from typing import Dict
from strawberry.exceptions import ConnectionRejectionError
from strawberry.channels import GraphQLWSConsumer


class MyGraphQLWSConsumer(GraphQLWSConsumer):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret":
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgment payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgment payload
return await super().on_ws_connect(context)
```

### Context

Expand Down
45 changes: 45 additions & 0 deletions docs/integrations/fastapi.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ following methods:
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def render_graphql_ide(self, request: Request) -> HTMLResponse`
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`

### process_result

Expand Down Expand Up @@ -354,3 +355,47 @@ class MyGraphQLRouter(GraphQLRouter):

return HTMLResponse(custom_html)
```

### on_ws_connect

By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
or `graphql-transport-ws` connection is established. This is particularly useful
for authentication and authorization. By default, all connections are accepted.

To manually accept a connection, return `strawberry.UNSET` or a connection
acknowledgment payload. The acknowledgment payload will be sent to the client.

Note that the legacy protocol does not support `None`/`null` acknowledgment
payloads, while the new protocol does. Our implementation will treat
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
legacy protocol.

To reject a connection, raise a `ConnectionRejectionError`. You can optionally
provide a custom error payload that will be sent to the client when the legacy
GraphQL over WebSocket protocol is used.

```python
from typing import Dict
from strawberry.exceptions import ConnectionRejectionError
from strawberry.fastapi import GraphQLRouter


class MyGraphQLRouter(GraphQLRouter):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret":
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgment payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgment payload
return await super().on_ws_connect(context)
```
61 changes: 61 additions & 0 deletions docs/integrations/litestar.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ extended by overriding any of the following methods:
- `def decode_json(self, data: Union[str, bytes]) -> object`
- `def encode_json(self, data: object) -> str`
- `async def render_graphql_ide(self, request: Request) -> Response`
- `async def on_ws_connect(self, context: Context) -> Union[UnsetType, None, Dict[str, object]]`

### process_result

Expand Down Expand Up @@ -476,3 +477,63 @@ class MyGraphQLController(GraphQLController):

return Response(custom_html, media_type=MediaType.HTML)
```

### on_ws_connect

By overriding `on_ws_connect` you can customize the behavior when a `graphql-ws`
or `graphql-transport-ws` connection is established. This is particularly useful
for authentication and authorization. By default, all connections are accepted.

To manually accept a connection, return `strawberry.UNSET` or a connection
acknowledgment payload. The acknowledgment payload will be sent to the client.

Note that the legacy protocol does not support `None`/`null` acknowledgment
payloads, while the new protocol does. Our implementation will treat
`None`/`null` payloads the same as `strawberry.UNSET` in the context of the
legacy protocol.

To reject a connection, raise a `ConnectionRejectionError`. You can optionally
provide a custom error payload that will be sent to the client when the legacy
GraphQL over WebSocket protocol is used.

```python
import strawberry
from typing import Dict
from strawberry.exceptions import ConnectionRejectionError
from strawberry.litestar import make_graphql_controller


@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "world"


schema = strawberry.Schema(Query)

GraphQLController = make_graphql_controller(
schema,
path="/graphql",
)


class MyGraphQLController(GraphQLController):
async def on_ws_connect(self, context: Dict[str, object]):
connection_params = context["connection_params"]

if not isinstance(connection_params, dict):
# Reject without a custom graphql-ws error payload
raise ConnectionRejectionError()

if connection_params.get("password") != "secret":
# Reject with a custom graphql-ws error payload
raise ConnectionRejectionError({"reason": "Invalid password"})

if username := connection_params.get("username"):
# Accept with a custom acknowledgment payload
return {"message": f"Hello, {username}!"}

# Accept without a acknowledgment payload
return await super().on_ws_connect(context)
```
9 changes: 8 additions & 1 deletion strawberry/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Optional, Set, Union
from typing import TYPE_CHECKING, Dict, Optional, Set, Union

from graphql import GraphQLError

Expand Down Expand Up @@ -157,6 +157,13 @@ class StrawberryGraphQLError(GraphQLError):
"""Use it when you want to override the graphql.GraphQLError in custom extensions."""


class ConnectionRejectionError(Exception):
"""Use it when you want to reject a WebSocket connection."""

def __init__(self, payload: Dict[str, object] = {}) -> None:
self.payload = payload


__all__ = [
"StrawberryException",
"UnableToFindExceptionSource",
Expand Down
9 changes: 8 additions & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from graphql import GraphQLError

from strawberry import UNSET
from strawberry.exceptions import MissingQueryError
from strawberry.file_uploads.utils import replace_placeholders_with_files
from strawberry.http import (
Expand All @@ -39,6 +38,7 @@
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
from strawberry.types.graphql import OperationType
from strawberry.types.unset import UNSET, UnsetType

from .base import BaseView
from .exceptions import HTTPException
Expand Down Expand Up @@ -279,6 +279,7 @@ async def run(

if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
await self.graphql_transport_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
Expand All @@ -288,6 +289,7 @@ async def run(
).handle()
elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
await self.graphql_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
Expand Down Expand Up @@ -476,5 +478,10 @@ async def process_result(
) -> GraphQLHTTPResponse:
return process_result(result)

async def on_ws_connect(
self, context: Context
) -> Union[UnsetType, None, Dict[str, object]]:
return UNSET


__all__ = ["AsyncBaseHTTPView"]
Loading

0 comments on commit e105975

Please sign in to comment.