Skip to content

Commit

Permalink
Add default values to Context and RootValue type vars (#3732)
Browse files Browse the repository at this point in the history
* Add default values to Context and RootValue type vars

* Mypy

* Use stable version of mypy

* Remove cache from lint

* Fix mypy

* Fix lint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert

* Test

* Test

* Add release file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
patrick91 and pre-commit-ci[bot] authored Dec 20, 2024
1 parent 633c9bc commit 71ac2d6
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 221 deletions.
32 changes: 11 additions & 21 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,21 @@ jobs:

steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- run: pipx install coverage
- uses: actions/setup-python@v5
id: setup-python
with:
python-version: |
3.8
3.9
3.10
3.11
3.12
python-version: "3.12"
cache: "poetry"

- name: Pip and nox cache
id: cache
uses: actions/cache@v4
with:
path: |
~/.cache
~/.nox
.nox
key:
${{ runner.os }}-nox-lint-${{ env.pythonLocation }}-${{
hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}
restore-keys: |
${{ runner.os }}-nox-lint-${{ env.pythonLocation }}
- run: poetry install --with integrations
if: steps.setup-python.outputs.cache-hit != 'true'

- run: pip install poetry nox nox-poetry uv
- run: nox -r -t lint
- run: |
mkdir .mypy_cache
poetry run mypy --install-types --non-interactive --cache-dir=.mypy_cache/ --config-file mypy.ini
unit-tests-on-windows:
name: 🪟 Tests on Windows
Expand Down
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: patch

This release updates the Context and RootValue vars to have
a default value of `None`, this makes it easier to use the views
without having to pass in a value for these vars.
10 changes: 1 addition & 9 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def tests_typecheckers(session: Session) -> None:

session.install("pyright")
session.install("pydantic")
session.install("git+https://github.com/python/mypy.git#master")
session.install("mypy")

session.run(
"pytest",
Expand All @@ -181,11 +181,3 @@ def tests_cli(session: Session) -> None:
"tests/cli",
"-vv",
)


@session(name="Mypy", tags=["lint"])
def mypy(session: Session) -> None:
session.run_always("poetry", "install", "--with", "integrations", external=True)
session.install("mypy")

session.run("mypy", "--config-file", "mypy.ini")
405 changes: 246 additions & 159 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ poetry-plugin-export = "^1.6.0"
urllib3 = "<2"
graphlib_backport = {version = "*", python = "<3.9", optional = false}
inline-snapshot = "^0.10.1"
types-deprecated = "^1.2.15.20241117"
types-six = "^1.17.0.20241205"
types-pyyaml = "^6.0.12.20240917"
mypy = "^1.13.0"

[tool.poetry.group.integrations]
optional = true
Expand Down
23 changes: 16 additions & 7 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
overload,
Expand Down Expand Up @@ -116,11 +117,19 @@ class AsyncBaseHTTPView(
connection_init_wait_timeout: timedelta = timedelta(minutes=1)
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
websocket_adapter_class: Callable[
["AsyncBaseHTTPView", WebSocketRequest, WebSocketResponse],
[
"AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]",
WebSocketRequest,
WebSocketResponse,
],
AsyncWebSocketAdapter,
]
graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
graphql_ws_handler_class = BaseGraphQLWSHandler
graphql_transport_ws_handler_class: Type[
BaseGraphQLTransportWSHandler[Context, RootValue]
] = BaseGraphQLTransportWSHandler[Context, RootValue]
graphql_ws_handler_class: Type[BaseGraphQLWSHandler[Context, RootValue]] = (
BaseGraphQLWSHandler[Context, RootValue]
)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -281,8 +290,8 @@ async def run(
await self.graphql_transport_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
connection_init_wait_timeout=self.connection_init_wait_timeout,
Expand All @@ -291,8 +300,8 @@ async def run(
await self.graphql_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
keep_alive=self.keep_alive,
Expand Down
10 changes: 5 additions & 5 deletions strawberry/http/typevars.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import TypeVar
from typing_extensions import TypeVar

Request = TypeVar("Request", contravariant=True)
Response = TypeVar("Response")
SubResponse = TypeVar("SubResponse")
WebSocketRequest = TypeVar("WebSocketRequest")
WebSocketResponse = TypeVar("WebSocketResponse")
Context = TypeVar("Context")
RootValue = TypeVar("RootValue")
Context = TypeVar("Context", default=None)
RootValue = TypeVar("RootValue", default=None)


__all__ = [
"Context",
"Request",
"Response",
"RootValue",
"SubResponse",
"WebSocketRequest",
"WebSocketResponse",
"Context",
"RootValue",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
Generic,
List,
Optional,
cast,
Expand All @@ -20,6 +22,7 @@
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionInitMessage,
Expand All @@ -44,15 +47,15 @@
from strawberry.schema.subscribe import SubscriptionResult


class BaseGraphQLTransportWSHandler:
class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
connection_init_wait_timeout: timedelta,
Expand All @@ -68,7 +71,7 @@ def __init__(
self.connection_init_received = False
self.connection_acknowledged = False
self.connection_timed_out = False
self.operations: Dict[str, Operation] = {}
self.operations: Dict[str, Operation[Context, RootValue]] = {}
self.completed_tasks: List[asyncio.Task] = []

async def handle(self) -> None:
Expand Down Expand Up @@ -184,6 +187,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError:
Expand Down Expand Up @@ -250,7 +255,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
operation.task = asyncio.create_task(self.run_operation(operation))
self.operations[message["id"]] = operation

async def run_operation(self, operation: Operation) -> None:
async def run_operation(self, operation: Operation[Context, RootValue]) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()

Expand Down Expand Up @@ -334,7 +339,7 @@ async def reap_completed_tasks(self) -> None:
await task


class Operation:
class Operation(Generic[Context, RootValue]):
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = [
Expand All @@ -350,7 +355,7 @@ class Operation:

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
handler: BaseGraphQLTransportWSHandler[Context, RootValue],
id: str,
operation_type: OperationType,
query: str,
Expand Down
16 changes: 12 additions & 4 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generic,
Optional,
cast,
)

from strawberry.exceptions import ConnectionRejectionError
from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage,
ConnectionTerminateMessage,
Expand All @@ -29,13 +32,16 @@
from strawberry.schema import BaseSchema


class BaseGraphQLWSHandler:
class BaseGraphQLWSHandler(Generic[Context, RootValue]):
context: Context
root_value: RootValue

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
keep_alive: bool,
Expand Down Expand Up @@ -100,6 +106,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError as e:
Expand Down
16 changes: 8 additions & 8 deletions tests/fastapi/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, AsyncGenerator, Dict
from typing import AsyncGenerator, Dict

import pytest

Expand Down Expand Up @@ -47,7 +47,7 @@ def get_context(custom_context: CustomContext = Depends(custom_context_dependenc

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_context(custom_context: CustomContext = Depends()):

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> Dict[str, st

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand All @@ -138,7 +138,7 @@ def abc(self, info: strawberry.Info) -> str:

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[None, None](schema, context_getter=None)
graphql_app = GraphQLRouter(schema, context_getter=None)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> str:

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -213,7 +213,7 @@ def get_context(context: Context = Depends()) -> Context:

app = FastAPI()
schema = strawberry.Schema(query=Query, subscription=Subscription)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")
test_client = TestClient(app)

Expand Down Expand Up @@ -287,7 +287,7 @@ def get_context(context: Context = Depends()) -> Context:

app = FastAPI()
schema = strawberry.Schema(query=Query, subscription=Subscription)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")
test_client = TestClient(app)

Expand Down

0 comments on commit 71ac2d6

Please sign in to comment.