Skip to content

Commit 15044cd

Browse files
authored
Type internal test clients stricter (#3745)
* Type internal test clients stricter * Remove unneeded exception handling
1 parent e78f8c6 commit 15044cd

File tree

14 files changed

+129
-122
lines changed

14 files changed

+129
-122
lines changed

strawberry/subscriptions/protocols/graphql_ws/handlers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434

3535

3636
class BaseGraphQLWSHandler(Generic[Context, RootValue]):
37-
context: Context
38-
root_value: RootValue
39-
4037
def __init__(
4138
self,
4239
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],

tests/http/clients/aiohttp.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import contextlib
44
import json
5-
from collections.abc import AsyncGenerator
5+
from collections.abc import AsyncGenerator, Mapping
66
from io import BytesIO
7-
from typing import Any, Optional
7+
from typing import Any, Optional, Union
88
from typing_extensions import Literal
99

1010
from aiohttp import web
@@ -37,7 +37,7 @@ class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]):
3737
graphql_ws_handler_class = DebuggableGraphQLWSHandler
3838

3939
async def get_context(
40-
self, request: web.Request, response: web.StreamResponse
40+
self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
4141
) -> dict[str, object]:
4242
context = await super().get_context(request, response)
4343

@@ -95,7 +95,7 @@ def create_app(self, **kwargs: Any) -> None:
9595
async def _graphql_request(
9696
self,
9797
method: Literal["get", "post"],
98-
query: Optional[str] = None,
98+
query: str,
9999
variables: Optional[dict[str, object]] = None,
100100
files: Optional[dict[str, BytesIO]] = None,
101101
headers: Optional[dict[str, str]] = None,
@@ -163,7 +163,7 @@ async def post(
163163
return Response(
164164
status_code=response.status,
165165
data=(await response.text()).encode(),
166-
headers=response.headers,
166+
headers=dict(response.headers),
167167
)
168168

169169
@contextlib.asynccontextmanager
@@ -186,7 +186,7 @@ def __init__(self, ws: ClientWebSocketResponse):
186186
async def send_text(self, payload: str) -> None:
187187
await self.ws.send_str(payload)
188188

189-
async def send_json(self, payload: dict[str, Any]) -> None:
189+
async def send_json(self, payload: Mapping[str, object]) -> None:
190190
await self.ws.send_json(payload)
191191

192192
async def send_bytes(self, payload: bytes) -> None:
@@ -197,7 +197,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message:
197197
self._reason = m.extra
198198
return Message(type=m.type, data=m.data, extra=m.extra)
199199

200-
async def receive_json(self, timeout: Optional[float] = None) -> Any:
200+
async def receive_json(self, timeout: Optional[float] = None) -> object:
201201
m = await self.ws.receive(timeout)
202202
assert m.type == WSMsgType.TEXT
203203
return json.loads(m.data)

tests/http/clients/asgi.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import contextlib
44
import json
5-
from collections.abc import AsyncGenerator
5+
from collections.abc import AsyncGenerator, Mapping
66
from io import BytesIO
77
from typing import Any, Optional, Union
88
from typing_extensions import Literal
99

1010
from starlette.requests import Request
1111
from starlette.responses import Response as StarletteResponse
1212
from starlette.testclient import TestClient, WebSocketTestSession
13-
from starlette.websockets import WebSocket, WebSocketDisconnect
13+
from starlette.websockets import WebSocket
1414

1515
from strawberry.asgi import GraphQL as BaseGraphQLView
1616
from strawberry.http import GraphQLHTTPResponse
@@ -86,7 +86,7 @@ def create_app(self, **kwargs: Any) -> None:
8686
async def _graphql_request(
8787
self,
8888
method: Literal["get", "post"],
89-
query: Optional[str] = None,
89+
query: str,
9090
variables: Optional[dict[str, object]] = None,
9191
files: Optional[dict[str, BytesIO]] = None,
9292
headers: Optional[dict[str, str]] = None,
@@ -152,7 +152,7 @@ async def post(
152152
return Response(
153153
status_code=response.status_code,
154154
data=response.content,
155-
headers=response.headers,
155+
headers=dict(response.headers),
156156
)
157157

158158
@contextlib.asynccontextmanager
@@ -162,13 +162,8 @@ async def ws_connect(
162162
*,
163163
protocols: list[str],
164164
) -> AsyncGenerator[WebSocketClient, None]:
165-
try:
166-
with self.client.websocket_connect(url, protocols) as ws:
167-
yield AsgiWebSocketClient(ws)
168-
except WebSocketDisconnect as error:
169-
ws = AsgiWebSocketClient(None)
170-
ws.handle_disconnect(error)
171-
yield ws
165+
with self.client.websocket_connect(url, protocols) as ws:
166+
yield AsgiWebSocketClient(ws)
172167

173168

174169
class AsgiWebSocketClient(WebSocketClient):
@@ -178,15 +173,10 @@ def __init__(self, ws: WebSocketTestSession):
178173
self._close_code: Optional[int] = None
179174
self._close_reason: Optional[str] = None
180175

181-
def handle_disconnect(self, exc: WebSocketDisconnect) -> None:
182-
self._closed = True
183-
self._close_code = exc.code
184-
self._close_reason = exc.reason
185-
186176
async def send_text(self, payload: str) -> None:
187177
self.ws.send_text(payload)
188178

189-
async def send_json(self, payload: dict[str, Any]) -> None:
179+
async def send_json(self, payload: Mapping[str, object]) -> None:
190180
self.ws.send_json(payload)
191181

192182
async def send_bytes(self, payload: bytes) -> None:

tests/http/clients/async_django.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncIterable
4+
35
from django.core.exceptions import BadRequest, SuspiciousOperation
46
from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse
5-
from django.test.client import RequestFactory
67

78
from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView
89
from strawberry.http import GraphQLHTTPResponse
@@ -14,14 +15,16 @@
1415
from .django import DjangoHttpClient
1516

1617

17-
class AsyncGraphQLView(BaseAsyncGraphQLView):
18+
class AsyncGraphQLView(BaseAsyncGraphQLView[dict[str, object], object]):
1819
result_override: ResultOverrideFunction = None
1920

2021
async def get_root_value(self, request: HttpRequest) -> Query:
2122
await super().get_root_value(request) # for coverage
2223
return Query()
2324

24-
async def get_context(self, request: HttpRequest, response: HttpResponse) -> object:
25+
async def get_context(
26+
self, request: HttpRequest, response: HttpResponse
27+
) -> dict[str, object]:
2528
context = {"request": request, "response": response}
2629

2730
return get_context(context)
@@ -36,7 +39,7 @@ async def process_result(
3639

3740

3841
class AsyncDjangoHttpClient(DjangoHttpClient):
39-
async def _do_request(self, request: RequestFactory) -> Response:
42+
async def _do_request(self, request: HttpRequest) -> Response:
4043
view = AsyncGraphQLView.as_view(
4144
schema=schema,
4245
graphiql=self.graphiql,
@@ -56,14 +59,16 @@ async def _do_request(self, request: RequestFactory) -> Response:
5659
data=e.args[0].encode(),
5760
headers={},
5861
)
62+
5963
data = (
6064
response.streaming_content
6165
if isinstance(response, StreamingHttpResponse)
66+
and isinstance(response.streaming_content, AsyncIterable)
6267
else response.content
6368
)
6469

6570
return Response(
6671
status_code=response.status_code,
6772
data=data,
68-
headers=response.headers,
73+
headers=dict(response.headers),
6974
)

tests/http/clients/async_flask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from .flask import FlaskHttpClient
1717

1818

19-
class GraphQLView(BaseAsyncGraphQLView):
19+
class GraphQLView(BaseAsyncGraphQLView[dict[str, object], object]):
2020
methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"]
2121

2222
result_override: ResultOverrideFunction = None
2323

24-
def __init__(self, *args: str, **kwargs: Any):
24+
def __init__(self, *args: Any, **kwargs: Any):
2525
self.result_override = kwargs.pop("result_override")
2626
super().__init__(*args, **kwargs)
2727

tests/http/clients/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
async def _graphql_request(
109109
self,
110110
method: Literal["get", "post"],
111-
query: Optional[str] = None,
111+
query: str,
112112
variables: Optional[dict[str, object]] = None,
113113
files: Optional[dict[str, BytesIO]] = None,
114114
headers: Optional[dict[str, str]] = None,
@@ -141,7 +141,7 @@ async def post(
141141

142142
async def query(
143143
self,
144-
query: Optional[str] = None,
144+
query: str,
145145
method: Literal["get", "post"] = "post",
146146
variables: Optional[dict[str, object]] = None,
147147
files: Optional[dict[str, BytesIO]] = None,
@@ -302,7 +302,9 @@ async def send_legacy_message(self, message: OperationMessage) -> None:
302302
await self.send_json(message)
303303

304304

305-
class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
305+
class DebuggableGraphQLTransportWSHandler(
306+
BaseGraphQLTransportWSHandler[dict[str, object], object]
307+
):
306308
def on_init(self) -> None:
307309
"""This method can be patched by unit tests to get the instance of the
308310
transport handler when it is initialized.
@@ -330,10 +332,10 @@ def context(self, value):
330332
self.original_context = value
331333

332334

333-
class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler):
335+
class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler[dict[str, object], object]):
334336
def __init__(self, *args: Any, **kwargs: Any):
335337
super().__init__(*args, **kwargs)
336-
self.original_context = self.context
338+
self.original_context = kwargs.get("context", {})
337339

338340
def get_tasks(self) -> list:
339341
return list(self.tasks.values())

tests/http/clients/chalice.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .base import JSON, HttpClient, Response, ResultOverrideFunction
2121

2222

23-
class GraphQLView(BaseGraphQLView):
23+
class GraphQLView(BaseGraphQLView[dict[str, object], object]):
2424
result_override: ResultOverrideFunction = None
2525

2626
def get_root_value(self, request: ChaliceRequest) -> Query:
@@ -29,7 +29,7 @@ def get_root_value(self, request: ChaliceRequest) -> Query:
2929

3030
def get_context(
3131
self, request: ChaliceRequest, response: TemporalResponse
32-
) -> object:
32+
) -> dict[str, object]:
3333
context = super().get_context(request, response)
3434

3535
return get_context(context)
@@ -66,12 +66,13 @@ def __init__(
6666
"/graphql", methods=["GET", "POST"], content_types=["application/json"]
6767
)
6868
def handle_graphql():
69+
assert self.app.current_request is not None
6970
return view.execute_request(self.app.current_request)
7071

7172
async def _graphql_request(
7273
self,
7374
method: Literal["get", "post"],
74-
query: Optional[str] = None,
75+
query: str,
7576
variables: Optional[dict[str, object]] = None,
7677
files: Optional[dict[str, BytesIO]] = None,
7778
headers: Optional[dict[str, str]] = None,

0 commit comments

Comments
 (0)