Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions aiohttp/_websocket/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ def json(
return loads(self.data)


class WSMessageTextBytes(NamedTuple):
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""

data: bytes
size: int
extra: str | None = None
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT

def json(
self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads
) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


class WSMessageBinary(NamedTuple):
data: bytes
size: int
Expand Down Expand Up @@ -117,6 +132,7 @@ class WSMessageError(NamedTuple):
WSMessage = Union[
WSMessageContinuation,
WSMessageText,
WSMessageTextBytes,
WSMessageBinary,
WSMessagePing,
WSMessagePong,
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cdef object TUPLE_NEW
cdef object WSMsgType

cdef object WSMessageText
cdef object WSMessageTextBytes
cdef object WSMessageBinary
cdef object WSMessagePing
cdef object WSMessagePong
Expand Down Expand Up @@ -66,6 +67,7 @@ cdef class WebSocketReader:

cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef bint _decode_text

cdef Exception _exc
cdef bytearray _partial
Expand Down
38 changes: 25 additions & 13 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
WSMessagePing,
WSMessagePong,
WSMessageText,
WSMessageTextBytes,
WSMsgType,
)

Expand Down Expand Up @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage:

class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
self,
queue: WebSocketDataQueue,
max_msg_size: int,
compress: bool = True,
decode_text: bool = True,
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._decode_text = decode_text

self._exc: Exception | None = None
self._partial = bytearray()
Expand Down Expand Up @@ -270,18 +276,24 @@ def _handle_frame(

size = len(payload_merged)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
if self._decode_text:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT))
else:
# Keep as bytes for performance (e.g., for orjson parsing)
msg = TUPLE_NEW(
WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT)
)
else:
msg = TUPLE_NEW(
WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY)
Expand Down
7 changes: 6 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
decode_text: bool = True,
) -> "_WSRequestContextManager":
"""Initiate websocket connection."""
return _WSRequestContextManager(
Expand All @@ -911,6 +912,7 @@ def ws_connect(
proxy_headers=proxy_headers,
compress=compress,
max_msg_size=max_msg_size,
decode_text=decode_text,
)
)

Expand All @@ -936,6 +938,7 @@ async def _ws_connect(
proxy_headers: LooseHeaders | None = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
decode_text: bool = True,
) -> ClientWebSocketResponse:
if timeout is not sentinel:
if isinstance(timeout, ClientWSTimeout):
Expand Down Expand Up @@ -1098,7 +1101,9 @@ async def _ws_connect(
transport = conn.transport
assert transport is not None
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
)
writer = WebSocketWriter(
conn_proto,
transport,
Expand Down
10 changes: 9 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,20 @@ async def receive(self, timeout: float | None = None) -> WSMessage:
return msg

async def receive_str(self, *, timeout: float | None = None) -> str:
"""Receive TEXT message.

Returns str when decode_text=True (default), bytes when decode_text=False.

Note: The return type annotation is kept as str for backwards compatibility,
but this method will return bytes when the WebSocket connection was created
with decode_text=False.
"""
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return msg.data
return msg.data # type: ignore[return-value]

async def receive_bytes(self, *, timeout: float | None = None) -> bytes:
msg = await self.receive(timeout)
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WSMessagePing,
WSMessagePong,
WSMessageText,
WSMessageTextBytes,
WSMsgType,
)
from ._websocket.reader import WebSocketReader
Expand Down Expand Up @@ -48,6 +49,7 @@
"WSMessagePong",
"WSMessageBinary",
"WSMessageText",
"WSMessageTextBytes",
"WSMessagePing",
"WSMessageContinuation",
)
17 changes: 15 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
compress: bool = True,
max_msg_size: int = 4 * 1024 * 1024,
writer_limit: int = DEFAULT_LIMIT,
decode_text: bool = True,
) -> None:
super().__init__(status=101)
self._protocols = protocols
Expand All @@ -108,6 +109,7 @@ def __init__(
self._compress: bool | int = compress
self._max_msg_size = max_msg_size
self._writer_limit = writer_limit
self._decode_text = decode_text

def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
Expand Down Expand Up @@ -341,7 +343,10 @@ def _post_start(
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
request.protocol.set_parser(
WebSocketReader(
self._reader, self._max_msg_size, compress=bool(self._compress)
self._reader,
self._max_msg_size,
compress=bool(self._compress),
decode_text=self._decode_text,
)
)
# disable HTTP keepalive for WebSocket
Expand Down Expand Up @@ -589,12 +594,20 @@ async def receive(self, timeout: float | None = None) -> WSMessage:
return msg

async def receive_str(self, *, timeout: float | None = None) -> str:
"""Receive TEXT message.

Returns str when decode_text=True (default), bytes when decode_text=False.

Note: The return type annotation is kept as str for backwards compatibility,
but this method will return bytes when the WebSocket connection was created
with decode_text=False.
"""
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return msg.data
return msg.data # type: ignore[return-value]

async def receive_bytes(self, *, timeout: float | None = None) -> bytes:
msg = await self.receive(timeout)
Expand Down
7 changes: 6 additions & 1 deletion examples/server_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# server_simple.py
from typing import TYPE_CHECKING

from aiohttp import web


Expand All @@ -14,7 +16,10 @@ async def wshandle(request: web.Request) -> web.StreamResponse:

async for msg in ws:
if msg.type is web.WSMsgType.TEXT:
await ws.send_str(f"Hello, {msg.data}")
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(f"Hello, {data}")
elif msg.type is web.WSMsgType.BINARY:
await ws.send_bytes(msg.data)
elif msg.type is web.WSMsgType.CLOSE:
Expand Down
6 changes: 5 additions & 1 deletion examples/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any

import os
from typing import TYPE_CHECKING

from aiohttp import web

Expand All @@ -32,9 +33,12 @@ async def wshandler(request: web.Request) -> web.WebSocketResponse | web.Respons

async for msg in resp:
if msg.type is web.WSMsgType.TEXT:
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
for ws in request.app[sockets]:
if ws is not resp:
await ws.send_str(msg.data)
await ws.send_str(data)
else:
return resp
return resp
Expand Down
6 changes: 5 additions & 1 deletion tests/autobahn/client/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import asyncio
from typing import TYPE_CHECKING

import aiohttp

Expand All @@ -19,7 +20,10 @@ async def client(url: str, name: str) -> None:
async with session.ws_connect(text_url) as ws:
async for msg in ws:
if msg.type is aiohttp.WSMsgType.TEXT:
await ws.send_str(msg.data)
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data)
elif msg.type is aiohttp.WSMsgType.BINARY:
await ws.send_bytes(msg.data)
else:
Expand Down
6 changes: 5 additions & 1 deletion tests/autobahn/server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import logging
from typing import TYPE_CHECKING

from aiohttp import WSCloseCode, web

Expand All @@ -21,7 +22,10 @@ async def wshandler(request: web.Request) -> web.WebSocketResponse:
msg = await ws.receive()

if msg.type is web.WSMsgType.TEXT:
await ws.send_str(msg.data)
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data)
elif msg.type is web.WSMsgType.BINARY:
await ws.send_bytes(msg.data)
elif msg.type is web.WSMsgType.CLOSE:
Expand Down
17 changes: 13 additions & 4 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import sys
from typing import NoReturn
from typing import TYPE_CHECKING, NoReturn
from unittest import mock

import pytest
Expand Down Expand Up @@ -1080,7 +1080,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
await ws.send_str(msg.data + "/answer")
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data + "/answer")
await ws.close()
return ws

Expand All @@ -1106,7 +1109,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
await ws.send_str(msg.data + "/answer")
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data + "/answer")
await ws.close()
return ws

Expand All @@ -1130,7 +1136,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
await ws.prepare(request)
msg = await ws.receive()
assert msg.type is WSMsgType.TEXT
await ws.send_str(msg.data + "/answer")
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data + "/answer")
await ws.close()
return ws

Expand Down
7 changes: 5 additions & 2 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
import sys
from collections.abc import Iterator, Mapping
from typing import NoReturn
from typing import TYPE_CHECKING, NoReturn
from unittest import mock

import pytest
Expand Down Expand Up @@ -45,7 +45,10 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
if msg.data == "close":
await ws.close()
else:
await ws.send_str(msg.data + "/answer")
data = msg.data
if TYPE_CHECKING:
assert isinstance(data, str)
await ws.send_str(data + "/answer")

return ws

Expand Down
4 changes: 3 additions & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import sys
import weakref
from typing import NoReturn
from typing import TYPE_CHECKING, NoReturn
from unittest import mock

import pytest
Expand Down Expand Up @@ -913,6 +913,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
async for msg in ws:
assert msg.type == aiohttp.WSMsgType.TEXT
s = msg.data
if TYPE_CHECKING:
assert isinstance(s, str)
await ws.send_str(s + "/answer")
await ws.close()
closed.set_result(1)
Expand Down
Loading