Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/integrations/anthropic/test_iterator_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from weave.integrations.anthropic.anthropic_sdk import AnthropicIteratorWrapper


class _AsyncContextIterator:
def __init__(self, open_handles: list["_AsyncContextIterator"]) -> None:
self.entered = False
self.exited = False
self._open_handles = open_handles
self._items = iter([1, 2])

async def __aenter__(self) -> "_AsyncContextIterator":
self.entered = True
self._open_handles.append(self)
return self

async def __aexit__(self, exc_type, exc, tb) -> None:
self.exited = True
self._open_handles.remove(self)

def __aiter__(self) -> "_AsyncContextIterator":
return self

async def __anext__(self) -> int:
try:
return next(self._items)
except StopIteration as exc:
raise StopAsyncIteration from exc


@pytest.mark.asyncio
async def test_anthropic_iterator_wrapper_delegates_aexit() -> None:
open_handles: list[_AsyncContextIterator] = []
stream = _AsyncContextIterator(open_handles)
wrapper = AnthropicIteratorWrapper(
stream, lambda _: None, lambda _: None, lambda: None
)

async with wrapper as wrapped:
assert stream.entered is True
values = [value async for value in wrapped]

assert values == [1, 2]
# If this fails, AnthropicIteratorWrapper.__aexit__ is not delegating to the
# wrapped async context manager, which can leak streaming connections.
assert stream.exited is True
assert open_handles == []
14 changes: 14 additions & 0 deletions weave/integrations/anthropic/anthropic_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore
def text_stream(self) -> Iterator[str] | AsyncIterator[str]:
return self.__stream_text__()

async def __aexit__(
self,
exc_type: Exception | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
if exc_type and isinstance(exc_value, Exception):
self._call_on_error_once(exc_value)
if hasattr(self._iterator_or_ctx_manager, "__aexit__"):
await self._iterator_or_ctx_manager.__aexit__(
exc_type, exc_value, traceback
)
self._call_on_close_once()


def create_stream_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
Expand Down