diff --git a/pyproject.toml b/pyproject.toml index 13fc5934f28d..a3d8204e8328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ trace_server = [ "ddtrace>=2.7.0", # BYOB - S3 "boto3>=1.34.0", + "aiobotocore>=3.1.1", # Async S3 support # BYOB - Azure "azure-storage-blob>=12.24.0,<12.26.0", # BYOB - GCP diff --git a/tests/trace/test_file_storage_async.py b/tests/trace/test_file_storage_async.py new file mode 100644 index 000000000000..e584a42c18ff --- /dev/null +++ b/tests/trace/test_file_storage_async.py @@ -0,0 +1,257 @@ +"""Unit tests for async file storage client implementations. + +This module tests the AsyncFileStorageClient implementations directly, +independent of the trace server layer. +""" + +import asyncio +import time +from unittest import mock + +import pytest + +from weave.trace_server.file_storage import ( + AsyncFileStorageClient, + AsyncGCSStorageClient, + AsyncS3StorageClient, + FileStorageWriteError, + GCSStorageClient, + S3StorageClient, + store_in_bucket_async, +) +from weave.trace_server.file_storage_uris import ( + GCSFileStorageURI, + S3FileStorageURI, +) + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def s3_uri(): + """S3 URI fixture for testing.""" + return S3FileStorageURI(bucket="test-bucket", path="") + + +@pytest.fixture +def s3_credentials(): + """AWS credentials fixture for testing.""" + return { + "access_key_id": "test-key", + "secret_access_key": "test-secret", + "region": "us-east-1", + "session_token": None, + "kms_key": None, + } + + +@pytest.fixture +def mock_sync_s3_client(s3_uri): + """Create a mock sync S3 client.""" + client = mock.MagicMock(spec=S3StorageClient) + client.base_uri = s3_uri + return client + + +@pytest.fixture +def gcs_uri(): + """GCS URI fixture for testing.""" + return GCSFileStorageURI(bucket="test-bucket", path="") + + +@pytest.fixture +def mock_sync_gcs_client(gcs_uri): + """Create a mock sync GCS client.""" + client = mock.MagicMock(spec=GCSStorageClient) + client.base_uri = gcs_uri + return client + + +@pytest.fixture +def mock_async_client(): + """Create a mock async storage client for store_in_bucket_async tests.""" + client = mock.AsyncMock(spec=AsyncFileStorageClient) + client.base_uri = S3FileStorageURI(bucket="test-bucket", path="") + client.base_uri.with_path = lambda p: S3FileStorageURI(bucket="test-bucket", path=p) + return client + + +# ============================================================================= +# AsyncS3StorageClient Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_s3_store_async_calls_aiobotocore( + s3_uri, s3_credentials, mock_sync_s3_client +): + """Test that S3 store_async uses aiobotocore client.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) + + mock_aio_client = mock.AsyncMock() + mock_aio_client.put_object = mock.AsyncMock() + + with mock.patch.object(async_client, "_get_async_s3_client") as mock_get_client: + mock_get_client.return_value.__aenter__ = mock.AsyncMock( + return_value=mock_aio_client + ) + mock_get_client.return_value.__aexit__ = mock.AsyncMock() + + target_uri = s3_uri.with_path("test/file.txt") + await async_client.store_async(target_uri, b"test content") + + mock_get_client.assert_called_once() + + +@pytest.mark.asyncio +async def test_s3_read_async_calls_aiobotocore( + s3_uri, s3_credentials, mock_sync_s3_client +): + """Test that S3 read_async uses aiobotocore client.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) + + mock_stream = mock.AsyncMock() + mock_stream.read = mock.AsyncMock(return_value=b"test content") + mock_aio_client = mock.AsyncMock() + mock_aio_client.get_object = mock.AsyncMock(return_value={"Body": mock_stream}) + mock_stream.__aenter__ = mock.AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = mock.AsyncMock() + + with mock.patch.object(async_client, "_get_async_s3_client") as mock_get_client: + mock_get_client.return_value.__aenter__ = mock.AsyncMock( + return_value=mock_aio_client + ) + mock_get_client.return_value.__aexit__ = mock.AsyncMock() + + target_uri = s3_uri.with_path("test/file.txt") + await async_client.read_async(target_uri) + + mock_get_client.assert_called_once() + + +@pytest.mark.asyncio +async def test_s3_session_reused(s3_uri, s3_credentials, mock_sync_s3_client): + """Test that the aiobotocore session is reused across calls.""" + async_client = AsyncS3StorageClient(s3_uri, s3_credentials, mock_sync_s3_client) + + # First call creates session + session1 = async_client._get_session() + # Second call reuses session + session2 = async_client._get_session() + + assert session1 is session2 + + +# ============================================================================= +# AsyncGCSStorageClient Tests +# ============================================================================= + + +@pytest.mark.asyncio +@pytest.mark.parametrize("operation", ["store", "read"]) +async def test_gcs_async_uses_thread_pool(gcs_uri, mock_sync_gcs_client, operation): + """Test that GCS async operations use thread pool wrapper.""" + async_client = AsyncGCSStorageClient(gcs_uri, None, mock_sync_gcs_client) + target_uri = gcs_uri.with_path("test/file.txt") + + if operation == "store": + mock_sync_gcs_client.store.return_value = None + with mock.patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = None + await async_client.store_async(target_uri, b"test content") + mock_to_thread.assert_called_once_with( + mock_sync_gcs_client.store, target_uri, b"test content" + ) + else: + mock_sync_gcs_client.read.return_value = b"test content" + with mock.patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = b"test content" + result = await async_client.read_async(target_uri) + mock_to_thread.assert_called_once_with( + mock_sync_gcs_client.read, target_uri + ) + assert result == b"test content" + + +# ============================================================================= +# store_in_bucket_async Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_store_in_bucket_async_success(mock_async_client): + """Test successful async bucket storage returns correct URI.""" + mock_async_client.store_async.return_value = None + + result = await store_in_bucket_async( + mock_async_client, "test/path/file.txt", b"content" + ) + + assert result.bucket == "test-bucket" + assert result.path == "test/path/file.txt" + mock_async_client.store_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_store_in_bucket_async_failure_raises_error(mock_async_client): + """Test async bucket storage failure raises FileStorageWriteError.""" + mock_async_client.store_async.side_effect = Exception("Storage failed") + + with pytest.raises(FileStorageWriteError) as exc_info: + await store_in_bucket_async(mock_async_client, "test/path/file.txt", b"content") + + assert "Failed to store file" in str(exc_info.value) + + +# ============================================================================= +# Concurrency Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_concurrent_operations_complete_in_parallel(): + """Test that multiple async operations run concurrently, not serially.""" + + async def mock_async_operation(delay: float, name: str) -> str: + await asyncio.sleep(delay) + return name + + # Run 3 operations concurrently, each taking 0.1s + tasks = [mock_async_operation(0.1, f"op{i}") for i in range(3)] + + start = time.monotonic() + results = await asyncio.gather(*tasks) + total_time = time.monotonic() - start + + assert set(results) == {"op0", "op1", "op2"} + # If truly concurrent, total time should be ~0.1s, not 0.3s + assert total_time < 0.25, f"Operations took {total_time}s, expected < 0.25s" + + +@pytest.mark.asyncio +async def test_event_loop_not_blocked(): + """Test that async operations don't block the event loop.""" + event_loop_blocked = False + + async def check_event_loop(): + nonlocal event_loop_blocked + for _ in range(5): + start = time.monotonic() + await asyncio.sleep(0.01) + elapsed = time.monotonic() - start + if elapsed > 0.05: + event_loop_blocked = True + break + + async def fast_async_operation(): + await asyncio.sleep(0.02) + return "done" + + checker_task = asyncio.create_task(check_event_loop()) + result = await fast_async_operation() + await checker_task + + assert result == "done" + assert not event_loop_blocked, "Event loop was blocked during async operation" diff --git a/tests/trace/test_server_file_storage.py b/tests/trace/test_server_file_storage.py index 1114ab675e93..84f4b9857eac 100644 --- a/tests/trace/test_server_file_storage.py +++ b/tests/trace/test_server_file_storage.py @@ -5,8 +5,10 @@ specific setup requirements. """ +import asyncio import base64 import os +import time from unittest import mock import boto3 @@ -481,3 +483,282 @@ def mock_upload_fail(*args, **kwargs): # Verify GCS was attempted exactly 3 times before fallback assert attempt_count == 3, f"Expected 3 GCS attempts, got {attempt_count}" + + +# ============================================================================= +# Async file_create tests +# ============================================================================= + + +@pytest.fixture +def aws_storage_env_async(): + """Setup AWS storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_AWS_ACCESS_KEY_ID": "test-key", + "WF_FILE_STORAGE_AWS_SECRET_ACCESS_KEY": "test-secret", + "WF_FILE_STORAGE_URI": f"s3://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield + + +@pytest.fixture +def gcp_storage_env_async(): + """Setup GCP storage environment for async tests.""" + with mock.patch.dict( + os.environ, + { + "WF_FILE_STORAGE_GCP_CREDENTIALS_JSON_B64": base64.b64encode( + b"""{ + "type": "service_account", + "project_id": "test-project", + "private_key_id": "test-key-id", + "private_key": "test-key", + "client_email": "test@test-project.iam.gserviceaccount.com", + "client_id": "test-client-id", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" + }""" + ).decode(), + "WF_FILE_STORAGE_URI": f"gs://{TEST_BUCKET}", + "WF_FILE_STORAGE_PROJECT_ALLOW_LIST": "c2hhd24vdGVzdC1wcm9qZWN0", + }, + ): + yield + + +@pytest.fixture +def mock_gcp_credentials(): + """Mock GCP credentials to prevent authentication.""" + with mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info" + ) as mock_creds: + mock_creds.return_value = AnonymousCredentials() + yield + + +@pytest.fixture +def mock_gcs_client(): + """Google Cloud Storage mock for tests.""" + mock_storage_client = mock.MagicMock() + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + + mock_storage_client.bucket.return_value = mock_bucket + mock_bucket.blob.return_value = mock_blob + + blob_data = {} + + def mock_upload_from_string(data, timeout=None, if_generation_match=None, **kwargs): + blob_name = mock_blob.name + if if_generation_match == 0 and blob_name in blob_data: + raise exceptions.PreconditionFailed("Object already exists") + blob_data[blob_name] = data + + def mock_download_as_bytes(timeout=None, **kwargs): + blob_name = mock_blob.name + return blob_data.get(blob_name, b"") + + mock_blob.upload_from_string.side_effect = mock_upload_from_string + mock_blob.download_as_bytes.side_effect = mock_download_as_bytes + + with mock.patch("google.cloud.storage.Client", return_value=mock_storage_client): + yield mock_storage_client + + +@pytest.fixture +def s3_mock(): + """Moto S3 mock that implements the S3 API.""" + with mock_aws(): + s3_client = boto3.client( + "s3", + aws_access_key_id="test-key", + aws_secret_access_key="test-secret", + region_name="us-east-1", + ) + s3_client.create_bucket(Bucket=TEST_BUCKET) + yield s3_client + + +# Parametrized test data for async file creation with various content sizes +ASYNC_FILE_CREATE_PARAMS = [ + pytest.param(TEST_CONTENT, "basic.txt", id="basic_content"), + pytest.param(b"x" * 300000, "large.bin", id="large_file"), # 3x chunk size +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize(("content", "filename"), ASYNC_FILE_CREATE_PARAMS) +async def test_async_file_create(client: WeaveClient, content: bytes, filename: str): + """Test async file creation with various content sizes.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite - testing ClickHouse async path") + + req = FileCreateReq( + project_id=client._project_id(), + name=filename, + content=content, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + assert res.digest != "" + + # Verify content can be read back + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == content + + +@pytest.mark.asyncio +async def test_async_file_create_consistency_with_sync(client: WeaveClient): + """Test that async and sync produce identical digests for the same content.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="consistency_test.txt", + content=b"consistent content for testing", + ) + + sync_res = client.server.file_create(req) + async_res = await client.server.file_create_async(req) + + assert sync_res.digest == async_res.digest + + +@pytest.mark.asyncio +async def test_async_file_create_concurrent(client: WeaveClient): + """Test concurrent async file creations complete correctly with unique digests.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + async def create_file(content: bytes, name: str) -> tuple[str, float]: + start = time.monotonic() + req = FileCreateReq( + project_id=client._project_id(), + name=name, + content=content, + ) + res = await client.server.file_create_async(req) + return res.digest, time.monotonic() - start + + tasks = [create_file(f"content_{i}".encode(), f"file_{i}.txt") for i in range(5)] + + start_total = time.monotonic() + results = await asyncio.gather(*tasks) + total_elapsed = time.monotonic() - start_total + + digests = [r[0] for r in results] + assert all(d is not None and d != "" for d in digests) + assert len(set(digests)) == 5 # All unique + + # Sanity check: concurrent should show some speedup vs serial + individual_times = [r[1] for r in results] + assert total_elapsed < sum(individual_times) * 0.9 or total_elapsed < 1.0 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("aws_storage_env_async") +async def test_async_aws_storage(client: WeaveClient, s3_mock): + """Test async file storage using AWS S3.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="async_s3_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + + # Verify the object exists in S3 + response = s3_mock.list_objects_v2(Bucket=TEST_BUCKET) + assert "Contents" in response + assert len(response["Contents"]) >= 1 + + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("gcp_storage_env_async", "mock_gcp_credentials") +async def test_async_gcs_storage(client: WeaveClient, mock_gcs_client): + """Test async file storage using GCS.""" + if client_is_sqlite(client): + pytest.skip("Skipping for SQLite") + + req = FileCreateReq( + project_id=client._project_id(), + name="async_gcs_test.txt", + content=TEST_CONTENT, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + + # Verify content via server + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == TEST_CONTENT + + +# Parametrized test data for SQLite async tests +SQLITE_ASYNC_PARAMS = [ + pytest.param(TEST_CONTENT, "sqlite_async_test.txt", 1, id="single_file"), + pytest.param(None, None, 3, id="concurrent_files"), # Special case for concurrent +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize(("content", "filename", "count"), SQLITE_ASYNC_PARAMS) +async def test_sqlite_async(client: WeaveClient, content, filename, count): + """Test SQLite async operations: single file and concurrent file creation.""" + if not client_is_sqlite(client): + pytest.skip("This test is specifically for SQLite") + + if count == 1: + # Single file test + req = FileCreateReq( + project_id=client._project_id(), + name=filename, + content=content, + ) + res = await client.server.file_create_async(req) + + assert res.digest is not None + assert res.digest != "" + + file = client.server.file_content_read( + FileContentReadReq(project_id=client._project_id(), digest=res.digest) + ) + assert file.content == content + else: + # Concurrent files test + async def create_file(i: int) -> str: + req = FileCreateReq( + project_id=client._project_id(), + name=f"sqlite_concurrent_{i}.txt", + content=f"sqlite_content_{i}".encode(), + ) + res = await client.server.file_create_async(req) + return res.digest + + tasks = [create_file(i) for i in range(count)] + digests = await asyncio.gather(*tasks) + + assert all(d is not None and d != "" for d in digests) + assert len(set(digests)) == count diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index e00c61bd8095..dd3d9e33cbad 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -123,13 +123,16 @@ validate_feedback_purge_req, ) from weave.trace_server.file_storage import ( + AsyncFileStorageClient, FileStorageClient, FileStorageReadError, FileStorageWriteError, key_for_project_digest, + maybe_get_async_storage_client_from_env, maybe_get_storage_client_from_env, read_from_bucket, store_in_bucket, + store_in_bucket_async, ) from weave.trace_server.file_storage_uris import FileStorageURI from weave.trace_server.ids import generate_id @@ -233,6 +236,8 @@ def __init__( self._use_async_insert = use_async_insert self._model_to_provider_info_map = read_model_to_provider_info_map() self._file_storage_client: FileStorageClient | None = None + self._async_file_storage_client: AsyncFileStorageClient | None = None + self._async_ch_client: clickhouse_connect.driver.AsyncClient | None = None self._kafka_producer: KafkaProducer | None = None self._evaluate_model_dispatcher = evaluate_model_dispatcher self._table_routing_resolver: TableRoutingResolver | None = None @@ -298,6 +303,34 @@ def file_storage_client(self) -> FileStorageClient | None: self._file_storage_client = maybe_get_storage_client_from_env() return self._file_storage_client + @property + def async_file_storage_client(self) -> AsyncFileStorageClient | None: + if self._async_file_storage_client is not None: + return self._async_file_storage_client + self._async_file_storage_client = maybe_get_async_storage_client_from_env() + return self._async_file_storage_client + + async def _get_async_ch_client(self) -> clickhouse_connect.driver.AsyncClient: + """Get or create async ClickHouse client.""" + if self._async_ch_client is None: + self._async_ch_client = await clickhouse_connect.get_async_client( + host=self._host, + port=self._port, + user=self._user, + password=self._password, + database=self._database, + ) + return self._async_ch_client + + async def close_async_ch_client(self) -> None: + """Close the async ClickHouse client if it exists. + + Should be called during shutdown to release resources. + """ + if self._async_ch_client is not None: + await self._async_ch_client.close() + self._async_ch_client = None + @property def kafka_producer(self) -> KafkaProducer: if self._kafka_producer is not None: @@ -4709,6 +4742,102 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: set_root_span_dd_tags({"write_bytes": len(req.content)}) return tsi.FileCreateRes(digest=digest) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version of file_create using async storage and ClickHouse clients.""" + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched.file_create_async" + ): + digest = bytes_digest(req.content) + use_file_storage = self._should_use_file_storage_for_writes(req.project_id) + async_client = self.async_file_storage_client + + if async_client is not None and use_file_storage: + try: + await self._file_create_bucket_async(req, digest, async_client) + except FileStorageWriteError: + await self._file_create_clickhouse_async(req, digest) + else: + await self._file_create_clickhouse_async(req, digest) + set_root_span_dd_tags({"write_bytes": len(req.content)}) + return tsi.FileCreateRes(digest=digest) + + async def _file_create_clickhouse_async( + self, req: tsi.FileCreateReq, digest: str + ) -> None: + """Async version of _file_create_clickhouse.""" + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._file_create_clickhouse_async" + ): + set_root_span_dd_tags({"storage_provider": "clickhouse"}) + chunks = [ + req.content[i : i + ch_settings.FILE_CHUNK_SIZE] + for i in range(0, len(req.content), ch_settings.FILE_CHUNK_SIZE) + ] + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=i, + n_chunks=len(chunks), + name=req.name, + val_bytes=chunk, + bytes_stored=len(chunk), + file_storage_uri=None, + ) + for i, chunk in enumerate(chunks) + ] + ) + + async def _file_create_bucket_async( + self, req: tsi.FileCreateReq, digest: str, client: AsyncFileStorageClient + ) -> None: + """Async version of _file_create_bucket.""" + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._file_create_bucket_async" + ): + set_root_span_dd_tags({"storage_provider": "bucket"}) + target_file_storage_uri = await store_in_bucket_async( + client, key_for_project_digest(req.project_id, digest), req.content + ) + await self._insert_file_chunks_async( + [ + FileChunkCreateCHInsertable( + project_id=req.project_id, + digest=digest, + chunk_index=0, + n_chunks=1, + name=req.name, + val_bytes=b"", + bytes_stored=len(req.content), + file_storage_uri=target_file_storage_uri.to_uri_str(), + ) + ] + ) + + async def _insert_file_chunks_async( + self, file_chunks: list[FileChunkCreateCHInsertable] + ) -> None: + """Async version of _insert_file_chunks using async ClickHouse client.""" + with ddtrace.tracer.trace( + name="clickhouse_trace_server_batched._insert_file_chunks_async" + ): + data = [] + for chunk in file_chunks: + chunk_dump = chunk.model_dump() + row = [] + for col in ALL_FILE_CHUNK_INSERT_COLUMNS: + row.append(chunk_dump.get(col, None)) + data.append(row) + + if data: + client = await self._get_async_ch_client() + await client.insert( + "files", + data=data, + column_names=ALL_FILE_CHUNK_INSERT_COLUMNS, + ) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_clickhouse") def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: set_root_span_dd_tags({"storage_provider": "clickhouse"}) diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index e7528740a0ee..166f3b9c29a7 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -297,6 +297,11 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: # Special case where refs can never be part of the request return self._internal_trace_server.file_create(req) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + # Special case where refs can never be part of the request + return await self._internal_trace_server.file_create_async(req) + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: req.project_id = self._idc.ext_to_int_project_id(req.project_id) # Special case where refs can never be part of the request diff --git a/weave/trace_server/file_storage.py b/weave/trace_server/file_storage.py index 18e6b910f134..c3d58709b0b0 100644 --- a/weave/trace_server/file_storage.py +++ b/weave/trace_server/file_storage.py @@ -40,14 +40,19 @@ - `WF_FILE_STORAGE_AZURE_ACCOUNT_URL`: (optional) the account url for the azure account - defaults to `https://.blob.core.windows.net/` """ +import asyncio import logging from abc import abstractmethod -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager from typing import Any, cast import boto3 +from aiobotocore.session import AioSession +from aiobotocore.session import get_session as get_aio_session from azure.core.exceptions import HttpResponseError from azure.storage.blob import BlobServiceClient +from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient from botocore.config import Config from botocore.exceptions import ClientError from google.api_core import exceptions as gcp_exceptions @@ -130,8 +135,8 @@ def store_in_bucket( client: FileStorageClient, path: str, data: bytes ) -> FileStorageURI: """Store a file in a storage bucket.""" + target_file_storage_uri = client.base_uri.with_path(path) try: - target_file_storage_uri = client.base_uri.with_path(path) client.store(target_file_storage_uri, data) except Exception as e: logger.exception("Failed to store file at %s", target_file_storage_uri) @@ -408,3 +413,289 @@ def maybe_get_storage_client_from_env() -> FileStorageClient | None: raise NotImplementedError( f"Storage client for URI type {type(file_storage_uri)} not supported" ) + + +# ============================================================================= +# Async Storage Clients +# ============================================================================= +# +# Async implementations use different strategies based on library support: +# - S3: Native async via aiobotocore (best performance) +# - GCS: Thread pool wrapper (no mature async library available) +# - Azure: Native async via azure-storage-blob aio support + + +def create_async_retry_decorator(operation_name: str) -> Callable[[Any], Any]: + """Creates an async retry decorator with consistent retry policy and special 429 handling. + + Uses the same retry logic as the sync version to ensure consistent error handling. + """ + + def after_retry(retry_state: RetryCallState) -> None: + if retry_state.attempt_number > 1: + logger.info( + "%s: Attempt %d/%d after %.2f seconds", + operation_name, + retry_state.attempt_number, + RETRY_MAX_ATTEMPTS, + retry_state.seconds_since_start, + ) + + def create_wait_strategy(retry_state: RetryCallState) -> float: + """Create wait strategy that uses jitter for rate limit errors.""" + if retry_state.outcome and retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception and _is_rate_limit_error(exception): + # Use random exponential backoff with jitter for rate limiting + return wait_random_exponential( + multiplier=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT + )(retry_state) + # Use regular exponential backoff for other errors + return wait_exponential(multiplier=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT)( + retry_state + ) + + return retry( + stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), + wait=create_wait_strategy, + reraise=True, + before_sleep=before_sleep_log(logger, logging.DEBUG), + after=after_retry, + ) + + +class AsyncFileStorageClient: + """Abstract base class for async cloud storage operations. + + Follows the same pattern as FileStorageClient - uses @abstractmethod + without ABC inheritance for interface definition. + """ + + base_uri: FileStorageURI + _sync_client: FileStorageClient + + def __init__(self, base_uri: FileStorageURI, sync_client: FileStorageClient): + """Initialize async storage client. + + Args: + base_uri: The base URI for storage operations. + sync_client: A sync client used for fallback or thread pool operations. + """ + self.base_uri = base_uri + self._sync_client = sync_client + + @abstractmethod + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data at the specified URI location in cloud storage (async).""" + pass + + @abstractmethod + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from the specified URI location in cloud storage (async).""" + pass + + +class AsyncS3StorageClient(AsyncFileStorageClient): + """Async AWS S3 storage implementation using aiobotocore.""" + + def __init__( + self, + base_uri: FileStorageURI, + credentials: AWSCredentials, + sync_client: S3StorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, S3FileStorageURI) + self._credentials = credentials + self._kms_key = credentials.get("kms_key") + self._session: AioSession | None = None + + def _get_session(self) -> AioSession: + """Get or create an aiobotocore session (reused across calls).""" + if self._session is None: + self._session = get_aio_session() + return self._session + + @asynccontextmanager + async def _get_async_s3_client(self) -> AsyncIterator[Any]: + """Context manager that yields an async S3 client.""" + session = self._get_session() + async with session.create_client( + "s3", + region_name=self._credentials.get("region"), + aws_access_key_id=self._credentials.get("access_key_id"), + aws_secret_access_key=self._credentials.get("secret_access_key"), + aws_session_token=self._credentials.get("session_token"), + ) as client: + yield client + + @create_async_retry_decorator("async_s3_storage") + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in S3 bucket asynchronously using aiobotocore.""" + assert isinstance(uri, S3FileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with self._get_async_s3_client() as client: + put_object_params: dict[str, Any] = { + "Bucket": uri.bucket, + "Key": uri.path, + "Body": data, + } + if self._kms_key: + put_object_params["ServerSideEncryption"] = "aws:kms" + put_object_params["SSEKMSKeyId"] = self._kms_key + + await client.put_object(**put_object_params) + + @create_async_retry_decorator("async_s3_read") + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from S3 bucket asynchronously.""" + assert isinstance(uri, S3FileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with self._get_async_s3_client() as client: + response = await client.get_object(Bucket=uri.bucket, Key=uri.path) + async with response["Body"] as stream: + return await stream.read() + + +class AsyncGCSStorageClient(AsyncFileStorageClient): + """Async Google Cloud Storage implementation using thread pool wrapper. + + GCS does not have a mature async library, so we wrap sync operations + in asyncio.to_thread. The sync client already has retry logic applied. + """ + + def __init__( + self, + base_uri: FileStorageURI, + credentials: GCPCredentials | None, + sync_client: GCSStorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, GCSFileStorageURI) + self._credentials = credentials + + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in GCS bucket asynchronously. + + Delegates to sync client which has retry logic applied. + """ + assert isinstance(uri, GCSFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + await asyncio.to_thread(self._sync_client.store, uri, data) + + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from GCS bucket asynchronously. + + Delegates to sync client which has retry logic applied. + """ + assert isinstance(uri, GCSFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + return await asyncio.to_thread(self._sync_client.read, uri) + + +class AsyncAzureStorageClient(AsyncFileStorageClient): + """Async Azure Blob Storage implementation using built-in aio support.""" + + def __init__( + self, + base_uri: FileStorageURI, + credentials: AzureConnectionCredentials | AzureAccountCredentials, + sync_client: AzureStorageClient, + ): + super().__init__(base_uri, sync_client) + assert isinstance(base_uri, AzureFileStorageURI) + self._credentials = credentials + + async def _get_async_client(self, account: str) -> AsyncBlobServiceClient: + """Create async Azure client based on available credentials.""" + if "connection_string" in self._credentials: + connection_creds = cast(AzureConnectionCredentials, self._credentials) + return AsyncBlobServiceClient.from_connection_string( + connection_creds["connection_string"], + connection_timeout=DEFAULT_CONNECT_TIMEOUT, + read_timeout=DEFAULT_READ_TIMEOUT, + ) + else: + account_creds = cast(AzureAccountCredentials, self._credentials) + if "account_url" in account_creds and account_creds["account_url"]: + account_url = account_creds["account_url"] + else: + account_url = f"https://{account}.blob.core.windows.net/" + return AsyncBlobServiceClient( + account_url=account_url, + credential=account_creds["access_key"], + connection_timeout=DEFAULT_CONNECT_TIMEOUT, + read_timeout=DEFAULT_READ_TIMEOUT, + ) + + @create_async_retry_decorator("async_azure_storage") + async def store_async(self, uri: FileStorageURI, data: bytes) -> None: + """Store data in Azure container asynchronously.""" + assert isinstance(uri, AzureFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with await self._get_async_client(uri.account) as client: + container_client = client.get_container_client(uri.container) + blob_client = container_client.get_blob_client(uri.path) + await blob_client.upload_blob(data, overwrite=True) + + @create_async_retry_decorator("async_azure_read") + async def read_async(self, uri: FileStorageURI) -> bytes: + """Read data from Azure container asynchronously.""" + assert isinstance(uri, AzureFileStorageURI) + assert uri.to_uri_str().startswith(self.base_uri.to_uri_str()) + + async with await self._get_async_client(uri.account) as client: + container_client = client.get_container_client(uri.container) + blob_client = container_client.get_blob_client(uri.path) + stream = await blob_client.download_blob() + return await stream.readall() + + +async def store_in_bucket_async( + client: AsyncFileStorageClient, path: str, data: bytes +) -> FileStorageURI: + """Store a file in a storage bucket asynchronously.""" + target_file_storage_uri = client.base_uri.with_path(path) + try: + await client.store_async(target_file_storage_uri, data) + except Exception as e: + logger.exception("Failed to store file at %s", target_file_storage_uri) + raise FileStorageWriteError(f"Failed to store file at {path}: {e!s}") from e + return target_file_storage_uri + + +def maybe_get_async_storage_client_from_env() -> AsyncFileStorageClient | None: + """Factory method that returns appropriate async storage client based on URI type.""" + file_storage_uri = wf_file_storage_uri() + if not file_storage_uri: + return None + try: + parsed_uri = FileStorageURI.parse_uri_str(file_storage_uri) + except Exception as e: + logger.warning(f"Error parsing file storage URI: {e}") + return None + if parsed_uri.has_path(): + logger.error( + f"Supplied file storage uri contains path components: {file_storage_uri}" + ) + return None + + if isinstance(parsed_uri, S3FileStorageURI): + credentials = get_aws_credentials() + sync_client = S3StorageClient(parsed_uri, credentials) + return AsyncS3StorageClient(parsed_uri, credentials, sync_client) + elif isinstance(parsed_uri, GCSFileStorageURI): + credentials = get_gcp_credentials() + sync_client = GCSStorageClient(parsed_uri, credentials) + return AsyncGCSStorageClient(parsed_uri, credentials, sync_client) + elif isinstance(parsed_uri, AzureFileStorageURI): + credentials = get_azure_credentials() + sync_client = AzureStorageClient(parsed_uri, credentials) + return AsyncAzureStorageClient(parsed_uri, credentials, sync_client) + else: + raise NotImplementedError( + f"Async storage client for URI type {type(parsed_uri)} not supported" + ) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 305477a36178..754ba534f884 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1,5 +1,6 @@ # Sqlite Trace Server +import asyncio import datetime import hashlib import json @@ -1525,6 +1526,10 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: conn.commit() return tsi.FileCreateRes(digest=digest) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version - wraps sync in thread pool for SQLite.""" + return await asyncio.to_thread(self.file_create, req) + def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: conn, cursor = get_conn_cursor(self.db_path) cursor.execute( diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 458f80b542ae..a2c72587bcff 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -2289,6 +2289,8 @@ def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ... # File API def file_create(self, req: FileCreateReq) -> FileCreateRes: ... + async def file_create_async(self, req: FileCreateReq) -> FileCreateRes: ... + def file_content_read(self, req: FileContentReadReq) -> FileContentReadRes: ... def files_stats(self, req: FilesStatsReq) -> FilesStatsRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 9a6127561902..78eed14e2b14 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,3 +1,4 @@ +import asyncio import datetime import io import logging @@ -762,6 +763,10 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: handle_response_error(r, "/files/create") return tsi.FileCreateRes.model_validate(r.json()) + async def file_create_async(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + """Async version - wraps sync HTTP call in thread pool.""" + return await asyncio.to_thread(self.file_create, req) + @with_retry def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: r = self.post(