Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ws impl #39718

Draft
wants to merge 12 commits into
base: feature/eventhub/geodr-preview
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.http_proxy = http_proxy
self.transport_type = TransportType.AmqpOverWebsocket if self.http_proxy else transport_type
# if transport_type is not provided, it is None, we will default to Amqp
self.legacy_ws = kwargs.get("legacy_ws", False)
self.transport_type = self.transport_type or TransportType.Amqp
self.auth_timeout = auth_timeout
self.prefetch = prefetch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _create_consumer(
idle_timeout=self._idle_timeout,
track_last_enqueued_event_properties=track_last_enqueued_event_properties,
amqp_transport=self._amqp_transport,
legacy_ws=self._config.legacy_ws,
)
return handler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def _create_producer(
idle_timeout=self._idle_timeout,
amqp_transport=self._amqp_transport,
keep_alive=self._keep_alive,
legacy_ws=self._config.legacy_ws,
)
return handler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, List, Tuple, Optional, NamedTuple, Type, Union, cast

from ._transport import Transport
from .sasl import SASLTransport, SASLWithWebSocket
from .sasl import SASLTransport, SASLWithWebSocket, SASLWithLegacyWebSocket
from .session import Session
from .performatives import OpenFrame, CloseFrame
from .constants import (
Expand Down Expand Up @@ -159,9 +159,13 @@ def __init__( # pylint:disable=too-many-locals
if transport:
self._transport = transport
elif "sasl_credential" in kwargs:
sasl_transport: Union[Type[SASLTransport], Type[SASLWithWebSocket]] = SASLTransport
sasl_transport: Union[
Type[SASLTransport],
Type[SASLWithWebSocket],
Type[SASLWithLegacyWebSocket]] = SASLTransport
if self._transport_type.name == "AmqpOverWebsocket" or kwargs.get("http_proxy"):
sasl_transport = SASLWithWebSocket
use_legacy_ws = kwargs.get("legacy_ws", False)
sasl_transport = SASLWithWebSocket if not use_legacy_ws else SASLWithLegacyWebSocket
endpoint = parsed_url.hostname + parsed_url.path
self._transport = sasl_transport(
host=endpoint,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from enum import Enum
from typing import Literal


WS_VERSION: Literal[13] = 13
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'

EOF = b'\r\n\r\n'

class ConnectionStatus(Enum):
# a client opens a connection and sends a handshake
CONNECTING = 1

# a client has successfully completed & validated the handshake response from the server
OPEN = 2

# upon either sending or receiving a Close control frame
CLOSING = 3

# when the underlying TCP connection is closed
CLOSED = 4
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

class WebSocketException(Exception):
"""
Base exception for all WebSocket related errors.
"""

class WebSocketPayloadError(WebSocketException):
"""
Raised when there is an error in the WebSocket payload.
"""

class WebSocketConnectionError(WebSocketException):
"""
Raised when there is an eror while establishing a connection.
"""

class WebSocketConnectionClosed(WebSocketConnectionError):
"""
Raised when the connection is closed.
"""

class WebSocketProtocolError(WebSocketException):
"""
Raised when the WebSocket protocol is violated.
"""
173 changes: 173 additions & 0 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from dataclasses import dataclass
import io
import os
import struct
import enum

from ._exceptions import WebSocketPayloadError, WebSocketProtocolError
from ._utils import mask_payload


class Opcode(enum.IntEnum):
CONTINUATION = 0x00
TEXT = 0x01
BINARY = 0x02
CLOSE = 0x08
PING = 0x09
PONG = 0x0A

FRAMES = {
Opcode.CONTINUATION,
Opcode.TEXT,
Opcode.BINARY,
Opcode.CLOSE,
Opcode.PING,
Opcode.PONG
}

CONTROL_FRAMES = {
Opcode.CLOSE,
Opcode.PING,
Opcode.PONG,
}

DATA_FRAMES = {
Opcode.TEXT,
Opcode.BINARY,
Opcode.CONTINUATION,
}


class CloseReason(enum.IntEnum):
'''
The status codes for the close frame as per RFC 6455.
https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
'''
NORMAL = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
NO_STATUS_RCVD = 1005
ABNORMAL_CLOSURE = 1006
INVALID_PAYLOAD = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXT = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
TLS_HANDSHAKE = 1015

ALLOWED_CLOSED_REASONS = {
CloseReason.NORMAL,
CloseReason.GOING_AWAY,
CloseReason.PROTOCOL_ERROR,
CloseReason.UNSUPPORTED_DATA,
CloseReason.INVALID_PAYLOAD,
CloseReason.POLICY_VIOLATION,
CloseReason.MESSAGE_TOO_BIG,
CloseReason.MANDATORY_EXT,
CloseReason.INTERNAL_ERROR,
CloseReason.SERVICE_RESTART,
CloseReason.TRY_AGAIN_LATER,
CloseReason.BAD_GATEWAY,
}


@dataclass
class Frame:
"""
Represents a WebSocket frame
"""
data: bytes
opcode: int = Opcode.TEXT
fin: bool = True
mask: bool = True
rsv1: int = 0
rsv2: int = 0
rsv3: int = 0

def encode(self) -> bytes:
output = io.BytesIO()

header = output.write(struct.pack('!B', # pylint: disable=unused-variable
(
self.fin << 7 | # FIN
self.rsv1 << 6 | # RSV1
self.rsv2 << 5 | # RSV2
self.rsv3 << 4 | # RSV3
self.opcode
)))
mask_bit = 1 << 7
length = len(self.data)
payload = bytearray(self.data)

if length < 126:
output.write(struct.pack('!B', mask_bit | length))
elif length < 65536:
output.write(struct.pack('!BH', mask_bit | 126, length))
else:
output.write(struct.pack('!BQ', mask_bit | 127, length))

if self.mask:
masking_key = os.urandom(4)
output.write(masking_key)
mask_payload(masking_key, payload)
output.write(payload)

return output.getvalue()

def validate(self) -> None:
"""
Validate the frame according to the RFC 6455
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise WebSocketProtocolError("Reserved bits are set")

if self.opcode not in FRAMES:
raise WebSocketProtocolError(f"Invalid opcode {self.opcode}. Expected one of {FRAMES}")

# Validations for control frames
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.5

if self.opcode in CONTROL_FRAMES:
if not self.fin:
raise WebSocketProtocolError("Control frames must not be fragmented")

if len(self.data) > 125:
raise WebSocketProtocolError(f"Control frames must not be larger than 125 bytes. Got {len(self.data)}")

if self.opcode == Opcode.CLOSE:
if not self.data:
return

payload_length = len(self.data)

# if there is a body, it must be at least 2 bytes long in order to contain the status code
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
if payload_length <= 1:
raise WebSocketProtocolError("Close frame with payload length of less than 2 bytes")

code = struct.unpack('!H', self.data[:2])[0]
payload = self.data[2:]

if not (code in ALLOWED_CLOSED_REASONS or 3000 <= code <= 4999):
raise WebSocketProtocolError(f"Invalid close code {code}")

try:
payload.decode("utf-8")
except UnicodeDecodeError as ude:
raise WebSocketPayloadError('Invalid UTF-8 payload') from ude

elif self.opcode == Opcode.TEXT and self.fin:
try:
self.data.decode("utf-8")
except UnicodeDecodeError as ude:
raise WebSocketPayloadError('Invalid UTF-8 payload') from ude
98 changes: 98 additions & 0 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from hashlib import sha1
import os
import base64
from typing import Dict, List, Tuple, Optional
from ._constants import WS_KEY


def build_request_headers(
resource: str,
host: str,
port: int,
is_secure: bool,
key: bytes,
*,
subprotocols: Optional[List[str]] = None
) -> bytes:
request: List[bytes] = [
f'GET {resource} HTTP/1.1'.encode(),
]
origin: str = 'https://' if is_secure else 'http://'
origin = f'{origin}{build_host(host, port)}'

headers: Dict[bytes, bytes] = {
b'Host: ': build_host(host, port).encode(),
b'Connection: ': b'Upgrade',
b'Upgrade: ': b'websocket',
b'Origin: ': origin.encode(),
b'Sec-WebSocket-Key: ': key,
b'Sec-WebSocket-Version: ': b'13'
}

if subprotocols:
headers[b'Sec-WebSocket-Protocol: '] = ','.join(subprotocols).encode()

for header, value in headers.items():
request.append(header + value)

request.append(b'\r\n')
request_bytes: bytes = b'\r\n'.join(request)
return request_bytes

def build_host(host: str, port: int) -> str:
if ':' in host:
return f'[{host}]:{port}'

if port in (80, 443):
return host

return f'{host}:{port}'

def build_key() -> bytes:
random_bytes: bytes = os.urandom(16)
return base64.b64encode(random_bytes)

def match_key(client_header_key: bytes, server_header_key: bytes) -> bool:
# the use of sha1 is a websocket protocol requirement and is used for hashing and not cryptography
# other impls have the same behavior
# https://github.com/dotnet/runtime/blob/45caaf85faa654114f7a3744910df86d8e92882f/src/libraries/System.Net.HttpListener/src/System/Net/WebSockets/HttpWebSocket.cs#L18
match = base64.b64encode(sha1(client_header_key + WS_KEY).digest()).lower()

return match == server_header_key

def parse_response_headers(response_headers: bytes) -> Tuple[bytes, bytes, Dict[bytes, bytes]]:
response_line = response_headers.split(b'\r\n')[0]
_, code, status = response_line.split(b' ', 2)

headers: Dict[bytes, bytes] = {}

for header in response_headers.split(b'\r\n')[1:]:
if not header:
break
key, value = header.split(b': ')
headers[key.lower()] = value.lower()

return code, status, headers

def parse_proxy_response(response_headers: bytes) -> Tuple[bytes, bytes, bytes]:
if not response_headers:
return b'', b'', b''

response_line = response_headers.split(b'\r\n')[0]

try:
version, status, reason = response_line.split(b' ', 2)
except ValueError:
try:
version, status = response_line.split(b' ', 1)
reason = b''
except ValueError:
version = b''

return version, status, reason
Loading
Loading