Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 113 additions & 88 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
# flake8: noqa: F811
import asyncio
from concurrent.futures import Future as ConcurrentFuture
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable
from concurrent.futures import Future
from collections.abc import AsyncGenerator, AsyncIterable
from copy import deepcopy
from io import BytesIO, BufferedIOBase
from threading import Lock
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,23 +60,100 @@ def _initialize_default_loop(self) -> "crt_io.ClientBootstrap":


class AWSCRTHTTPResponse(http_aio_interfaces.HTTPResponse):
def __init__(self) -> None:
def __init__(self, *, status: int, fields: Fields, body: "CRTResponseBody") -> None:
_assert_crt()
self._stream: crt_http.HttpClientStream | None = None
self._status_code_future: Future[int] = Future()
self._headers_future: Future[Fields] = Future()
self._chunk_futures: list[Future[bytes]] = []
self._received_chunks: list[bytes] = []
self._chunk_lock: Lock = Lock()

def _set_stream(self, stream: "crt_http.HttpClientStream") -> None:
self._status = status
self._fields = fields
self._body = body

@property
def status(self) -> int:
return self._status

@property
def fields(self) -> Fields:
return self._fields

@property
def body(self) -> AsyncIterable[bytes]:
return self.chunks()

@property
def reason(self) -> str | None:
"""Optional string provided by the server explaining the status."""
# TODO: See how CRT exposes reason.
return None

async def chunks(self) -> AsyncGenerator[bytes, None]:
while True:
chunk = await self._body.next()
if chunk:
yield chunk
else:
break

def __repr__(self) -> str:
return (
f"AWSCRTHTTPResponse("
f"status={self.status}, "
f"fields={self.fields!r}, body=...)"
)


class CRTResponseBody:
def __init__(self) -> None:
self._stream: "crt_http.HttpClientStream | None" = None
self._chunk_futures: deque[ConcurrentFuture[bytes]] = deque()

# deque is thread safe and the crt is only going to be writing
# with one thread anyway, so we *shouldn't* need to gate this
# behind a lock. In an ideal world, the CRT would expose
# an interface that better matches python's async.
self._received_chunks: deque[bytes] = deque()

def set_stream(self, stream: "crt_http.HttpClientStream") -> None:
if self._stream is not None:
raise SmithyHTTPException("Stream already set on AWSCRTHTTPResponse object")
self._stream = stream
self._stream.completion_future.add_done_callback(self._on_complete)
self._stream.activate()

def _on_headers(
def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback
# TODO: update back pressure window once CRT supports it
if self._chunk_futures:
future = self._chunk_futures.popleft()
future.set_result(chunk)
else:
self._received_chunks.append(chunk)

async def next(self) -> bytes:
if self._stream is None:
raise SmithyHTTPException("Stream not set")

# TODO: update backpressure window once CRT supports it
if self._received_chunks:
return self._received_chunks.popleft()
elif self._stream.completion_future.done():
return b""
else:
future = ConcurrentFuture[bytes]()
self._chunk_futures.append(future)
return await asyncio.wrap_future(future)

def _on_complete(
self, completion_future: ConcurrentFuture[int]
) -> None: # pragma: crt-callback
for future in self._chunk_futures:
future.set_result(b"")
self._chunk_futures.clear()


class CRTResponseFactory:
def __init__(self, body: CRTResponseBody) -> None:
self._body = body
self._response_future = ConcurrentFuture[AWSCRTHTTPResponse]()

def on_response(
self, status_code: int, headers: list[tuple[str, str]], **kwargs: Any
) -> None: # pragma: crt-callback
fields = Fields()
Expand All @@ -90,76 +166,24 @@ def _on_headers(
values=[header_val],
kind=FieldPosition.HEADER,
)
self._status_code_future.set_result(status_code)
self._headers_future.set_result(fields)

def _on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback
with self._chunk_lock:
# TODO: update back pressure window once CRT supports it
if self._chunk_futures:
future = self._chunk_futures.pop(0)
future.set_result(chunk)
else:
self._received_chunks.append(chunk)

def _get_chunk_future(self) -> Future[bytes]:
if self._stream is None:
raise SmithyHTTPException("Stream not set")
with self._chunk_lock:
future: Future[bytes] = Future()
# TODO: update backpressure window once CRT supports it
if self._received_chunks:
chunk = self._received_chunks.pop(0)
future.set_result(chunk)
elif self._stream.completion_future.done():
future.set_result(b"")
else:
self._chunk_futures.append(future)
return future

def _on_complete(
self, completion_future: Future[int]
) -> None: # pragma: crt-callback
with self._chunk_lock:
if self._chunk_futures:
future = self._chunk_futures.pop(0)
future.set_result(b"")

@property
def body(self) -> AsyncIterable[bytes]:
return self.chunks()

@property
def status(self) -> int:
"""The 3 digit response status code (1xx, 2xx, 3xx, 4xx, 5xx)."""
return self._status_code_future.result()

@property
def fields(self) -> Fields:
"""List of HTTP header fields."""
if self._stream is None:
raise SmithyHTTPException("Stream not set")
if not self._headers_future.done():
raise SmithyHTTPException("Headers not received yet")
return self._headers_future.result()
self._response_future.set_result(
AWSCRTHTTPResponse(
status=status_code,
fields=fields,
body=self._body,
)
)

@property
def reason(self) -> str | None:
"""Optional string provided by the server explaining the status."""
# TODO: See how CRT exposes reason.
return None
async def await_response(self) -> AWSCRTHTTPResponse:
return await asyncio.wrap_future(self._response_future)

def get_chunk(self) -> Awaitable[bytes]:
future = self._get_chunk_future()
return asyncio.wrap_future(future)
def set_done_callback(self, stream: "crt_http.HttpClientStream") -> None:
stream.completion_future.add_done_callback(self._cancel)

async def chunks(self) -> AsyncGenerator[bytes, None]:
while True:
chunk = await self.get_chunk()
if chunk:
yield chunk
else:
break
def _cancel(self, completion_future: ConcurrentFuture[int | Exception]) -> None:
if not self._response_future.done():
self._response_future.cancel()


ConnectionPoolKey = tuple[str, str, int | None]
Expand Down Expand Up @@ -208,20 +232,21 @@ async def send(
"""
crt_request = await self._marshal_request(request)
connection = await self._get_connection(request.destination)
crt_response = AWSCRTHTTPResponse()
response_body = CRTResponseBody()
response_factory = CRTResponseFactory(response_body)
crt_stream = connection.request(
crt_request,
crt_response._on_headers, # pyright: ignore[reportPrivateUsage]
crt_response._on_body, # pyright: ignore[reportPrivateUsage]
response_factory.on_response,
response_body.on_body,
)
crt_response._set_stream(crt_stream) # pyright: ignore[reportPrivateUsage]
return crt_response
response_factory.set_done_callback(crt_stream)
response_body.set_stream(crt_stream)
return await response_factory.await_response()

async def _create_connection(
self, url: core_interfaces.URI
) -> "crt_http.HttpClientConnection":
"""Builds and validates connection to ``url``, returns it as
``asyncio.Future``"""
"""Builds and validates connection to ``url``"""
connect_future = self._build_new_connection(url)
connection = await asyncio.wrap_future(connect_future)
self._validate_connection(connection)
Expand All @@ -241,7 +266,7 @@ async def _get_connection(

def _build_new_connection(
self, url: core_interfaces.URI
) -> Future["crt_http.HttpClientConnection"]:
) -> ConcurrentFuture["crt_http.HttpClientConnection"]:
if url.scheme == "http":
port = self._HTTP_PORT
tls_connection_options = None
Expand All @@ -258,7 +283,7 @@ def _build_new_connection(
if url.port is not None:
port = url.port

connect_future: Future[crt_http.HttpClientConnection] = (
connect_future: ConcurrentFuture[crt_http.HttpClientConnection] = (
crt_http.HttpClientConnection.new(
bootstrap=self._client_bootstrap,
host_name=url.host,
Expand Down