From 75c5bee2a13405c5657759ed25b4e16da26204c1 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Sun, 1 Feb 2026 11:01:30 -0800 Subject: [PATCH 1/4] first pass, async file create endpoint end to end --- .../clickhouse_trace_server_batched.py | 108 ++++++++ ...ternal_to_internal_trace_server_adapter.py | 5 + weave/trace_server/file_storage.py | 242 ++++++++++++++++++ weave/trace_server/trace_server_interface.py | 10 +- 4 files changed, 364 insertions(+), 1 deletion(-) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index e00c61bd8095..75cc9fc96f89 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -123,13 +123,16 @@ validate_feedback_purge_req, ) from weave.trace_server.file_storage import ( + AsyncFileStorageClient, FileStorageClient, FileStorageReadError, FileStorageWriteError, key_for_project_digest, + maybe_get_async_storage_client_from_env, maybe_get_storage_client_from_env, read_from_bucket, store_in_bucket, + store_in_bucket_async, ) from weave.trace_server.file_storage_uris import FileStorageURI from weave.trace_server.ids import generate_id @@ -233,6 +236,8 @@ def __init__( self._use_async_insert = use_async_insert self._model_to_provider_info_map = read_model_to_provider_info_map() self._file_storage_client: FileStorageClient | None = None + self._async_file_storage_client: AsyncFileStorageClient | None = None + self._async_ch_client: clickhouse_connect.driver.AsyncClient | None = None self._kafka_producer: KafkaProducer | None = None self._evaluate_model_dispatcher = evaluate_model_dispatcher self._table_routing_resolver: TableRoutingResolver | None = None @@ -298,6 +303,25 @@ def file_storage_client(self) -> FileStorageClient | None: self._file_storage_client = maybe_get_storage_client_from_env() return self._file_storage_client + @property + def async_file_storage_client(self) -> AsyncFileStorageClient | None: + if self._async_file_storage_client is not None: + return self._async_file_storage_client + self._async_file_storage_client = maybe_get_async_storage_client_from_env() + return self._async_file_storage_client + + async def _get_async_ch_client(self) -> clickhouse_connect.driver.AsyncClient: + """Get or create async ClickHouse client.""" + if self._async_ch_client is None: + self._async_ch_client = await clickhouse_connect.get_async_client( + host=self._host, + port=self._port, + user=self._user, + password=self._password, + database=self._database, + ) + return self._async_ch_client + @property def kafka_producer(self) -> KafkaProducer: if self._kafka_producer is not None: @@ -4709,6 +4733,90 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: set_root_span_dd_tags({"write_bytes": len(req.content)}) return tsi.FileCreateRes(digest=digest) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version of file_create using async storage and ClickHouse clients.""" + digest = bytes_digest(req.content) + use_file_storage = self._should_use_file_storage_for_writes(req.project_id) + async_client = self.async_file_storage_client + + if async_client is not None and use_file_storage: + try: + await self._file_create_bucket_async(req, digest, async_client) + except FileStorageWriteError: + await self._file_create_clickhouse_async(req, digest) + else: + await self._file_create_clickhouse_async(req, digest) + set_root_span_dd_tags({"write_bytes": len(req.content)}) + return tsi.FileCreateRes(digest=digest) + + async def _file_create_clickhouse_async( + self, req: tsi.FileCreateReq, digest: str + ) -> None: + """Async version of _file_create_clickhouse.""" + set_root_span_dd_tags({"storage_provider": "clickhouse"}) + chunks = [ + req.content[i : i + ch_settings.FILE_CHUNK_SIZE] + for i in range(0, len(req.content), ch_settings.FILE_CHUNK_SIZE) + ] + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=i, + n_chunks=len(chunks), + name=req.name, + val_bytes=chunk, + bytes_stored=len(chunk), + file_storage_uri=None, + ) + for i, chunk in enumerate(chunks) + ] + ) + + async def _file_create_bucket_async( + self, req: tsi.FileCreateReq, digest: str, client: AsyncFileStorageClient + ) -> None: + """Async version of _file_create_bucket.""" + set_root_span_dd_tags({"storage_provider": "bucket"}) + target_file_storage_uri = await store_in_bucket_async( + client, key_for_project_digest(req.project_id, digest), req.content + ) + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=0, + n_chunks=1, + name=req.name, + val_bytes=b"", + bytes_stored=len(req.content), + file_storage_uri=target_file_storage_uri.to_uri_str(), + ) + ] + ) + + async def _insert_file_chunks_async( + self, file_chunks: list[FileChunkCreateCHInsertable] + ) -> None: + """Async version of _insert_file_chunks using async ClickHouse client.""" + data = [] + for chunk in file_chunks: + chunk_dump = chunk.model_dump() + row = [] + for col in ALL_FILE_CHUNK_INSERT_COLUMNS: + row.append(chunk_dump.get(col, None)) + data.append(row) + + if data: + client = await self._get_async_ch_client() + await client.insert( + "files", + data=data, + column_names=ALL_FILE_CHUNK_INSERT_COLUMNS, + ) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_clickhouse") def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: set_root_span_dd_tags({"storage_provider": "clickhouse"}) diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index e7528740a0ee..166f3b9c29a7 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -297,6 +297,11 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: # Special case where refs can never be part of the request return self._internal_trace_server.file_create(req) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + # Special case where refs can never be part of the request + return await self._internal_trace_server.file_create_async(req) + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) # Special case where refs can never be part of the request diff --git a/weave/trace_server/file_storage.py b/weave/trace_server/file_storage.py index 18e6b910f134..49a96be5a462 100644 --- a/weave/trace_server/file_storage.py +++ b/weave/trace_server/file_storage.py @@ -408,3 +408,245 @@ def maybe_get_storage_client_from_env() -> FileStorageClient | None: raise NotImplementedError( f"Storage client for URI type {type(file_storage_uri)} not supported" ) + + +# ============================================================================= +# Async Storage Clients +# ============================================================================= + + +class AsyncFileStorageClient: + """Abstract base class for async cloud storage operations.""" + + base_uri: FileStorageURI + _sync_client: FileStorageClient + + def __init__(self, base_uri: FileStorageURI, sync_client: FileStorageClient): + self.base_uri = base_uri + self._sync_client = sync_client + + @abstractmethod + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data at the specified URI location in cloud storage (async).""" + pass + + @abstractmethod + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from the specified URI location in cloud storage (async).""" + pass + + +class AsyncS3StorageClient(AsyncFileStorageClient): + """Async AWS S3 storage implementation using aiobotocore.""" + + def __init__( + self, + base_uri: FileStorageURI, + credentials: AWSCredentials, + sync_client: S3StorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, S3FileStorageURI) + self._credentials = credentials + self._kms_key = credentials.get("kms_key") + + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in S3 bucket asynchronously using aiobotocore.""" + assert isinstance(uri, S3FileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + # Import aiobotocore lazily to avoid hard dependency + try: + from aiobotocore.session import get_session + except ImportError: + # Fallback to sync client in thread pool if aiobotocore not available + import asyncio + + await asyncio.to_thread(self._sync_client.store, uri, data) + return + + session = get_session() + async with session.create_client( + "s3", + region_name=self._credentials.get("region"), + aws_access_key_id=self._credentials.get("access_key_id"), + aws_secret_access_key=self._credentials.get("secret_access_key"), + aws_session_token=self._credentials.get("session_token"), + ) as client: + put_object_params: dict[str, Any] = { + "Bucket": uri.bucket, + "Key": uri.path, + "Body": data, + } + if self._kms_key: + put_object_params["ServerSideEncryption"] = "aws:kms" + put_object_params["SSEKMSKeyId"] = self._kms_key + + await client.put_object(**put_object_params) + + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from S3 bucket asynchronously.""" + assert isinstance(uri, S3FileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + try: + from aiobotocore.session import get_session + except ImportError: + import asyncio + + return await asyncio.to_thread(self._sync_client.read, uri) + + session = get_session() + async with session.create_client( + "s3", + region_name=self._credentials.get("region"), + aws_access_key_id=self._credentials.get("access_key_id"), + aws_secret_access_key=self._credentials.get("secret_access_key"), + aws_session_token=self._credentials.get("session_token"), + ) as client: + response = await client.get_object(Bucket=uri.bucket, Key=uri.path) + async with response["Body"] as stream: + return await stream.read() + + +class AsyncGCSStorageClient(AsyncFileStorageClient): + """Async Google Cloud Storage implementation.""" + + def __init__( + self, + base_uri: FileStorageURI, + credentials: GCPCredentials | None, + sync_client: GCSStorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, GCSFileStorageURI) + self._credentials = credentials + + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in GCS bucket asynchronously. + + Falls back to sync client in thread pool since gcloud-aio-storage + is not a standard dependency. + """ + assert isinstance(uri, GCSFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + # GCS async requires gcloud-aio-storage; fall back to threaded sync + import asyncio + + await asyncio.to_thread(self._sync_client.store, uri, data) + + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from GCS bucket asynchronously.""" + assert isinstance(uri, GCSFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + import asyncio + + return await asyncio.to_thread(self._sync_client.read, uri) + + +class AsyncAzureStorageClient(AsyncFileStorageClient): + """Async Azure Blob Storage implementation using built-in aio support.""" + + def __init__( + self, + base_uri: FileStorageURI, + credentials: AzureConnectionCredentials | AzureAccountCredentials, + sync_client: AzureStorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, AzureFileStorageURI) + self._credentials = credentials + + async def _get_async_client(self, account: str) -> Any: + """Create async Azure client based on available credentials.""" + from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient + + if "connection_string" in self._credentials: + connection_creds = cast(AzureConnectionCredentials, self._credentials) + return AsyncBlobServiceClient.from_connection_string( + connection_creds["connection_string"], + connection_timeout=DEFAULT_CONNECT_TIMEOUT, + read_timeout=DEFAULT_READ_TIMEOUT, + ) + else: + account_creds = cast(AzureAccountCredentials, self._credentials) + if "account_url" in account_creds and account_creds["account_url"]: + account_url = account_creds["account_url"] + else: + account_url = f"https://{account}.blob.core.windows.net/" + return AsyncBlobServiceClient( + account_url=account_url, + credential=account_creds["access_key"], + connection_timeout=DEFAULT_CONNECT_TIMEOUT, + read_timeout=DEFAULT_READ_TIMEOUT, + ) + + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in Azure container asynchronously.""" + assert isinstance(uri, AzureFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with await self._get_async_client(uri.account) as client: + container_client = client.get_container_client(uri.container) + blob_client = container_client.get_blob_client(uri.path) + await blob_client.upload_blob(data, overwrite=True) + + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from Azure container asynchronously.""" + assert isinstance(uri, AzureFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with await self._get_async_client(uri.account) as client: + container_client = client.get_container_client(uri.container) + blob_client = container_client.get_blob_client(uri.path) + stream = await blob_client.download_blob() + return await stream.readall() + + +async def store_in_bucket_async( + client: AsyncFileStorageClient, path: str, data: bytes +) -> FileStorageURI: + """Store a file in a storage bucket asynchronously.""" + try: + target_file_storage_uri = client.base_uri.with_path(path) + await client.store_async(target_file_storage_uri, data) + except Exception as e: + logger.exception("Failed to store file at %s", target_file_storage_uri) + raise FileStorageWriteError(f"Failed to store file at {path}: {e!s}") from e + return target_file_storage_uri + + +def maybe_get_async_storage_client_from_env() -> AsyncFileStorageClient | None: + """Factory method that returns appropriate async storage client based on URI type.""" + file_storage_uri = wf_file_storage_uri() + if not file_storage_uri: + return None + try: + parsed_uri = FileStorageURI.parse_uri_str(file_storage_uri) + except Exception as e: + logger.warning(f"Error parsing file storage URI: {e}") + return None + if parsed_uri.has_path(): + logger.error( + f"Supplied file storage uri contains path components: {file_storage_uri}" + ) + return None + + if isinstance(parsed_uri, S3FileStorageURI): + credentials = get_aws_credentials() + sync_client = S3StorageClient(parsed_uri, credentials) + return AsyncS3StorageClient(parsed_uri, credentials, sync_client) + elif isinstance(parsed_uri, GCSFileStorageURI): + credentials = get_gcp_credentials() + sync_client = GCSStorageClient(parsed_uri, credentials) + return AsyncGCSStorageClient(parsed_uri, credentials, sync_client) + elif isinstance(parsed_uri, AzureFileStorageURI): + credentials = get_azure_credentials() + sync_client = AzureStorageClient(parsed_uri, credentials) + return AsyncAzureStorageClient(parsed_uri, credentials, sync_client) + else: + raise NotImplementedError( + f"Async storage client for URI type {type(file_storage_uri)} not supported" + ) diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 458f80b542ae..d6ae7248e15b 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -1,5 +1,6 @@ +import asyncio import datetime -from collections.abc import Iterator +from collections.abc import Coroutine, Iterator from enum import Enum from typing import Any, Literal, Protocol @@ -2289,6 +2290,13 @@ def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... # File API def file_create(self, req: FileCreateReq) -> FileCreateRes: ... + + def file_create_async( + self, req: FileCreateReq + ) -> Coroutine[Any, Any, FileCreateRes]: + """Async version of file_create. Default wraps sync in thread pool.""" + return asyncio.to_thread(self.file_create, req) + def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def files_stats(self, req: FilesStatsReq) -> FilesStatsRes: ... From e4318156c3d92c2f3dbf6c609b009d86070d80b9 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Sun, 1 Feb 2026 11:37:03 -0800 Subject: [PATCH 2/4] new async endpoint path --- tests/trace/test_file_storage_async.py | 242 +++++++++++++ tests/trace/test_server_file_storage.py | 325 ++++++++++++++++++ weave/trace_server/file_storage.py | 9 +- weave/trace_server/sqlite_trace_server.py | 5 + weave/trace_server/trace_server_interface.py | 10 +- .../remote_http_trace_server.py | 5 + 6 files changed, 580 insertions(+), 16 deletions(-) create mode 100644 tests/trace/test_file_storage_async.py diff --git a/tests/trace/test_file_storage_async.py b/tests/trace/test_file_storage_async.py new file mode 100644 index 000000000000..673abc5bb8ef --- /dev/null +++ b/tests/trace/test_file_storage_async.py @@ -0,0 +1,242 @@ +"""Unit tests for async file storage client implementations. + +This module tests the AsyncFileStorageClient implementations directly, +independent of the trace server layer. +""" + +import asyncio +import time +from unittest import mock + +import pytest + +from weave.trace_server.file_storage import ( + AsyncFileStorageClient, + AsyncGCSStorageClient, + AsyncS3StorageClient, + FileStorageWriteError, + GCSStorageClient, + S3StorageClient, + store_in_bucket_async, +) +from weave.trace_server.file_storage_uris import ( + GCSFileStorageURI, + S3FileStorageURI, +) + + +class TestAsyncS3StorageClient: + """Unit tests for AsyncS3StorageClient.""" + + @pytest.fixture + def s3_uri(self): + return S3FileStorageURI(bucket="test-bucket", path="") + + @pytest.fixture + def s3_credentials(self): + return { + "access_key_id": "test-key", + "secret_access_key": "test-secret", + "region": "us-east-1", + "session_token": None, + "kms_key": None, + } + + @pytest.fixture + def mock_sync_client(self, s3_uri, s3_credentials): + """Create a mock sync S3 client.""" + client = mock.MagicMock(spec=S3StorageClient) + client.base_uri = s3_uri + return client + + @pytest.mark.asyncio + async def test_store_async_fallback_to_sync( + self, s3_uri, s3_credentials, mock_sync_client + ): + """Test that store_async falls back to sync when aiobotocore is not available.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_client) + + # Mock aiobotocore import to fail + with mock.patch.dict( + "sys.modules", {"aiobotocore": None, "aiobotocore.session": None} + ): + target_uri = s3_uri.with_path("test/file.txt") + test_data = b"test content" + + # This should fall back to sync client via to_thread + # We need to mock asyncio.to_thread since the import will fail + with mock.patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = None + await async_client.store_async(target_uri, test_data) + mock_to_thread.assert_called_once_with( + mock_sync_client.store, target_uri, test_data + ) + + @pytest.mark.asyncio + async def test_store_async_with_aiobotocore( + self, s3_uri, s3_credentials, mock_sync_client + ): + """Test store_async uses aiobotocore when available.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_client) + + # Create mock aiobotocore session and client + mock_aio_client = mock.AsyncMock() + mock_session = mock.MagicMock() + mock_session.create_client.return_value.__aenter__ = mock.AsyncMock( + return_value=mock_aio_client + ) + mock_session.create_client.return_value.__aexit__ = mock.AsyncMock() + + with mock.patch( + "weave.trace_server.file_storage.AsyncS3StorageClient.store_async" + ) as mock_store: + mock_store.return_value = None + target_uri = s3_uri.with_path("test/file.txt") + await mock_store(target_uri, b"test content") + mock_store.assert_called_once() + + +class TestAsyncGCSStorageClient: + """Unit tests for AsyncGCSStorageClient.""" + + @pytest.fixture + def gcs_uri(self): + return GCSFileStorageURI(bucket="test-bucket", path="") + + @pytest.fixture + def mock_sync_gcs_client(self, gcs_uri): + """Create a mock sync GCS client.""" + client = mock.MagicMock(spec=GCSStorageClient) + client.base_uri = gcs_uri + return client + + @pytest.mark.asyncio + async def test_store_async_uses_thread_pool(self, gcs_uri, mock_sync_gcs_client): + """Test that GCS async falls back to thread pool (no native async support).""" + async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) + + target_uri = gcs_uri.with_path("test/file.txt") + test_data = b"test content" + + with mock.patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = None + await async_client.store_async(target_uri, test_data) + mock_to_thread.assert_called_once_with( + mock_sync_gcs_client.store, target_uri, test_data + ) + + @pytest.mark.asyncio + async def test_read_async_uses_thread_pool(self, gcs_uri, mock_sync_gcs_client): + """Test that GCS read_async falls back to thread pool.""" + async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) + mock_sync_gcs_client.read.return_value = b"test content" + + target_uri = gcs_uri.with_path("test/file.txt") + + with mock.patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = b"test content" + result = await async_client.read_async(target_uri) + mock_to_thread.assert_called_once_with( + mock_sync_gcs_client.read, target_uri + ) + assert result == b"test content" + + +class TestStoreInBucketAsync: + """Tests for the store_in_bucket_async helper function.""" + + @pytest.fixture + def mock_async_client(self): + """Create a mock async storage client.""" + client = mock.AsyncMock(spec=AsyncFileStorageClient) + client.base_uri = S3FileStorageURI(bucket="test-bucket", path="") + client.base_uri.with_path = lambda p: S3FileStorageURI( + bucket="test-bucket", path=p + ) + return client + + @pytest.mark.asyncio + async def test_store_in_bucket_async_success(self, mock_async_client): + """Test successful async bucket storage.""" + mock_async_client.store_async.return_value = None + + result = await store_in_bucket_async( + mock_async_client, "test/path/file.txt", b"content" + ) + + assert result.bucket == "test-bucket" + assert result.path == "test/path/file.txt" + mock_async_client.store_async.assert_called_once() + + @pytest.mark.asyncio + async def test_store_in_bucket_async_failure(self, mock_async_client): + """Test async bucket storage failure raises FileStorageWriteError.""" + mock_async_client.store_async.side_effect = Exception("Storage failed") + + with pytest.raises(FileStorageWriteError) as exc_info: + await store_in_bucket_async( + mock_async_client, "test/path/file.txt", b"content" + ) + + assert "Failed to store file" in str(exc_info.value) + + +class TestAsyncConcurrency: + """Tests to verify async operations don't block the event loop.""" + + @pytest.mark.asyncio + async def test_concurrent_operations_complete(self): + """Test that multiple async operations can run concurrently.""" + call_times = [] + + async def mock_async_operation(delay: float, name: str) -> str: + start = time.monotonic() + await asyncio.sleep(delay) + end = time.monotonic() + call_times.append((name, start, end)) + return name + + # Run 3 operations concurrently + tasks = [ + mock_async_operation(0.1, "op1"), + mock_async_operation(0.1, "op2"), + mock_async_operation(0.1, "op3"), + ] + + start = time.monotonic() + results = await asyncio.gather(*tasks) + total_time = time.monotonic() - start + + assert set(results) == {"op1", "op2", "op3"} + # If truly concurrent, total time should be ~0.1s, not 0.3s + assert total_time < 0.25, f"Operations took {total_time}s, expected < 0.25s" + + @pytest.mark.asyncio + async def test_blocking_operation_detection(self): + """Test that simulates detecting if operations would block.""" + event_loop_blocked = False + + async def check_event_loop(): + """Background task to check if event loop is responsive.""" + nonlocal event_loop_blocked + for _ in range(5): + start = time.monotonic() + await asyncio.sleep(0.01) + elapsed = time.monotonic() - start + # If sleep takes much longer than expected, loop was blocked + if elapsed > 0.05: + event_loop_blocked = True + break + + async def fast_async_operation(): + """A fast async operation that shouldn't block.""" + await asyncio.sleep(0.02) + return "done" + + # Run checker and operation concurrently + checker_task = asyncio.create_task(check_event_loop()) + result = await fast_async_operation() + await checker_task + + assert result == "done" + assert not event_loop_blocked, "Event loop was blocked during async operation" diff --git a/tests/trace/test_server_file_storage.py b/tests/trace/test_server_file_storage.py index 1114ab675e93..43420e4e08b7 100644 --- a/tests/trace/test_server_file_storage.py +++ b/tests/trace/test_server_file_storage.py @@ -5,8 +5,10 @@ specific setup requirements. """ +import asyncio import base64 import os +import time from unittest import mock import boto3 @@ -481,3 +483,326 @@ def mock_upload_fail(*args, **kwargs): # Verify GCS was attempted exactly 3 times before fallback assert attempt_count == 3, f"Expected 3 GCS attempts, got {attempt_count}" + + +# ============================================================================= +# Async file_create tests +# ============================================================================= + + +class TestAsyncFileCreate: + """Tests for the async file_create_async functionality.""" + + @pytest.mark.asyncio + async def test_async_file_create_basic(self, client: WeaveClient): + """Test basic async file creation works.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite - testing ClickHouse async path") + + req = FileCreateReq( + project_id=client._project_id(), + name="async_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + assert res.digest != "" + + # Verify content can be read back + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + @pytest.mark.asyncio + async def test_async_file_create_consistency_with_sync(self, client: WeaveClient): + """Test that async and sync produce the same digest for same content.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="consistency_test.txt", + content=b"consistent content for testing", + ) + + # Create via sync + sync_res = client.server.file_create(req) + + # Create via async (should be idempotent, same digest) + async_res = await client.server.file_create_async(req) + + assert sync_res.digest == async_res.digest + + @pytest.mark.asyncio + async def test_async_file_create_concurrent(self, client: WeaveClient): + """Test concurrent async file creations don't block each other.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + async def create_file(content: bytes, name: str) -> tuple[str, float]: + start = time.monotonic() + req = FileCreateReq( + project_id=client._project_id(), + name=name, + content=content, + ) + res = await client.server.file_create_async(req) + elapsed = time.monotonic() - start + return res.digest, elapsed + + # Create multiple files concurrently + tasks = [ + create_file(f"content_{i}".encode(), f"file_{i}.txt") for i in range(5) + ] + + start_total = time.monotonic() + results = await asyncio.gather(*tasks) + total_elapsed = time.monotonic() - start_total + + # All should succeed + digests = [r[0] for r in results] + assert all(d is not None and d != "" for d in digests) + + # All digests should be unique (different content) + assert len(set(digests)) == 5 + + # Concurrent execution should be faster than serial + # (sum of individual times should be greater than total time if truly concurrent) + individual_times = [r[1] for r in results] + sum_individual = sum(individual_times) + # Allow some overhead, but concurrent should show speedup + # This is a sanity check, not strict - just verify it's not completely serial + assert total_elapsed < sum_individual * 0.9 or total_elapsed < 1.0 + + @pytest.mark.asyncio + async def test_async_file_create_large_file(self, client: WeaveClient): + """Test async file creation with large files.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + # Create a file larger than FILE_CHUNK_SIZE + chunk_size = 100000 + num_chunks = 3 + large_content = b"x" * (chunk_size * num_chunks) + + req = FileCreateReq( + project_id=client._project_id(), + name="large_async_test.bin", + content=large_content, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + + # Verify content + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == large_content + + +class TestAsyncS3Storage: + """Tests for async AWS S3 storage implementation.""" + + @pytest.fixture + def s3(self): + """Moto S3 mock that implements the S3 API.""" + with mock_aws(): + s3_client = boto3.client( + "s3", + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + ) + s3_client.create_bucket(Bucket=TEST_BUCKET) + yield s3_client + + @pytest.fixture + def aws_storage_env_async(self): + """Setup AWS storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_AWS_ACCESS_KEY_ID": "test-key", + "WF_FILE_STORAGE_AWS_SECRET_ACCESS_KEY": "test-secret", + "WF_FILE_STORAGE_URI": f"s3://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield + + @pytest.mark.asyncio + @pytest.mark.usefixtures("aws_storage_env_async") + async def test_async_aws_storage(self, client: WeaveClient, s3): + """Test async file storage using AWS S3.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="async_s3_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + + # Verify the object exists in S3 + response = s3.list_objects_v2(Bucket=TEST_BUCKET) + assert "Contents" in response + assert len(response["Contents"]) >= 1 + + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + +class TestAsyncGCSStorage: + """Tests for async Google Cloud Storage implementation.""" + + @pytest.fixture + def mock_gcp_credentials(self): + """Mock GCP credentials to prevent authentication.""" + with mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info" + ) as mock_creds: + from google.auth.credentials import AnonymousCredentials + + mock_creds.return_value = AnonymousCredentials() + yield + + @pytest.fixture + def gcs_async(self): + """Google Cloud Storage mock for async tests.""" + mock_storage_client = mock.MagicMock() + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + + mock_storage_client.bucket.return_value = mock_bucket + mock_bucket.blob.return_value = mock_blob + + blob_data = {} + + def mock_upload_from_string( + data, timeout=None, if_generation_match=None, **kwargs + ): + blob_name = mock_blob.name + if if_generation_match == 0 and blob_name in blob_data: + from google.api_core import exceptions + + raise exceptions.PreconditionFailed("Object already exists") + blob_data[blob_name] = data + + def mock_download_as_bytes(timeout=None, **kwargs): + blob_name = mock_blob.name + return blob_data.get(blob_name, b"") + + mock_blob.upload_from_string.side_effect = mock_upload_from_string + mock_blob.download_as_bytes.side_effect = mock_download_as_bytes + + with mock.patch( + "google.cloud.storage.Client", return_value=mock_storage_client + ): + yield mock_storage_client + + @pytest.fixture + def gcp_storage_env_async(self): + """Setup GCP storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_GCP_CREDENTIALS_JSON_B64": base64.b64encode( + b"""{ + "type": "service_account", + "project_id": "test-project", + "private_key_id": "test-key-id", + "private_key": "test-key", + "client_email": "test@test-project.iam.gserviceaccount.com", + "client_id": "test-client-id", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" + }""" + ).decode(), + "WF_FILE_STORAGE_URI": f"gs://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield + + @pytest.mark.asyncio + @pytest.mark.usefixtures("gcp_storage_env_async", "mock_gcp_credentials") + async def test_async_gcs_storage(self, client: WeaveClient, gcs_async): + """Test async file storage using GCS.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="async_gcs_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + +class TestAsyncSQLiteStorage: + """Tests for async SQLite storage (thread pool wrapper).""" + + @pytest.mark.asyncio + async def test_sqlite_async_wrapper(self, client: WeaveClient): + """Test that SQLite async wrapper works correctly.""" + if not client_is_sqlite(client): + pytest.skip("This test is specifically for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="sqlite_async_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + assert res.digest != "" + + # Verify content + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + @pytest.mark.asyncio + async def test_sqlite_async_concurrent(self, client: WeaveClient): + """Test concurrent SQLite async operations.""" + if not client_is_sqlite(client): + pytest.skip("This test is specifically for SQLite") + + async def create_file(i: int) -> str: + req = FileCreateReq( + project_id=client._project_id(), + name=f"sqlite_concurrent_{i}.txt", + content=f"sqlite_content_{i}".encode(), + ) + res = await client.server.file_create_async(req) + return res.digest + + # Create multiple files concurrently + tasks = [create_file(i) for i in range(3)] + digests = await asyncio.gather(*tasks) + + # All should succeed + assert all(d is not None and d != "" for d in digests) + # All digests should be unique + assert len(set(digests)) == 3 diff --git a/weave/trace_server/file_storage.py b/weave/trace_server/file_storage.py index 49a96be5a462..2e59ebb13804 100644 --- a/weave/trace_server/file_storage.py +++ b/weave/trace_server/file_storage.py @@ -40,6 +40,7 @@ - `WF_FILE_STORAGE_AZURE_ACCOUNT_URL`: (optional) the account url for the azure account - defaults to `https://.blob.core.windows.net/` """ +import asyncio import logging from abc import abstractmethod from collections.abc import Callable @@ -460,8 +461,6 @@ async def store_async(self, uri: FileStorageURI, data: bytes) -> None: from aiobotocore.session import get_session except ImportError: # Fallback to sync client in thread pool if aiobotocore not available - import asyncio - await asyncio.to_thread(self._sync_client.store, uri, data) return @@ -492,8 +491,6 @@ async def read_async(self, uri: FileStorageURI) -> bytes: try: from aiobotocore.session import get_session except ImportError: - import asyncio - return await asyncio.to_thread(self._sync_client.read, uri) session = get_session() @@ -532,8 +529,6 @@ async def store_async(self, uri: FileStorageURI, data: bytes) -> None: assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) # GCS async requires gcloud-aio-storage; fall back to threaded sync - import asyncio - await asyncio.to_thread(self._sync_client.store, uri, data) async def read_async(self, uri: FileStorageURI) -> bytes: @@ -541,8 +536,6 @@ async def read_async(self, uri: FileStorageURI) -> bytes: assert isinstance(uri, GCSFileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) - import asyncio - return await asyncio.to_thread(self._sync_client.read, uri) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 305477a36178..754ba534f884 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1,5 +1,6 @@ # Sqlite Trace Server +import asyncio import datetime import hashlib import json @@ -1525,6 +1526,10 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: conn.commit() return tsi.FileCreateRes(digest=digest) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version - wraps sync in thread pool for SQLite.""" + return await asyncio.to_thread(self.file_create, req) + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: conn, cursor = get_conn_cursor(self.db_path) cursor.execute( diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index d6ae7248e15b..a2c72587bcff 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -1,6 +1,5 @@ -import asyncio import datetime -from collections.abc import Coroutine, Iterator +from collections.abc import Iterator from enum import Enum from typing import Any, Literal, Protocol @@ -2290,12 +2289,7 @@ def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... # File API def file_create(self, req: FileCreateReq) -> FileCreateRes: ... - - def file_create_async( - self, req: FileCreateReq - ) -> Coroutine[Any, Any, FileCreateRes]: - """Async version of file_create. Default wraps sync in thread pool.""" - return asyncio.to_thread(self.file_create, req) + async def file_create_async(self, req: FileCreateReq) -> FileCreateRes: ... def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def files_stats(self, req: FilesStatsReq) -> FilesStatsRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 9a6127561902..78eed14e2b14 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,3 +1,4 @@ +import asyncio import datetime import io import logging @@ -762,6 +763,10 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: handle_response_error(r, "/files/create") return tsi.FileCreateRes.model_validate(r.json()) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version - wraps sync HTTP call in thread pool.""" + return await asyncio.to_thread(self.file_create, req) + @with_retry def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: r = self.post( From a7b4b93a72500f531e44e0287d5f64ed3cbb7116 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Sun, 1 Feb 2026 12:03:52 -0800 Subject: [PATCH 3/4] chore: add aiobotocore dep, test: --- pyproject.toml | 1 + tests/trace/test_file_storage_async.py | 373 ++++++++++---------- tests/trace/test_server_file_storage.py | 440 +++++++++++------------- weave/trace_server/file_storage.py | 72 ++-- 4 files changed, 423 insertions(+), 463 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 13fc5934f28d..a3d8204e8328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ trace_server = [ "ddtrace>=2.7.0", # BYOB - S3 "boto3>=1.34.0", + "aiobotocore>=3.1.1", # Async S3 support # BYOB - Azure "azure-storage-blob>=12.24.0,<12.26.0", # BYOB - GCP diff --git a/tests/trace/test_file_storage_async.py b/tests/trace/test_file_storage_async.py index 673abc5bb8ef..e584a42c18ff 100644 --- a/tests/trace/test_file_storage_async.py +++ b/tests/trace/test_file_storage_async.py @@ -24,115 +24,148 @@ S3FileStorageURI, ) +# ============================================================================= +# Fixtures +# ============================================================================= -class TestAsyncS3StorageClient: - """Unit tests for AsyncS3StorageClient.""" - - @pytest.fixture - def s3_uri(self): - return S3FileStorageURI(bucket="test-bucket", path="") - - @pytest.fixture - def s3_credentials(self): - return { - "access_key_id": "test-key", - "secret_access_key": "test-secret", - "region": "us-east-1", - "session_token": None, - "kms_key": None, - } - - @pytest.fixture - def mock_sync_client(self, s3_uri, s3_credentials): - """Create a mock sync S3 client.""" - client = mock.MagicMock(spec=S3StorageClient) - client.base_uri = s3_uri - return client - - @pytest.mark.asyncio - async def test_store_async_fallback_to_sync( - self, s3_uri, s3_credentials, mock_sync_client - ): - """Test that store_async falls back to sync when aiobotocore is not available.""" - async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_client) - - # Mock aiobotocore import to fail - with mock.patch.dict( - "sys.modules", {"aiobotocore": None, "aiobotocore.session": None} - ): - target_uri = s3_uri.with_path("test/file.txt") - test_data = b"test content" - - # This should fall back to sync client via to_thread - # We need to mock asyncio.to_thread since the import will fail - with mock.patch("asyncio.to_thread") as mock_to_thread: - mock_to_thread.return_value = None - await async_client.store_async(target_uri, test_data) - mock_to_thread.assert_called_once_with( - mock_sync_client.store, target_uri, test_data - ) - - @pytest.mark.asyncio - async def test_store_async_with_aiobotocore( - self, s3_uri, s3_credentials, mock_sync_client - ): - """Test store_async uses aiobotocore when available.""" - async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_client) - - # Create mock aiobotocore session and client - mock_aio_client = mock.AsyncMock() - mock_session = mock.MagicMock() - mock_session.create_client.return_value.__aenter__ = mock.AsyncMock( + +@pytest.fixture +def s3_uri(): + """S3 URI fixture for testing.""" + return S3FileStorageURI(bucket="test-bucket", path="") + + +@pytest.fixture +def s3_credentials(): + """AWS credentials fixture for testing.""" + return { + "access_key_id": "test-key", + "secret_access_key": "test-secret", + "region": "us-east-1", + "session_token": None, + "kms_key": None, + } + + +@pytest.fixture +def mock_sync_s3_client(s3_uri): + """Create a mock sync S3 client.""" + client = mock.MagicMock(spec=S3StorageClient) + client.base_uri = s3_uri + return client + + +@pytest.fixture +def gcs_uri(): + """GCS URI fixture for testing.""" + return GCSFileStorageURI(bucket="test-bucket", path="") + + +@pytest.fixture +def mock_sync_gcs_client(gcs_uri): + """Create a mock sync GCS client.""" + client = mock.MagicMock(spec=GCSStorageClient) + client.base_uri = gcs_uri + return client + + +@pytest.fixture +def mock_async_client(): + """Create a mock async storage client for store_in_bucket_async tests.""" + client = mock.AsyncMock(spec=AsyncFileStorageClient) + client.base_uri = S3FileStorageURI(bucket="test-bucket", path="") + client.base_uri.with_path = lambda p: S3FileStorageURI(bucket="test-bucket", path=p) + return client + + +# ============================================================================= +# AsyncS3StorageClient Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_s3_store_async_calls_aiobotocore( + s3_uri, s3_credentials, mock_sync_s3_client +): + """Test that S3 store_async uses aiobotocore client.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) + + mock_aio_client = mock.AsyncMock() + mock_aio_client.put_object = mock.AsyncMock() + + with mock.patch.object(async_client, "_get_async_s3_client") as mock_get_client: + mock_get_client.return_value.__aenter__ = mock.AsyncMock( + return_value=mock_aio_client + ) + mock_get_client.return_value.__aexit__ = mock.AsyncMock() + + target_uri = s3_uri.with_path("test/file.txt") + await async_client.store_async(target_uri, b"test content") + + mock_get_client.assert_called_once() + + +@pytest.mark.asyncio +async def test_s3_read_async_calls_aiobotocore( + s3_uri, s3_credentials, mock_sync_s3_client +): + """Test that S3 read_async uses aiobotocore client.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) + + mock_stream = mock.AsyncMock() + mock_stream.read = mock.AsyncMock(return_value=b"test content") + mock_aio_client = mock.AsyncMock() + mock_aio_client.get_object = mock.AsyncMock(return_value={"Body": mock_stream}) + mock_stream.__aenter__ = mock.AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = mock.AsyncMock() + + with mock.patch.object(async_client, "_get_async_s3_client") as mock_get_client: + mock_get_client.return_value.__aenter__ = mock.AsyncMock( return_value=mock_aio_client ) - mock_session.create_client.return_value.__aexit__ = mock.AsyncMock() + mock_get_client.return_value.__aexit__ = mock.AsyncMock() + + target_uri = s3_uri.with_path("test/file.txt") + await async_client.read_async(target_uri) + + mock_get_client.assert_called_once() - with mock.patch( - "weave.trace_server.file_storage.AsyncS3StorageClient.store_async" - ) as mock_store: - mock_store.return_value = None - target_uri = s3_uri.with_path("test/file.txt") - await mock_store(target_uri, b"test content") - mock_store.assert_called_once() +@pytest.mark.asyncio +async def test_s3_session_reused(s3_uri, s3_credentials, mock_sync_s3_client): + """Test that the aiobotocore session is reused across calls.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) -class TestAsyncGCSStorageClient: - """Unit tests for AsyncGCSStorageClient.""" + # First call creates session + session1 = async_client._get_session() + # Second call reuses session + session2 = async_client._get_session() - @pytest.fixture - def gcs_uri(self): - return GCSFileStorageURI(bucket="test-bucket", path="") + assert session1 is session2 - @pytest.fixture - def mock_sync_gcs_client(self, gcs_uri): - """Create a mock sync GCS client.""" - client = mock.MagicMock(spec=GCSStorageClient) - client.base_uri = gcs_uri - return client - @pytest.mark.asyncio - async def test_store_async_uses_thread_pool(self, gcs_uri, mock_sync_gcs_client): - """Test that GCS async falls back to thread pool (no native async support).""" - async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) +# ============================================================================= +# AsyncGCSStorageClient Tests +# ============================================================================= - target_uri = gcs_uri.with_path("test/file.txt") - test_data = b"test content" +@pytest.mark.asyncio +@pytest.mark.parametrize("operation", ["store", "read"]) +async def test_gcs_async_uses_thread_pool(gcs_uri, mock_sync_gcs_client, operation): + """Test that GCS async operations use thread pool wrapper.""" + async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) + target_uri = gcs_uri.with_path("test/file.txt") + + if operation == "store": + mock_sync_gcs_client.store.return_value = None with mock.patch("asyncio.to_thread") as mock_to_thread: mock_to_thread.return_value = None - await async_client.store_async(target_uri, test_data) + await async_client.store_async(target_uri, b"test content") mock_to_thread.assert_called_once_with( - mock_sync_gcs_client.store, target_uri, test_data + mock_sync_gcs_client.store, target_uri, b"test content" ) - - @pytest.mark.asyncio - async def test_read_async_uses_thread_pool(self, gcs_uri, mock_sync_gcs_client): - """Test that GCS read_async falls back to thread pool.""" - async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) + else: mock_sync_gcs_client.read.return_value = b"test content" - - target_uri = gcs_uri.with_path("test/file.txt") - with mock.patch("asyncio.to_thread") as mock_to_thread: mock_to_thread.return_value = b"test content" result = await async_client.read_async(target_uri) @@ -142,101 +175,83 @@ async def test_read_async_uses_thread_pool(self, gcs_uri, mock_sync_gcs_client): assert result == b"test content" -class TestStoreInBucketAsync: - """Tests for the store_in_bucket_async helper function.""" +# ============================================================================= +# store_in_bucket_async Tests +# ============================================================================= - @pytest.fixture - def mock_async_client(self): - """Create a mock async storage client.""" - client = mock.AsyncMock(spec=AsyncFileStorageClient) - client.base_uri = S3FileStorageURI(bucket="test-bucket", path="") - client.base_uri.with_path = lambda p: S3FileStorageURI( - bucket="test-bucket", path=p - ) - return client - @pytest.mark.asyncio - async def test_store_in_bucket_async_success(self, mock_async_client): - """Test successful async bucket storage.""" - mock_async_client.store_async.return_value = None +@pytest.mark.asyncio +async def test_store_in_bucket_async_success(mock_async_client): + """Test successful async bucket storage returns correct URI.""" + mock_async_client.store_async.return_value = None - result = await store_in_bucket_async( - mock_async_client, "test/path/file.txt", b"content" - ) + result = await store_in_bucket_async( + mock_async_client, "test/path/file.txt", b"content" + ) - assert result.bucket == "test-bucket" - assert result.path == "test/path/file.txt" - mock_async_client.store_async.assert_called_once() + assert result.bucket == "test-bucket" + assert result.path == "test/path/file.txt" + mock_async_client.store_async.assert_called_once() - @pytest.mark.asyncio - async def test_store_in_bucket_async_failure(self, mock_async_client): - """Test async bucket storage failure raises FileStorageWriteError.""" - mock_async_client.store_async.side_effect = Exception("Storage failed") - with pytest.raises(FileStorageWriteError) as exc_info: - await store_in_bucket_async( - mock_async_client, "test/path/file.txt", b"content" - ) +@pytest.mark.asyncio +async def test_store_in_bucket_async_failure_raises_error(mock_async_client): + """Test async bucket storage failure raises FileStorageWriteError.""" + mock_async_client.store_async.side_effect = Exception("Storage failed") + + with pytest.raises(FileStorageWriteError) as exc_info: + await store_in_bucket_async(mock_async_client, "test/path/file.txt", b"content") + + assert "Failed to store file" in str(exc_info.value) + + +# ============================================================================= +# Concurrency Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_concurrent_operations_complete_in_parallel(): + """Test that multiple async operations run concurrently, not serially.""" + + async def mock_async_operation(delay: float, name: str) -> str: + await asyncio.sleep(delay) + return name + + # Run 3 operations concurrently, each taking 0.1s + tasks = [mock_async_operation(0.1, f"op{i}") for i in range(3)] - assert "Failed to store file" in str(exc_info.value) + start = time.monotonic() + results = await asyncio.gather(*tasks) + total_time = time.monotonic() - start + assert set(results) == {"op0", "op1", "op2"} + # If truly concurrent, total time should be ~0.1s, not 0.3s + assert total_time < 0.25, f"Operations took {total_time}s, expected < 0.25s" -class TestAsyncConcurrency: - """Tests to verify async operations don't block the event loop.""" - @pytest.mark.asyncio - async def test_concurrent_operations_complete(self): - """Test that multiple async operations can run concurrently.""" - call_times = [] +@pytest.mark.asyncio +async def test_event_loop_not_blocked(): + """Test that async operations don't block the event loop.""" + event_loop_blocked = False - async def mock_async_operation(delay: float, name: str) -> str: + async def check_event_loop(): + nonlocal event_loop_blocked + for _ in range(5): start = time.monotonic() - await asyncio.sleep(delay) - end = time.monotonic() - call_times.append((name, start, end)) - return name - - # Run 3 operations concurrently - tasks = [ - mock_async_operation(0.1, "op1"), - mock_async_operation(0.1, "op2"), - mock_async_operation(0.1, "op3"), - ] - - start = time.monotonic() - results = await asyncio.gather(*tasks) - total_time = time.monotonic() - start - - assert set(results) == {"op1", "op2", "op3"} - # If truly concurrent, total time should be ~0.1s, not 0.3s - assert total_time < 0.25, f"Operations took {total_time}s, expected < 0.25s" - - @pytest.mark.asyncio - async def test_blocking_operation_detection(self): - """Test that simulates detecting if operations would block.""" - event_loop_blocked = False - - async def check_event_loop(): - """Background task to check if event loop is responsive.""" - nonlocal event_loop_blocked - for _ in range(5): - start = time.monotonic() - await asyncio.sleep(0.01) - elapsed = time.monotonic() - start - # If sleep takes much longer than expected, loop was blocked - if elapsed > 0.05: - event_loop_blocked = True - break - - async def fast_async_operation(): - """A fast async operation that shouldn't block.""" - await asyncio.sleep(0.02) - return "done" - - # Run checker and operation concurrently - checker_task = asyncio.create_task(check_event_loop()) - result = await fast_async_operation() - await checker_task - - assert result == "done" - assert not event_loop_blocked, "Event loop was blocked during async operation" + await asyncio.sleep(0.01) + elapsed = time.monotonic() - start + if elapsed > 0.05: + event_loop_blocked = True + break + + async def fast_async_operation(): + await asyncio.sleep(0.02) + return "done" + + checker_task = asyncio.create_task(check_event_loop()) + result = await fast_async_operation() + await checker_task + + assert result == "done" + assert not event_loop_blocked, "Event loop was blocked during async operation" diff --git a/tests/trace/test_server_file_storage.py b/tests/trace/test_server_file_storage.py index 43420e4e08b7..84f4b9857eac 100644 --- a/tests/trace/test_server_file_storage.py +++ b/tests/trace/test_server_file_storage.py @@ -490,305 +490,264 @@ def mock_upload_fail(*args, **kwargs): # ============================================================================= -class TestAsyncFileCreate: - """Tests for the async file_create_async functionality.""" +@pytest.fixture +def aws_storage_env_async(): + """Setup AWS storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_AWS_ACCESS_KEY_ID": "test-key", + "WF_FILE_STORAGE_AWS_SECRET_ACCESS_KEY": "test-secret", + "WF_FILE_STORAGE_URI": f"s3://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield - @pytest.mark.asyncio - async def test_async_file_create_basic(self, client: WeaveClient): - """Test basic async file creation works.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite - testing ClickHouse async path") - req = FileCreateReq( - project_id=client._project_id(), - name="async_test.txt", - content=TEST_CONTENT, - ) - res = await client.server.file_create_async(req) +@pytest.fixture +def gcp_storage_env_async(): + """Setup GCP storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_GCP_CREDENTIALS_JSON_B64": base64.b64encode( + b"""{ + "type": "service_account", + "project_id": "test-project", + "private_key_id": "test-key-id", + "private_key": "test-key", + "client_email": "test@test-project.iam.gserviceaccount.com", + "client_id": "test-client-id", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" + }""" + ).decode(), + "WF_FILE_STORAGE_URI": f"gs://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield - assert res.digest is not None - assert res.digest != "" - # Verify content can be read back - file = client.server.file_content_read( - FileContentReadReq(project_id=client._project_id(), digest=res.digest) - ) - assert file.content == TEST_CONTENT +@pytest.fixture +def mock_gcp_credentials(): + """Mock GCP credentials to prevent authentication.""" + with mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info" + ) as mock_creds: + mock_creds.return_value = AnonymousCredentials() + yield - @pytest.mark.asyncio - async def test_async_file_create_consistency_with_sync(self, client: WeaveClient): - """Test that async and sync produce the same digest for same content.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite") - req = FileCreateReq( - project_id=client._project_id(), - name="consistency_test.txt", - content=b"consistent content for testing", - ) +@pytest.fixture +def mock_gcs_client(): + """Google Cloud Storage mock for tests.""" + mock_storage_client = mock.MagicMock() + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() - # Create via sync - sync_res = client.server.file_create(req) + mock_storage_client.bucket.return_value = mock_bucket + mock_bucket.blob.return_value = mock_blob - # Create via async (should be idempotent, same digest) - async_res = await client.server.file_create_async(req) + blob_data = {} - assert sync_res.digest == async_res.digest + def mock_upload_from_string(data, timeout=None, if_generation_match=None, **kwargs): + blob_name = mock_blob.name + if if_generation_match == 0 and blob_name in blob_data: + raise exceptions.PreconditionFailed("Object already exists") + blob_data[blob_name] = data - @pytest.mark.asyncio - async def test_async_file_create_concurrent(self, client: WeaveClient): - """Test concurrent async file creations don't block each other.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite") + def mock_download_as_bytes(timeout=None, **kwargs): + blob_name = mock_blob.name + return blob_data.get(blob_name, b"") - async def create_file(content: bytes, name: str) -> tuple[str, float]: - start = time.monotonic() - req = FileCreateReq( - project_id=client._project_id(), - name=name, - content=content, - ) - res = await client.server.file_create_async(req) - elapsed = time.monotonic() - start - return res.digest, elapsed + mock_blob.upload_from_string.side_effect = mock_upload_from_string + mock_blob.download_as_bytes.side_effect = mock_download_as_bytes - # Create multiple files concurrently - tasks = [ - create_file(f"content_{i}".encode(), f"file_{i}.txt") for i in range(5) - ] + with mock.patch("google.cloud.storage.Client", return_value=mock_storage_client): + yield mock_storage_client - start_total = time.monotonic() - results = await asyncio.gather(*tasks) - total_elapsed = time.monotonic() - start_total - # All should succeed - digests = [r[0] for r in results] - assert all(d is not None and d != "" for d in digests) +@pytest.fixture +def s3_mock(): + """Moto S3 mock that implements the S3 API.""" + with mock_aws(): + s3_client = boto3.client( + "s3", + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + ) + s3_client.create_bucket(Bucket=TEST_BUCKET) + yield s3_client - # All digests should be unique (different content) - assert len(set(digests)) == 5 - # Concurrent execution should be faster than serial - # (sum of individual times should be greater than total time if truly concurrent) - individual_times = [r[1] for r in results] - sum_individual = sum(individual_times) - # Allow some overhead, but concurrent should show speedup - # This is a sanity check, not strict - just verify it's not completely serial - assert total_elapsed < sum_individual * 0.9 or total_elapsed < 1.0 +# Parametrized test data for async file creation with various content sizes +ASYNC_FILE_CREATE_PARAMS = [ + pytest.param(TEST_CONTENT, "basic.txt", id="basic_content"), + pytest.param(b"x" * 300000, "large.bin", id="large_file"), # 3x chunk size +] - @pytest.mark.asyncio - async def test_async_file_create_large_file(self, client: WeaveClient): - """Test async file creation with large files.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite") - # Create a file larger than FILE_CHUNK_SIZE - chunk_size = 100000 - num_chunks = 3 - large_content = b"x" * (chunk_size * num_chunks) +@pytest.mark.asyncio +@pytest.mark.parametrize(("content", "filename"), ASYNC_FILE_CREATE_PARAMS) +async def test_async_file_create(client: WeaveClient, content: bytes, filename: str): + """Test async file creation with various content sizes.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite - testing ClickHouse async path") - req = FileCreateReq( - project_id=client._project_id(), - name="large_async_test.bin", - content=large_content, - ) - res = await client.server.file_create_async(req) + req = FileCreateReq( + project_id=client._project_id(), + name=filename, + content=content, + ) + res = await client.server.file_create_async(req) - assert res.digest is not None + assert res.digest is not None + assert res.digest != "" - # Verify content - file = client.server.file_content_read( - FileContentReadReq(project_id=client._project_id(), digest=res.digest) - ) - assert file.content == large_content + # Verify content can be read back + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == content -class TestAsyncS3Storage: - """Tests for async AWS S3 storage implementation.""" +@pytest.mark.asyncio +async def test_async_file_create_consistency_with_sync(client: WeaveClient): + """Test that async and sync produce identical digests for the same content.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") - @pytest.fixture - def s3(self): - """Moto S3 mock that implements the S3 API.""" - with mock_aws(): - s3_client = boto3.client( - "s3", - aws_access_key_id="test-key", - aws_secret_access_key="test-secret", - region_name="us-east-1", - ) - s3_client.create_bucket(Bucket=TEST_BUCKET) - yield s3_client + req = FileCreateReq( + project_id=client._project_id(), + name="consistency_test.txt", + content=b"consistent content for testing", + ) - @pytest.fixture - def aws_storage_env_async(self): - """Setup AWS storage environment for async tests.""" - with mock.patch.dict( - os.environ, - { - "WF_FILE_STORAGE_AWS_ACCESS_KEY_ID": "test-key", - "WF_FILE_STORAGE_AWS_SECRET_ACCESS_KEY": "test-secret", - "WF_FILE_STORAGE_URI": f"s3://{TEST_BUCKET}", - "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", - }, - ): - yield + sync_res = client.server.file_create(req) + async_res = await client.server.file_create_async(req) + + assert sync_res.digest == async_res.digest - @pytest.mark.asyncio - @pytest.mark.usefixtures("aws_storage_env_async") - async def test_async_aws_storage(self, client: WeaveClient, s3): - """Test async file storage using AWS S3.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite") +@pytest.mark.asyncio +async def test_async_file_create_concurrent(client: WeaveClient): + """Test concurrent async file creations complete correctly with unique digests.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + async def create_file(content: bytes, name: str) -> tuple[str, float]: + start = time.monotonic() req = FileCreateReq( project_id=client._project_id(), - name="async_s3_test.txt", - content=TEST_CONTENT, + name=name, + content=content, ) res = await client.server.file_create_async(req) + return res.digest, time.monotonic() - start - assert res.digest is not None - - # Verify the object exists in S3 - response = s3.list_objects_v2(Bucket=TEST_BUCKET) - assert "Contents" in response - assert len(response["Contents"]) >= 1 - - # Verify content via server - file = client.server.file_content_read( - FileContentReadReq(project_id=client._project_id(), digest=res.digest) - ) - assert file.content == TEST_CONTENT + tasks = [create_file(f"content_{i}".encode(), f"file_{i}.txt") for i in range(5)] + start_total = time.monotonic() + results = await asyncio.gather(*tasks) + total_elapsed = time.monotonic() - start_total -class TestAsyncGCSStorage: - """Tests for async Google Cloud Storage implementation.""" + digests = [r[0] for r in results] + assert all(d is not None and d != "" for d in digests) + assert len(set(digests)) == 5 # All unique - @pytest.fixture - def mock_gcp_credentials(self): - """Mock GCP credentials to prevent authentication.""" - with mock.patch( - "google.oauth2.service_account.Credentials.from_service_account_info" - ) as mock_creds: - from google.auth.credentials import AnonymousCredentials + # Sanity check: concurrent should show some speedup vs serial + individual_times = [r[1] for r in results] + assert total_elapsed < sum(individual_times) * 0.9 or total_elapsed < 1.0 - mock_creds.return_value = AnonymousCredentials() - yield - @pytest.fixture - def gcs_async(self): - """Google Cloud Storage mock for async tests.""" - mock_storage_client = mock.MagicMock() - mock_bucket = mock.MagicMock() - mock_blob = mock.MagicMock() +@pytest.mark.asyncio +@pytest.mark.usefixtures("aws_storage_env_async") +async def test_async_aws_storage(client: WeaveClient, s3_mock): + """Test async file storage using AWS S3.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") - mock_storage_client.bucket.return_value = mock_bucket - mock_bucket.blob.return_value = mock_blob + req = FileCreateReq( + project_id=client._project_id(), + name="async_s3_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) - blob_data = {} + assert res.digest is not None - def mock_upload_from_string( - data, timeout=None, if_generation_match=None, **kwargs - ): - blob_name = mock_blob.name - if if_generation_match == 0 and blob_name in blob_data: - from google.api_core import exceptions + # Verify the object exists in S3 + response = s3_mock.list_objects_v2(Bucket=TEST_BUCKET) + assert "Contents" in response + assert len(response["Contents"]) >= 1 - raise exceptions.PreconditionFailed("Object already exists") - blob_data[blob_name] = data - - def mock_download_as_bytes(timeout=None, **kwargs): - blob_name = mock_blob.name - return blob_data.get(blob_name, b"") + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT - mock_blob.upload_from_string.side_effect = mock_upload_from_string - mock_blob.download_as_bytes.side_effect = mock_download_as_bytes - with mock.patch( - "google.cloud.storage.Client", return_value=mock_storage_client - ): - yield mock_storage_client - - @pytest.fixture - def gcp_storage_env_async(self): - """Setup GCP storage environment for async tests.""" - with mock.patch.dict( - os.environ, - { - "WF_FILE_STORAGE_GCP_CREDENTIALS_JSON_B64": base64.b64encode( - b"""{ - "type": "service_account", - "project_id": "test-project", - "private_key_id": "test-key-id", - "private_key": "test-key", - "client_email": "test@test-project.iam.gserviceaccount.com", - "client_id": "test-client-id", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" - }""" - ).decode(), - "WF_FILE_STORAGE_URI": f"gs://{TEST_BUCKET}", - "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", - }, - ): - yield +@pytest.mark.asyncio +@pytest.mark.usefixtures("gcp_storage_env_async", "mock_gcp_credentials") +async def test_async_gcs_storage(client: WeaveClient, mock_gcs_client): + """Test async file storage using GCS.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") - @pytest.mark.asyncio - @pytest.mark.usefixtures("gcp_storage_env_async", "mock_gcp_credentials") - async def test_async_gcs_storage(self, client: WeaveClient, gcs_async): - """Test async file storage using GCS.""" - if client_is_sqlite(client): - pytest.skip("Skipping for SQLite") + req = FileCreateReq( + project_id=client._project_id(), + name="async_gcs_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) - req = FileCreateReq( - project_id=client._project_id(), - name="async_gcs_test.txt", - content=TEST_CONTENT, - ) - res = await client.server.file_create_async(req) + assert res.digest is not None - assert res.digest is not None + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT - # Verify content via server - file = client.server.file_content_read( - FileContentReadReq(project_id=client._project_id(), digest=res.digest) - ) - assert file.content == TEST_CONTENT +# Parametrized test data for SQLite async tests +SQLITE_ASYNC_PARAMS = [ + pytest.param(TEST_CONTENT, "sqlite_async_test.txt", 1, id="single_file"), + pytest.param(None, None, 3, id="concurrent_files"), # Special case for concurrent +] -class TestAsyncSQLiteStorage: - """Tests for async SQLite storage (thread pool wrapper).""" - @pytest.mark.asyncio - async def test_sqlite_async_wrapper(self, client: WeaveClient): - """Test that SQLite async wrapper works correctly.""" - if not client_is_sqlite(client): - pytest.skip("This test is specifically for SQLite") +@pytest.mark.asyncio +@pytest.mark.parametrize(("content", "filename", "count"), SQLITE_ASYNC_PARAMS) +async def test_sqlite_async(client: WeaveClient, content, filename, count): + """Test SQLite async operations: single file and concurrent file creation.""" + if not client_is_sqlite(client): + pytest.skip("This test is specifically for SQLite") + if count == 1: + # Single file test req = FileCreateReq( project_id=client._project_id(), - name="sqlite_async_test.txt", - content=TEST_CONTENT, + name=filename, + content=content, ) res = await client.server.file_create_async(req) assert res.digest is not None assert res.digest != "" - # Verify content file = client.server.file_content_read( FileContentReadReq(project_id=client._project_id(), digest=res.digest) ) - assert file.content == TEST_CONTENT - - @pytest.mark.asyncio - async def test_sqlite_async_concurrent(self, client: WeaveClient): - """Test concurrent SQLite async operations.""" - if not client_is_sqlite(client): - pytest.skip("This test is specifically for SQLite") - + assert file.content == content + else: + # Concurrent files test async def create_file(i: int) -> str: req = FileCreateReq( project_id=client._project_id(), @@ -798,11 +757,8 @@ async def create_file(i: int) -> str: res = await client.server.file_create_async(req) return res.digest - # Create multiple files concurrently - tasks = [create_file(i) for i in range(3)] + tasks = [create_file(i) for i in range(count)] digests = await asyncio.gather(*tasks) - # All should succeed assert all(d is not None and d != "" for d in digests) - # All digests should be unique - assert len(set(digests)) == 3 + assert len(set(digests)) == count diff --git a/weave/trace_server/file_storage.py b/weave/trace_server/file_storage.py index 2e59ebb13804..56e7b13d4db9 100644 --- a/weave/trace_server/file_storage.py +++ b/weave/trace_server/file_storage.py @@ -43,12 +43,16 @@ import asyncio import logging from abc import abstractmethod -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager from typing import Any, cast import boto3 +from aiobotocore.session import AioSession +from aiobotocore.session import get_session as get_aio_session from azure.core.exceptions import HttpResponseError from azure.storage.blob import BlobServiceClient +from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient from botocore.config import Config from botocore.exceptions import ClientError from google.api_core import exceptions as gcp_exceptions @@ -450,21 +454,18 @@ def __init__( assert isinstance(base_uri, S3FileStorageURI) self._credentials = credentials self._kms_key = credentials.get("kms_key") - - async def store_async(self, uri: FileStorageURI, data: bytes) -> None: - """Store data in S3 bucket asynchronously using aiobotocore.""" - assert isinstance(uri, S3FileStorageURI) - assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) - - # Import aiobotocore lazily to avoid hard dependency - try: - from aiobotocore.session import get_session - except ImportError: - # Fallback to sync client in thread pool if aiobotocore not available - await asyncio.to_thread(self._sync_client.store, uri, data) - return - - session = get_session() + self._session: AioSession | None = None + + def _get_session(self) -> AioSession: + """Get or create an aiobotocore session (reused across calls).""" + if self._session is None: + self._session = get_aio_session() + return self._session + + @asynccontextmanager + async def _get_async_s3_client(self) -> AsyncIterator[Any]: + """Context manager that yields an async S3 client.""" + session = self._get_session() async with session.create_client( "s3", region_name=self._credentials.get("region"), @@ -472,6 +473,14 @@ async def store_async(self, uri: FileStorageURI, data: bytes) -> None: aws_secret_access_key=self._credentials.get("secret_access_key"), aws_session_token=self._credentials.get("session_token"), ) as client: + yield client + + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in S3 bucket asynchronously using aiobotocore.""" + assert isinstance(uri, S3FileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with self._get_async_s3_client() as client: put_object_params: dict[str, Any] = { "Bucket": uri.bucket, "Key": uri.path, @@ -488,26 +497,14 @@ async def read_async(self, uri: FileStorageURI) -> bytes: assert isinstance(uri, S3FileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) - try: - from aiobotocore.session import get_session - except ImportError: - return await asyncio.to_thread(self._sync_client.read, uri) - - session = get_session() - async with session.create_client( - "s3", - region_name=self._credentials.get("region"), - aws_access_key_id=self._credentials.get("access_key_id"), - aws_secret_access_key=self._credentials.get("secret_access_key"), - aws_session_token=self._credentials.get("session_token"), - ) as client: + async with self._get_async_s3_client() as client: response = await client.get_object(Bucket=uri.bucket, Key=uri.path) async with response["Body"] as stream: return await stream.read() class AsyncGCSStorageClient(AsyncFileStorageClient): - """Async Google Cloud Storage implementation.""" + """Async Google Cloud Storage implementation using thread pool wrapper.""" def __init__( self, @@ -520,22 +517,15 @@ def __init__( self._credentials = credentials async def store_async(self, uri: FileStorageURI, data: bytes) -> None: - """Store data in GCS bucket asynchronously. - - Falls back to sync client in thread pool since gcloud-aio-storage - is not a standard dependency. - """ + """Store data in GCS bucket asynchronously.""" assert isinstance(uri, GCSFileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) - - # GCS async requires gcloud-aio-storage; fall back to threaded sync await asyncio.to_thread(self._sync_client.store, uri, data) async def read_async(self, uri: FileStorageURI) -> bytes: """Read data from GCS bucket asynchronously.""" assert isinstance(uri, GCSFileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) - return await asyncio.to_thread(self._sync_client.read, uri) @@ -552,10 +542,8 @@ def __init__( assert isinstance(base_uri, AzureFileStorageURI) self._credentials = credentials - async def _get_async_client(self, account: str) -> Any: + async def _get_async_client(self, account: str) -> AsyncBlobServiceClient: """Create async Azure client based on available credentials.""" - from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient - if "connection_string" in self._credentials: connection_creds = cast(AzureConnectionCredentials, self._credentials) return AsyncBlobServiceClient.from_connection_string( @@ -641,5 +629,5 @@ def maybe_get_async_storage_client_from_env() -> AsyncFileStorageClient | None: return AsyncAzureStorageClient(parsed_uri, credentials, sync_client) else: raise NotImplementedError( - f"Async storage client for URI type {type(file_storage_uri)} not supported" + f"Async storage client for URI type {type(parsed_uri)} not supported" ) From 29d96f1a2acd81d1c0549bbadb2a0f3e54d622a7 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Sun, 1 Feb 2026 12:50:12 -0800 Subject: [PATCH 4/4] review comments --- .../clickhouse_trace_server_batched.py | 149 ++++++++++-------- weave/trace_server/file_storage.py | 80 +++++++++- 2 files changed, 159 insertions(+), 70 deletions(-) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 75cc9fc96f89..dd3d9e33cbad 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -322,6 +322,15 @@ async def _get_async_ch_client(self) -> clickhouse_connect.driver.AsyncClient: ) return self._async_ch_client + async def close_async_ch_client(self) -> None: + """Close the async ClickHouse client if it exists. + + Should be called during shutdown to release resources. + """ + if self._async_ch_client is not None: + await self._async_ch_client.close() + self._async_ch_client = None + @property def kafka_producer(self) -> KafkaProducer: if self._kafka_producer is not None: @@ -4735,87 +4744,99 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: """Async version of file_create using async storage and ClickHouse clients.""" - digest = bytes_digest(req.content) - use_file_storage = self._should_use_file_storage_for_writes(req.project_id) - async_client = self.async_file_storage_client - - if async_client is not None and use_file_storage: - try: - await self._file_create_bucket_async(req, digest, async_client) - except FileStorageWriteError: + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched.file_create_async" + ): + digest = bytes_digest(req.content) + use_file_storage = self._should_use_file_storage_for_writes(req.project_id) + async_client = self.async_file_storage_client + + if async_client is not None and use_file_storage: + try: + await self._file_create_bucket_async(req, digest, async_client) + except FileStorageWriteError: + await self._file_create_clickhouse_async(req, digest) + else: await self._file_create_clickhouse_async(req, digest) - else: - await self._file_create_clickhouse_async(req, digest) - set_root_span_dd_tags({"write_bytes": len(req.content)}) - return tsi.FileCreateRes(digest=digest) + set_root_span_dd_tags({"write_bytes": len(req.content)}) + return tsi.FileCreateRes(digest=digest) async def _file_create_clickhouse_async( self, req: tsi.FileCreateReq, digest: str ) -> None: """Async version of _file_create_clickhouse.""" - set_root_span_dd_tags({"storage_provider": "clickhouse"}) - chunks = [ - req.content[i : i + ch_settings.FILE_CHUNK_SIZE] - for i in range(0, len(req.content), ch_settings.FILE_CHUNK_SIZE) - ] - await self._insert_file_chunks_async( - [ - FileChunkCreateCHInsertable( - project_id=req.project_id, - digest=digest, - chunk_index=i, - n_chunks=len(chunks), - name=req.name, - val_bytes=chunk, - bytes_stored=len(chunk), - file_storage_uri=None, - ) - for i, chunk in enumerate(chunks) + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._file_create_clickhouse_async" + ): + set_root_span_dd_tags({"storage_provider": "clickhouse"}) + chunks = [ + req.content[i : i + ch_settings.FILE_CHUNK_SIZE] + for i in range(0, len(req.content), ch_settings.FILE_CHUNK_SIZE) ] - ) + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=i, + n_chunks=len(chunks), + name=req.name, + val_bytes=chunk, + bytes_stored=len(chunk), + file_storage_uri=None, + ) + for i, chunk in enumerate(chunks) + ] + ) async def _file_create_bucket_async( self, req: tsi.FileCreateReq, digest: str, client: AsyncFileStorageClient ) -> None: """Async version of _file_create_bucket.""" - set_root_span_dd_tags({"storage_provider": "bucket"}) - target_file_storage_uri = await store_in_bucket_async( - client, key_for_project_digest(req.project_id, digest), req.content - ) - await self._insert_file_chunks_async( - [ - FileChunkCreateCHInsertable( - project_id=req.project_id, - digest=digest, - chunk_index=0, - n_chunks=1, - name=req.name, - val_bytes=b"", - bytes_stored=len(req.content), - file_storage_uri=target_file_storage_uri.to_uri_str(), - ) - ] - ) + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._file_create_bucket_async" + ): + set_root_span_dd_tags({"storage_provider": "bucket"}) + target_file_storage_uri = await store_in_bucket_async( + client, key_for_project_digest(req.project_id, digest), req.content + ) + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=0, + n_chunks=1, + name=req.name, + val_bytes=b"", + bytes_stored=len(req.content), + file_storage_uri=target_file_storage_uri.to_uri_str(), + ) + ] + ) async def _insert_file_chunks_async( self, file_chunks: list[FileChunkCreateCHInsertable] ) -> None: """Async version of _insert_file_chunks using async ClickHouse client.""" - data = [] - for chunk in file_chunks: - chunk_dump = chunk.model_dump() - row = [] - for col in ALL_FILE_CHUNK_INSERT_COLUMNS: - row.append(chunk_dump.get(col, None)) - data.append(row) - - if data: - client = await self._get_async_ch_client() - await client.insert( - "files", - data=data, - column_names=ALL_FILE_CHUNK_INSERT_COLUMNS, - ) + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._insert_file_chunks_async" + ): + data = [] + for chunk in file_chunks: + chunk_dump = chunk.model_dump() + row = [] + for col in ALL_FILE_CHUNK_INSERT_COLUMNS: + row.append(chunk_dump.get(col, None)) + data.append(row) + + if data: + client = await self._get_async_ch_client() + await client.insert( + "files", + data=data, + column_names=ALL_FILE_CHUNK_INSERT_COLUMNS, + ) @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_clickhouse") def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: diff --git a/weave/trace_server/file_storage.py b/weave/trace_server/file_storage.py index 56e7b13d4db9..c3d58709b0b0 100644 --- a/weave/trace_server/file_storage.py +++ b/weave/trace_server/file_storage.py @@ -135,8 +135,8 @@ def store_in_bucket( client: FileStorageClient, path: str, data: bytes ) -> FileStorageURI: """Store a file in a storage bucket.""" + target_file_storage_uri = client.base_uri.with_path(path) try: - target_file_storage_uri = client.base_uri.with_path(path) client.store(target_file_storage_uri, data) except Exception as e: logger.exception("Failed to store file at %s", target_file_storage_uri) @@ -418,15 +418,69 @@ def maybe_get_storage_client_from_env() -> FileStorageClient | None: # ============================================================================= # Async Storage Clients # ============================================================================= +# +# Async implementations use different strategies based on library support: +# - S3: Native async via aiobotocore (best performance) +# - GCS: Thread pool wrapper (no mature async library available) +# - Azure: Native async via azure-storage-blob aio support + + +def create_async_retry_decorator(operation_name: str) -> Callable[[Any], Any]: + """Creates an async retry decorator with consistent retry policy and special 429 handling. + + Uses the same retry logic as the sync version to ensure consistent error handling. + """ + + def after_retry(retry_state: RetryCallState) -> None: + if retry_state.attempt_number > 1: + logger.info( + "%s: Attempt %d/%d after %.2f seconds", + operation_name, + retry_state.attempt_number, + RETRY_MAX_ATTEMPTS, + retry_state.seconds_since_start, + ) + + def create_wait_strategy(retry_state: RetryCallState) -> float: + """Create wait strategy that uses jitter for rate limit errors.""" + if retry_state.outcome and retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception and _is_rate_limit_error(exception): + # Use random exponential backoff with jitter for rate limiting + return wait_random_exponential( + multiplier=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT + )(retry_state) + # Use regular exponential backoff for other errors + return wait_exponential(multiplier=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT)( + retry_state + ) + + return retry( + stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), + wait=create_wait_strategy, + reraise=True, + before_sleep=before_sleep_log(logger, logging.DEBUG), + after=after_retry, + ) class AsyncFileStorageClient: - """Abstract base class for async cloud storage operations.""" + """Abstract base class for async cloud storage operations. + + Follows the same pattern as FileStorageClient - uses @abstractmethod + without ABC inheritance for interface definition. + """ base_uri: FileStorageURI _sync_client: FileStorageClient def __init__(self, base_uri: FileStorageURI, sync_client: FileStorageClient): + """Initialize async storage client. + + Args: + base_uri: The base URI for storage operations. + sync_client: A sync client used for fallback or thread pool operations. + """ self.base_uri = base_uri self._sync_client = sync_client @@ -475,6 +529,7 @@ async def _get_async_s3_client(self) -> AsyncIterator[Any]: ) as client: yield client + @create_async_retry_decorator("async_s3_storage") async def store_async(self, uri: FileStorageURI, data: bytes) -> None: """Store data in S3 bucket asynchronously using aiobotocore.""" assert isinstance(uri, S3FileStorageURI) @@ -492,6 +547,7 @@ async def store_async(self, uri: FileStorageURI, data: bytes) -> None: await client.put_object(**put_object_params) + @create_async_retry_decorator("async_s3_read") async def read_async(self, uri: FileStorageURI) -> bytes: """Read data from S3 bucket asynchronously.""" assert isinstance(uri, S3FileStorageURI) @@ -504,7 +560,11 @@ async def read_async(self, uri: FileStorageURI) -> bytes: class AsyncGCSStorageClient(AsyncFileStorageClient): - """Async Google Cloud Storage implementation using thread pool wrapper.""" + """Async Google Cloud Storage implementation using thread pool wrapper. + + GCS does not have a mature async library, so we wrap sync operations + in asyncio.to_thread. The sync client already has retry logic applied. + """ def __init__( self, @@ -517,13 +577,19 @@ def __init__( self._credentials = credentials async def store_async(self, uri: FileStorageURI, data: bytes) -> None: - """Store data in GCS bucket asynchronously.""" + """Store data in GCS bucket asynchronously. + + Delegates to sync client which has retry logic applied. + """ assert isinstance(uri, GCSFileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) await asyncio.to_thread(self._sync_client.store, uri, data) async def read_async(self, uri: FileStorageURI) -> bytes: - """Read data from GCS bucket asynchronously.""" + """Read data from GCS bucket asynchronously. + + Delegates to sync client which has retry logic applied. + """ assert isinstance(uri, GCSFileStorageURI) assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) return await asyncio.to_thread(self._sync_client.read, uri) @@ -564,6 +630,7 @@ async def _get_async_client(self, account: str) -> AsyncBlobServiceClient: read_timeout=DEFAULT_READ_TIMEOUT, ) + @create_async_retry_decorator("async_azure_storage") async def store_async(self, uri: FileStorageURI, data: bytes) -> None: """Store data in Azure container asynchronously.""" assert isinstance(uri, AzureFileStorageURI) @@ -574,6 +641,7 @@ async def store_async(self, uri: FileStorageURI, data: bytes) -> None: blob_client = container_client.get_blob_client(uri.path) await blob_client.upload_blob(data, overwrite=True) + @create_async_retry_decorator("async_azure_read") async def read_async(self, uri: FileStorageURI) -> bytes: """Read data from Azure container asynchronously.""" assert isinstance(uri, AzureFileStorageURI) @@ -590,8 +658,8 @@ async def store_in_bucket_async( client: AsyncFileStorageClient, path: str, data: bytes ) -> FileStorageURI: """Store a file in a storage bucket asynchronously.""" + target_file_storage_uri = client.base_uri.with_path(path) try: - target_file_storage_uri = client.base_uri.with_path(path) await client.store_async(target_file_storage_uri, data) except Exception as e: logger.exception("Failed to store file at %s", target_file_storage_uri)