Skip to content

Commit

Permalink
Type internal test clients stricter (#3745)
Browse files Browse the repository at this point in the history
* Type internal test clients stricter

* Remove unneeded exception handling
  • Loading branch information
DoctorJohn authored Jan 3, 2025
1 parent e78f8c6 commit 15044cd
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 122 deletions.
3 changes: 0 additions & 3 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@


class BaseGraphQLWSHandler(Generic[Context, RootValue]):
context: Context
root_value: RootValue

def __init__(
self,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
Expand Down
14 changes: 7 additions & 7 deletions tests/http/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import contextlib
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from io import BytesIO
from typing import Any, Optional
from typing import Any, Optional, Union
from typing_extensions import Literal

from aiohttp import web
Expand Down Expand Up @@ -37,7 +37,7 @@ class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]):
graphql_ws_handler_class = DebuggableGraphQLWSHandler

async def get_context(
self, request: web.Request, response: web.StreamResponse
self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
) -> dict[str, object]:
context = await super().get_context(request, response)

Expand Down Expand Up @@ -95,7 +95,7 @@ def create_app(self, **kwargs: Any) -> None:
async def _graphql_request(
self,
method: Literal["get", "post"],
query: Optional[str] = None,
query: str,
variables: Optional[dict[str, object]] = None,
files: Optional[dict[str, BytesIO]] = None,
headers: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -163,7 +163,7 @@ async def post(
return Response(
status_code=response.status,
data=(await response.text()).encode(),
headers=response.headers,
headers=dict(response.headers),
)

@contextlib.asynccontextmanager
Expand All @@ -186,7 +186,7 @@ def __init__(self, ws: ClientWebSocketResponse):
async def send_text(self, payload: str) -> None:
await self.ws.send_str(payload)

async def send_json(self, payload: dict[str, Any]) -> None:
async def send_json(self, payload: Mapping[str, object]) -> None:
await self.ws.send_json(payload)

async def send_bytes(self, payload: bytes) -> None:
Expand All @@ -197,7 +197,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message:
self._reason = m.extra
return Message(type=m.type, data=m.data, extra=m.extra)

async def receive_json(self, timeout: Optional[float] = None) -> Any:
async def receive_json(self, timeout: Optional[float] = None) -> object:
m = await self.ws.receive(timeout)
assert m.type == WSMsgType.TEXT
return json.loads(m.data)
Expand Down
24 changes: 7 additions & 17 deletions tests/http/clients/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import contextlib
import json
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Mapping
from io import BytesIO
from typing import Any, Optional, Union
from typing_extensions import Literal

from starlette.requests import Request
from starlette.responses import Response as StarletteResponse
from starlette.testclient import TestClient, WebSocketTestSession
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket

from strawberry.asgi import GraphQL as BaseGraphQLView
from strawberry.http import GraphQLHTTPResponse
Expand Down Expand Up @@ -86,7 +86,7 @@ def create_app(self, **kwargs: Any) -> None:
async def _graphql_request(
self,
method: Literal["get", "post"],
query: Optional[str] = None,
query: str,
variables: Optional[dict[str, object]] = None,
files: Optional[dict[str, BytesIO]] = None,
headers: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def post(
return Response(
status_code=response.status_code,
data=response.content,
headers=response.headers,
headers=dict(response.headers),
)

@contextlib.asynccontextmanager
Expand All @@ -162,13 +162,8 @@ async def ws_connect(
*,
protocols: list[str],
) -> AsyncGenerator[WebSocketClient, None]:
try:
with self.client.websocket_connect(url, protocols) as ws:
yield AsgiWebSocketClient(ws)
except WebSocketDisconnect as error:
ws = AsgiWebSocketClient(None)
ws.handle_disconnect(error)
yield ws
with self.client.websocket_connect(url, protocols) as ws:
yield AsgiWebSocketClient(ws)


class AsgiWebSocketClient(WebSocketClient):
Expand All @@ -178,15 +173,10 @@ def __init__(self, ws: WebSocketTestSession):
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None

def handle_disconnect(self, exc: WebSocketDisconnect) -> None:
self._closed = True
self._close_code = exc.code
self._close_reason = exc.reason

async def send_text(self, payload: str) -> None:
self.ws.send_text(payload)

async def send_json(self, payload: dict[str, Any]) -> None:
async def send_json(self, payload: Mapping[str, object]) -> None:
self.ws.send_json(payload)

async def send_bytes(self, payload: bytes) -> None:
Expand Down
15 changes: 10 additions & 5 deletions tests/http/clients/async_django.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from collections.abc import AsyncIterable

from django.core.exceptions import BadRequest, SuspiciousOperation
from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse
from django.test.client import RequestFactory

from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView
from strawberry.http import GraphQLHTTPResponse
Expand All @@ -14,14 +15,16 @@
from .django import DjangoHttpClient


class AsyncGraphQLView(BaseAsyncGraphQLView):
class AsyncGraphQLView(BaseAsyncGraphQLView[dict[str, object], object]):
result_override: ResultOverrideFunction = None

async def get_root_value(self, request: HttpRequest) -> Query:
await super().get_root_value(request) # for coverage
return Query()

async def get_context(self, request: HttpRequest, response: HttpResponse) -> object:
async def get_context(
self, request: HttpRequest, response: HttpResponse
) -> dict[str, object]:
context = {"request": request, "response": response}

return get_context(context)
Expand All @@ -36,7 +39,7 @@ async def process_result(


class AsyncDjangoHttpClient(DjangoHttpClient):
async def _do_request(self, request: RequestFactory) -> Response:
async def _do_request(self, request: HttpRequest) -> Response:
view = AsyncGraphQLView.as_view(
schema=schema,
graphiql=self.graphiql,
Expand All @@ -56,14 +59,16 @@ async def _do_request(self, request: RequestFactory) -> Response:
data=e.args[0].encode(),
headers={},
)

data = (
response.streaming_content
if isinstance(response, StreamingHttpResponse)
and isinstance(response.streaming_content, AsyncIterable)
else response.content
)

return Response(
status_code=response.status_code,
data=data,
headers=response.headers,
headers=dict(response.headers),
)
4 changes: 2 additions & 2 deletions tests/http/clients/async_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from .flask import FlaskHttpClient


class GraphQLView(BaseAsyncGraphQLView):
class GraphQLView(BaseAsyncGraphQLView[dict[str, object], object]):
methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"]

result_override: ResultOverrideFunction = None

def __init__(self, *args: str, **kwargs: Any):
def __init__(self, *args: Any, **kwargs: Any):
self.result_override = kwargs.pop("result_override")
super().__init__(*args, **kwargs)

Expand Down
12 changes: 7 additions & 5 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
async def _graphql_request(
self,
method: Literal["get", "post"],
query: Optional[str] = None,
query: str,
variables: Optional[dict[str, object]] = None,
files: Optional[dict[str, BytesIO]] = None,
headers: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -141,7 +141,7 @@ async def post(

async def query(
self,
query: Optional[str] = None,
query: str,
method: Literal["get", "post"] = "post",
variables: Optional[dict[str, object]] = None,
files: Optional[dict[str, BytesIO]] = None,
Expand Down Expand Up @@ -302,7 +302,9 @@ async def send_legacy_message(self, message: OperationMessage) -> None:
await self.send_json(message)


class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
class DebuggableGraphQLTransportWSHandler(
BaseGraphQLTransportWSHandler[dict[str, object], object]
):
def on_init(self) -> None:
"""This method can be patched by unit tests to get the instance of the
transport handler when it is initialized.
Expand Down Expand Up @@ -330,10 +332,10 @@ def context(self, value):
self.original_context = value


class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler):
class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler[dict[str, object], object]):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.original_context = self.context
self.original_context = kwargs.get("context", {})

def get_tasks(self) -> list:
return list(self.tasks.values())
Expand Down
7 changes: 4 additions & 3 deletions tests/http/clients/chalice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .base import JSON, HttpClient, Response, ResultOverrideFunction


class GraphQLView(BaseGraphQLView):
class GraphQLView(BaseGraphQLView[dict[str, object], object]):
result_override: ResultOverrideFunction = None

def get_root_value(self, request: ChaliceRequest) -> Query:
Expand All @@ -29,7 +29,7 @@ def get_root_value(self, request: ChaliceRequest) -> Query:

def get_context(
self, request: ChaliceRequest, response: TemporalResponse
) -> object:
) -> dict[str, object]:
context = super().get_context(request, response)

return get_context(context)
Expand Down Expand Up @@ -66,12 +66,13 @@ def __init__(
"/graphql", methods=["GET", "POST"], content_types=["application/json"]
)
def handle_graphql():
assert self.app.current_request is not None
return view.execute_request(self.app.current_request)

async def _graphql_request(
self,
method: Literal["get", "post"],
query: Optional[str] = None,
query: str,
variables: Optional[dict[str, object]] = None,
files: Optional[dict[str, BytesIO]] = None,
headers: Optional[dict[str, str]] = None,
Expand Down
Loading

0 comments on commit 15044cd

Please sign in to comment.