diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index ea72e822d..fd9c541f3 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -3,12 +3,11 @@ # pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false # flake8: noqa: F811 import asyncio +from concurrent.futures import Future as ConcurrentFuture from collections import deque -from collections.abc import AsyncGenerator, AsyncIterable, Awaitable -from concurrent.futures import Future +from collections.abc import AsyncGenerator, AsyncIterable from copy import deepcopy from io import BytesIO, BufferedIOBase -from threading import Lock from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -61,23 +60,100 @@ def _initialize_default_loop(self) -> "crt_io.ClientBootstrap": class AWSCRTHTTPResponse(http_aio_interfaces.HTTPResponse): - def __init__(self) -> None: + def __init__(self, *, status: int, fields: Fields, body: "CRTResponseBody") -> None: _assert_crt() - self._stream: crt_http.HttpClientStream | None = None - self._status_code_future: Future[int] = Future() - self._headers_future: Future[Fields] = Future() - self._chunk_futures: list[Future[bytes]] = [] - self._received_chunks: list[bytes] = [] - self._chunk_lock: Lock = Lock() - - def _set_stream(self, stream: "crt_http.HttpClientStream") -> None: + self._status = status + self._fields = fields + self._body = body + + @property + def status(self) -> int: + return self._status + + @property + def fields(self) -> Fields: + return self._fields + + @property + def body(self) -> AsyncIterable[bytes]: + return self.chunks() + + @property + def reason(self) -> str | None: + """Optional string provided by the server explaining the status.""" + # TODO: See how CRT exposes reason. + return None + + async def chunks(self) -> AsyncGenerator[bytes, None]: + while True: + chunk = await self._body.next() + if chunk: + yield chunk + else: + break + + def __repr__(self) -> str: + return ( + f"AWSCRTHTTPResponse(" + f"status={self.status}, " + f"fields={self.fields!r}, body=...)" + ) + + +class CRTResponseBody: + def __init__(self) -> None: + self._stream: "crt_http.HttpClientStream | None" = None + self._chunk_futures: deque[ConcurrentFuture[bytes]] = deque() + + # deque is thread safe and the crt is only going to be writing + # with one thread anyway, so we *shouldn't* need to gate this + # behind a lock. In an ideal world, the CRT would expose + # an interface that better matches python's async. + self._received_chunks: deque[bytes] = deque() + + def set_stream(self, stream: "crt_http.HttpClientStream") -> None: if self._stream is not None: raise SmithyHTTPException("Stream already set on AWSCRTHTTPResponse object") self._stream = stream self._stream.completion_future.add_done_callback(self._on_complete) self._stream.activate() - def _on_headers( + def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback + # TODO: update back pressure window once CRT supports it + if self._chunk_futures: + future = self._chunk_futures.popleft() + future.set_result(chunk) + else: + self._received_chunks.append(chunk) + + async def next(self) -> bytes: + if self._stream is None: + raise SmithyHTTPException("Stream not set") + + # TODO: update backpressure window once CRT supports it + if self._received_chunks: + return self._received_chunks.popleft() + elif self._stream.completion_future.done(): + return b"" + else: + future = ConcurrentFuture[bytes]() + self._chunk_futures.append(future) + return await asyncio.wrap_future(future) + + def _on_complete( + self, completion_future: ConcurrentFuture[int] + ) -> None: # pragma: crt-callback + for future in self._chunk_futures: + future.set_result(b"") + self._chunk_futures.clear() + + +class CRTResponseFactory: + def __init__(self, body: CRTResponseBody) -> None: + self._body = body + self._response_future = ConcurrentFuture[AWSCRTHTTPResponse]() + + def on_response( self, status_code: int, headers: list[tuple[str, str]], **kwargs: Any ) -> None: # pragma: crt-callback fields = Fields() @@ -90,76 +166,24 @@ def _on_headers( values=[header_val], kind=FieldPosition.HEADER, ) - self._status_code_future.set_result(status_code) - self._headers_future.set_result(fields) - - def _on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback - with self._chunk_lock: - # TODO: update back pressure window once CRT supports it - if self._chunk_futures: - future = self._chunk_futures.pop(0) - future.set_result(chunk) - else: - self._received_chunks.append(chunk) - - def _get_chunk_future(self) -> Future[bytes]: - if self._stream is None: - raise SmithyHTTPException("Stream not set") - with self._chunk_lock: - future: Future[bytes] = Future() - # TODO: update backpressure window once CRT supports it - if self._received_chunks: - chunk = self._received_chunks.pop(0) - future.set_result(chunk) - elif self._stream.completion_future.done(): - future.set_result(b"") - else: - self._chunk_futures.append(future) - return future - - def _on_complete( - self, completion_future: Future[int] - ) -> None: # pragma: crt-callback - with self._chunk_lock: - if self._chunk_futures: - future = self._chunk_futures.pop(0) - future.set_result(b"") - - @property - def body(self) -> AsyncIterable[bytes]: - return self.chunks() - - @property - def status(self) -> int: - """The 3 digit response status code (1xx, 2xx, 3xx, 4xx, 5xx).""" - return self._status_code_future.result() - @property - def fields(self) -> Fields: - """List of HTTP header fields.""" - if self._stream is None: - raise SmithyHTTPException("Stream not set") - if not self._headers_future.done(): - raise SmithyHTTPException("Headers not received yet") - return self._headers_future.result() + self._response_future.set_result( + AWSCRTHTTPResponse( + status=status_code, + fields=fields, + body=self._body, + ) + ) - @property - def reason(self) -> str | None: - """Optional string provided by the server explaining the status.""" - # TODO: See how CRT exposes reason. - return None + async def await_response(self) -> AWSCRTHTTPResponse: + return await asyncio.wrap_future(self._response_future) - def get_chunk(self) -> Awaitable[bytes]: - future = self._get_chunk_future() - return asyncio.wrap_future(future) + def set_done_callback(self, stream: "crt_http.HttpClientStream") -> None: + stream.completion_future.add_done_callback(self._cancel) - async def chunks(self) -> AsyncGenerator[bytes, None]: - while True: - chunk = await self.get_chunk() - if chunk: - yield chunk - else: - break + def _cancel(self, completion_future: ConcurrentFuture[int | Exception]) -> None: + if not self._response_future.done(): + self._response_future.cancel() ConnectionPoolKey = tuple[str, str, int | None] @@ -208,20 +232,21 @@ async def send( """ crt_request = await self._marshal_request(request) connection = await self._get_connection(request.destination) - crt_response = AWSCRTHTTPResponse() + response_body = CRTResponseBody() + response_factory = CRTResponseFactory(response_body) crt_stream = connection.request( crt_request, - crt_response._on_headers, # pyright: ignore[reportPrivateUsage] - crt_response._on_body, # pyright: ignore[reportPrivateUsage] + response_factory.on_response, + response_body.on_body, ) - crt_response._set_stream(crt_stream) # pyright: ignore[reportPrivateUsage] - return crt_response + response_factory.set_done_callback(crt_stream) + response_body.set_stream(crt_stream) + return await response_factory.await_response() async def _create_connection( self, url: core_interfaces.URI ) -> "crt_http.HttpClientConnection": - """Builds and validates connection to ``url``, returns it as - ``asyncio.Future``""" + """Builds and validates connection to ``url``""" connect_future = self._build_new_connection(url) connection = await asyncio.wrap_future(connect_future) self._validate_connection(connection) @@ -241,7 +266,7 @@ async def _get_connection( def _build_new_connection( self, url: core_interfaces.URI - ) -> Future["crt_http.HttpClientConnection"]: + ) -> ConcurrentFuture["crt_http.HttpClientConnection"]: if url.scheme == "http": port = self._HTTP_PORT tls_connection_options = None @@ -258,7 +283,7 @@ def _build_new_connection( if url.port is not None: port = url.port - connect_future: Future[crt_http.HttpClientConnection] = ( + connect_future: ConcurrentFuture[crt_http.HttpClientConnection] = ( crt_http.HttpClientConnection.new( bootstrap=self._client_bootstrap, host_name=url.host,