From 46a232effc95912d8a871f03c28b44f52c7eb43a Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 2 Feb 2026 13:24:36 -0800 Subject: [PATCH 1/4] chore(weave): add call-id sharding to distributed cluster migration --- .../test_clickhouse_trace_server_migrator.py | 10 ++++++++++ .../clickhouse_trace_server_migrator.py | 19 +++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py index 30f37878c571..c07d3f53d08c 100644 --- a/tests/trace_server/test_clickhouse_trace_server_migrator.py +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -387,6 +387,16 @@ def test_create_distributed_table_sql(): assert sql.strip() == expected.strip() +def test_create_distributed_table_sql_id_sharded(): + """Test distributed table creation SQL for ID-sharded tables.""" + distributed_migrator = DistributedClickHouseTraceServerMigrator( + Mock(), replicated_cluster="test_cluster", migration_dir=DEFAULT_MIGRATION_DIR + ) + sql = distributed_migrator._create_distributed_table_sql("calls_complete") + expected = "CREATE TABLE IF NOT EXISTS calls_complete ON CLUSTER test_cluster\n AS calls_complete_local\n ENGINE = Distributed(test_cluster, currentDatabase(), calls_complete_local, sipHash64(id))" + assert sql.strip() == expected.strip() + + def test_format_distributed_sql(): """Test distributed SQL formatting for CREATE TABLE and other DDL.""" distributed_migrator = DistributedClickHouseTraceServerMigrator( diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 750492db2d7b..6a12788b5c3e 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -85,6 +85,12 @@ # Constants for table naming conventions VIEW_SUFFIX = "_view" +# Tables that use ID-based sharding (sipHash64(field)) instead of random sharding +# in distributed mode. Maps table name to the field used for sharding. +# This ensures all data for a specific ID goes to the same shard, enabling +# efficient point lookups. +ID_SHARDED_TABLES: dict[str, str] = {"calls_complete": "id"} + @dataclass(frozen=True) class PostMigrationHookContext: @@ -781,12 +787,21 @@ def _format_distributed_sql(self, sql_query: str) -> DistributedTransformResult: ) def _create_distributed_table_sql(self, table_name: str) -> str: - """Generate SQL to create a distributed table.""" + """Generate SQL to create a distributed table. + + For tables in ID_SHARDED_TABLES, uses sipHash64(field) as the sharding key + to ensure all data for a specific ID goes to the same shard, enabling + efficient point lookups. Other tables use rand() for even distribution. + """ local_table_name = table_name + ch_settings.LOCAL_TABLE_SUFFIX + if shard_field := ID_SHARDED_TABLES.get(table_name): + sharding_key = f"sipHash64({shard_field})" + else: + sharding_key = "rand()" return f""" CREATE TABLE IF NOT EXISTS {table_name} ON CLUSTER {self.replicated_cluster} AS {local_table_name} - ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, rand()) + ENGINE = Distributed({self.replicated_cluster}, currentDatabase(), {local_table_name}, {sharding_key}) """ @staticmethod From fe9e89318f6cb47d6fa80290103bbc98765d0f46 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 3 Feb 2026 15:00:39 -0800 Subject: [PATCH 2/4] lint --- scripts/slack_digest.py | 5 +-- .../test_cases/media_cases.py | 16 ++++++---- .../test_clickhouse_trace_server_migrator.py | 6 +++- .../huggingface_inference_client_sdk.py | 10 +++--- weave/integrations/openai/openai_sdk.py | 32 +++++++++++-------- .../openai_realtime_websocket_patcher.py | 6 ++-- weave/integrations/vertexai/vertexai_sdk.py | 10 +++--- .../clickhouse_trace_server_migrator.py | 16 +++++++--- 8 files changed, 63 insertions(+), 38 deletions(-) diff --git a/scripts/slack_digest.py b/scripts/slack_digest.py index b3dee4251d13..8f8d137fb6d7 100644 --- a/scripts/slack_digest.py +++ b/scripts/slack_digest.py @@ -128,8 +128,9 @@ class CategoryRule: name="Python SDK", column_header="Py", emoji="🐍", - matcher=lambda path: path.startswith("weave/") - and not path.startswith("weave/trace_server/"), + matcher=lambda path: ( + path.startswith("weave/") and not path.startswith("weave/trace_server/") + ), ), ] diff --git a/tests/trace/data_serialization/test_cases/media_cases.py b/tests/trace/data_serialization/test_cases/media_cases.py index cd37b253946d..ba1478d41517 100644 --- a/tests/trace/data_serialization/test_cases/media_cases.py +++ b/tests/trace/data_serialization/test_cases/media_cases.py @@ -202,8 +202,10 @@ def markdown_equality_check(a, b): "exp_content": b'import weave\nfrom typing import Any\nfrom rich.markdown import Markdown\n\n@weave.op\ndef load(artifact: "MemTraceFilesArtifact", name: str, val: Any) -> Markdown:\n """Load markdown from file and metadata."""\n if "markup" in val:\n markup = val["markup"]\n else:\n with artifact.open("markup.md", binary=False) as f:\n markup = f.read()\n\n kwargs = {}\n if val and isinstance(val, dict) and "code_theme" in val:\n kwargs["code_theme"] = val["code_theme"]\n\n return Markdown(markup=markup, **kwargs)\n', }, ], - equality_check=lambda a, b: markdown_equality_check(a["inline"], b["inline"]) - and markdown_equality_check(a["file"], b["file"]), + equality_check=lambda a, b: ( + markdown_equality_check(a["inline"], b["inline"]) + and markdown_equality_check(a["file"], b["file"]) + ), python_version_code_capture=(3, 13), ), # Video @@ -239,8 +241,9 @@ def markdown_equality_check(a, b): "exp_content": VIDEO_BYTES, }, ], - equality_check=lambda a, b: a.duration - == b.duration, # could do better, but this is a good start + equality_check=lambda a, b: ( + a.duration == b.duration + ), # could do better, but this is a good start python_version_code_capture=(3, 13), ), # Content @@ -463,8 +466,9 @@ def markdown_equality_check(a, b): "exp_content": VIDEO_BYTES, }, ], - equality_check=lambda a, b: a.duration - == b.duration, # could do better, but this is a good start + equality_check=lambda a, b: ( + a.duration == b.duration + ), # could do better, but this is a good start python_version_code_capture=(3, 13), ), SerializationTestCase( diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py index c07d3f53d08c..07f3f320c5e4 100644 --- a/tests/trace_server/test_clickhouse_trace_server_migrator.py +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -393,7 +393,11 @@ def test_create_distributed_table_sql_id_sharded(): Mock(), replicated_cluster="test_cluster", migration_dir=DEFAULT_MIGRATION_DIR ) sql = distributed_migrator._create_distributed_table_sql("calls_complete") - expected = "CREATE TABLE IF NOT EXISTS calls_complete ON CLUSTER test_cluster\n AS calls_complete_local\n ENGINE = Distributed(test_cluster, currentDatabase(), calls_complete_local, sipHash64(id))" + expected = """ + CREATE TABLE IF NOT EXISTS calls_complete ON CLUSTER test_cluster + AS calls_complete_local + ENGINE = Distributed(test_cluster, currentDatabase(), calls_complete_local, sipHash64(id)) + """ assert sql.strip() == expected.strip() diff --git a/weave/integrations/huggingface/huggingface_inference_client_sdk.py b/weave/integrations/huggingface/huggingface_inference_client_sdk.py index 00adab8ec513..77516d276bc5 100644 --- a/weave/integrations/huggingface/huggingface_inference_client_sdk.py +++ b/weave/integrations/huggingface/huggingface_inference_client_sdk.py @@ -83,8 +83,9 @@ def wrapper(fn: Callable) -> Callable: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: huggingface_accumulator, - should_accumulate=lambda inputs: isinstance(inputs, dict) - and bool(inputs.get("stream")), + should_accumulate=lambda inputs: ( + isinstance(inputs, dict) and bool(inputs.get("stream")) + ), ) return wrapper @@ -107,8 +108,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: huggingface_accumulator, - should_accumulate=lambda inputs: isinstance(inputs, dict) - and bool(inputs.get("stream")), + should_accumulate=lambda inputs: ( + isinstance(inputs, dict) and bool(inputs.get("stream")) + ), ) return wrapper diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 77be84ae7425..667fe7888eae 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -415,11 +415,13 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( - acc, - value, - skip_last=not _openai_stream_options_is_set(inputs), - stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), + make_accumulator=lambda inputs: ( + lambda acc, value: openai_accumulator( + acc, + value, + skip_last=not _openai_stream_options_is_set(inputs), + stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), + ) ), should_accumulate=should_use_accumulator, on_finish_post_processor=openai_on_finish_post_processor, @@ -459,11 +461,13 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( - acc, - value, - skip_last=not _openai_stream_options_is_set(inputs), - stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), + make_accumulator=lambda inputs: ( + lambda acc, value: openai_accumulator( + acc, + value, + skip_last=not _openai_stream_options_is_set(inputs), + stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), + ) ), should_accumulate=should_use_accumulator, on_finish_post_processor=openai_on_finish_post_processor, @@ -691,8 +695,8 @@ def _inner(*args: Any, **kwargs: Any) -> Any: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: lambda acc, value: responses_accumulator( - acc, value + make_accumulator=lambda inputs: ( + lambda acc, value: responses_accumulator(acc, value) ), should_accumulate=should_use_responses_accumulator, on_finish_post_processor=responses_on_finish_post_processor, @@ -715,8 +719,8 @@ async def _inner(*args: Any, **kwargs: Any) -> Any: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: lambda acc, value: responses_accumulator( - acc, value + make_accumulator=lambda inputs: ( + lambda acc, value: responses_accumulator(acc, value) ), should_accumulate=should_use_responses_accumulator, on_finish_post_processor=responses_on_finish_post_processor, diff --git a/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py b/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py index eb305e31358a..a658367180b2 100644 --- a/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py +++ b/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py @@ -180,9 +180,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: make_new_value=make_new_async_value, ), SymbolPatcher( - get_base_symbol=lambda: importlib.import_module( - "aiohttp" - ).ClientSession, + get_base_symbol=lambda: ( + importlib.import_module("aiohttp").ClientSession + ), attribute_name="ws_connect", make_new_value=make_aiohttp_ws_connect, ), diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index 479b61b1da86..83373c12a740 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -99,8 +99,9 @@ def wrapper(fn: Callable) -> Callable: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: vertexai_accumulator, - should_accumulate=lambda inputs: isinstance(inputs, dict) - and bool(inputs.get("stream")), + should_accumulate=lambda inputs: ( + isinstance(inputs, dict) and bool(inputs.get("stream")) + ), ) return wrapper @@ -124,8 +125,9 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: vertexai_accumulator, - should_accumulate=lambda inputs: isinstance(inputs, dict) - and bool(inputs.get("stream")), + should_accumulate=lambda inputs: ( + isinstance(inputs, dict) and bool(inputs.get("stream")) + ), ) return wrapper diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 6a12788b5c3e..89894d39bc3a 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -516,28 +516,36 @@ def _add_on_cluster_clause(self, sql_query: str) -> str: # ALTER TABLE if SQLPatterns.ALTER_TABLE_STMT.search(sql_query): return SQLPatterns.ALTER_TABLE_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # CREATE TABLE if SQLPatterns.CREATE_TABLE_STMT.search(sql_query): return SQLPatterns.CREATE_TABLE_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # DROP VIEW if SQLPatterns.DROP_VIEW_STMT.search(sql_query): return SQLPatterns.DROP_VIEW_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) # CREATE VIEW / CREATE MATERIALIZED VIEW if SQLPatterns.CREATE_VIEW_STMT.search(sql_query): return SQLPatterns.CREATE_VIEW_NAME_PATTERN.sub( - lambda m: f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}", + lambda m: ( + f"{m.group(1)}{m.group(2)} ON CLUSTER {self.replicated_cluster}{m.group(3)}" + ), sql_query, ) From d7552f1646c4c7eb050fc68997223ce0cf2e8263 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 3 Feb 2026 15:02:45 -0800 Subject: [PATCH 3/4] undo --- scripts/slack_digest.py | 5 +- .../test_cases/media_cases.py | 16 ++-- weave/integrations/anthropic/anthropic_sdk.py | 22 ------ .../integrations/google_genai/gemini_utils.py | 77 +++++++------------ .../huggingface_inference_client_sdk.py | 10 +-- weave/integrations/openai/openai_sdk.py | 32 ++++---- .../openai_realtime_websocket_patcher.py | 6 +- weave/integrations/vertexai/vertexai_sdk.py | 10 +-- 8 files changed, 61 insertions(+), 117 deletions(-) diff --git a/scripts/slack_digest.py b/scripts/slack_digest.py index 8f8d137fb6d7..b3dee4251d13 100644 --- a/scripts/slack_digest.py +++ b/scripts/slack_digest.py @@ -128,9 +128,8 @@ class CategoryRule: name="Python SDK", column_header="Py", emoji="🐍", - matcher=lambda path: ( - path.startswith("weave/") and not path.startswith("weave/trace_server/") - ), + matcher=lambda path: path.startswith("weave/") + and not path.startswith("weave/trace_server/"), ), ] diff --git a/tests/trace/data_serialization/test_cases/media_cases.py b/tests/trace/data_serialization/test_cases/media_cases.py index ba1478d41517..cd37b253946d 100644 --- a/tests/trace/data_serialization/test_cases/media_cases.py +++ b/tests/trace/data_serialization/test_cases/media_cases.py @@ -202,10 +202,8 @@ def markdown_equality_check(a, b): "exp_content": b'import weave\nfrom typing import Any\nfrom rich.markdown import Markdown\n\n@weave.op\ndef load(artifact: "MemTraceFilesArtifact", name: str, val: Any) -> Markdown:\n """Load markdown from file and metadata."""\n if "markup" in val:\n markup = val["markup"]\n else:\n with artifact.open("markup.md", binary=False) as f:\n markup = f.read()\n\n kwargs = {}\n if val and isinstance(val, dict) and "code_theme" in val:\n kwargs["code_theme"] = val["code_theme"]\n\n return Markdown(markup=markup, **kwargs)\n', }, ], - equality_check=lambda a, b: ( - markdown_equality_check(a["inline"], b["inline"]) - and markdown_equality_check(a["file"], b["file"]) - ), + equality_check=lambda a, b: markdown_equality_check(a["inline"], b["inline"]) + and markdown_equality_check(a["file"], b["file"]), python_version_code_capture=(3, 13), ), # Video @@ -241,9 +239,8 @@ def markdown_equality_check(a, b): "exp_content": VIDEO_BYTES, }, ], - equality_check=lambda a, b: ( - a.duration == b.duration - ), # could do better, but this is a good start + equality_check=lambda a, b: a.duration + == b.duration, # could do better, but this is a good start python_version_code_capture=(3, 13), ), # Content @@ -466,9 +463,8 @@ def markdown_equality_check(a, b): "exp_content": VIDEO_BYTES, }, ], - equality_check=lambda a, b: ( - a.duration == b.duration - ), # could do better, but this is a good start + equality_check=lambda a, b: a.duration + == b.duration, # could do better, but this is a good start python_version_code_capture=(3, 13), ), SerializationTestCase( diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index 3fbc0fc089cd..b6d34a75795b 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -237,18 +237,6 @@ def get_anthropic_patcher( "kind": base.kind or "llm", } ) - beta_messages_parse_settings = base.model_copy( - update={ - "name": base.name or "anthropic.beta.Messages.parse", - "kind": base.kind or "llm", - } - ) - beta_async_messages_parse_settings = base.model_copy( - update={ - "name": base.name or "anthropic.beta.AsyncMessages.parse", - "kind": base.kind or "llm", - } - ) beta_stream_settings = base.model_copy( update={ "name": base.name or "anthropic.beta.Messages.stream", @@ -295,16 +283,6 @@ def get_anthropic_patcher( "AsyncMessages.create", create_wrapper_async(beta_async_messages_create_settings), ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.beta.messages"), - "Messages.parse", - create_wrapper_sync(beta_messages_parse_settings), - ), - SymbolPatcher( - lambda: importlib.import_module("anthropic.resources.beta.messages"), - "AsyncMessages.parse", - create_wrapper_async(beta_async_messages_parse_settings), - ), SymbolPatcher( lambda: importlib.import_module("anthropic.resources.beta.messages"), "Messages.stream", diff --git a/weave/integrations/google_genai/gemini_utils.py b/weave/integrations/google_genai/gemini_utils.py index d97db785fdc0..bbd279f12c68 100644 --- a/weave/integrations/google_genai/gemini_utils.py +++ b/weave/integrations/google_genai/gemini_utils.py @@ -1,8 +1,6 @@ -from __future__ import annotations - from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import weave from weave.trace.autopatch import OpSettings @@ -76,63 +74,44 @@ def google_genai_gemini_on_finish( def google_genai_gemini_accumulator( - acc: GenerateContentResponse | None, value: GenerateContentResponse -) -> GenerateContentResponse: + acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse" +) -> "GenerateContentResponse": if acc is None: return value - value_candidates = value.candidates or [] - acc_candidates = acc.candidates or [] - for i, value_candidate in enumerate(value_candidates): - if i >= len(acc_candidates): + for i, value_candidate in enumerate(value.candidates): + if i >= len(acc.candidates): break - - value_parts = value_candidate.content.parts or [] - for value_part in value_parts: - if value_part.text is None: - continue - - # Check if this part is thinking content (thought=True) - value_part_is_thought = getattr(value_part, "thought", False) - - # Find matching part by type (thought vs non-thought), not by index - matched = False - for acc_part in acc.candidates[i].content.parts: - acc_part_is_thought = getattr(acc_part, "thought", False) - if acc_part_is_thought == value_part_is_thought: - acc_part.text += value_part.text - matched = True - break - - # If no matching part found, append as new part - if not matched: - acc.candidates[i].content.parts.append(value_part) - - # Replace token counts with latest non-None values (Gemini returns cumulative counts) - # Per Google docs: "When streaming output, the usageMetadata attribute only appears - # on the last chunk of the stream." - if value.usage_metadata.prompt_token_count is not None: - acc.usage_metadata.prompt_token_count = value.usage_metadata.prompt_token_count - - if value.usage_metadata.candidates_token_count is not None: - acc.usage_metadata.candidates_token_count = ( + for j, value_part in enumerate(value_candidate.content.parts): + if j >= len(acc.candidates[i].content.parts): + break + if value_part.text is not None: + acc.candidates[i].content.parts[j].text += value_part.text + + if acc.usage_metadata.prompt_token_count is None: + acc.usage_metadata.prompt_token_count = 0 + elif value.usage_metadata.prompt_token_count is not None: + acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count + + if acc.usage_metadata.candidates_token_count is None: + acc.usage_metadata.candidates_token_count = 0 + elif value.usage_metadata.candidates_token_count is not None: + acc.usage_metadata.candidates_token_count += ( value.usage_metadata.candidates_token_count ) - if value.usage_metadata.total_token_count is not None: - acc.usage_metadata.total_token_count = value.usage_metadata.total_token_count + if acc.usage_metadata.total_token_count is None: + acc.usage_metadata.total_token_count = 0 + elif value.usage_metadata.total_token_count is not None: + acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count - if value.usage_metadata.cached_content_token_count is not None: - acc.usage_metadata.cached_content_token_count = ( + if acc.usage_metadata.cached_content_token_count is None: + acc.usage_metadata.cached_content_token_count = 0 + elif value.usage_metadata.cached_content_token_count is not None: + acc.usage_metadata.cached_content_token_count += ( value.usage_metadata.cached_content_token_count ) - # Also handle thoughts_token_count for thinking models - if getattr(value.usage_metadata, "thoughts_token_count", None) is not None: - acc.usage_metadata.thoughts_token_count = ( - value.usage_metadata.thoughts_token_count - ) - return acc diff --git a/weave/integrations/huggingface/huggingface_inference_client_sdk.py b/weave/integrations/huggingface/huggingface_inference_client_sdk.py index 77516d276bc5..00adab8ec513 100644 --- a/weave/integrations/huggingface/huggingface_inference_client_sdk.py +++ b/weave/integrations/huggingface/huggingface_inference_client_sdk.py @@ -83,9 +83,8 @@ def wrapper(fn: Callable) -> Callable: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: huggingface_accumulator, - should_accumulate=lambda inputs: ( - isinstance(inputs, dict) and bool(inputs.get("stream")) - ), + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), ) return wrapper @@ -108,9 +107,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: huggingface_accumulator, - should_accumulate=lambda inputs: ( - isinstance(inputs, dict) and bool(inputs.get("stream")) - ), + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), ) return wrapper diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 667fe7888eae..77be84ae7425 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -415,13 +415,11 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: ( - lambda acc, value: openai_accumulator( - acc, - value, - skip_last=not _openai_stream_options_is_set(inputs), - stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), - ) + make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( + acc, + value, + skip_last=not _openai_stream_options_is_set(inputs), + stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), ), should_accumulate=should_use_accumulator, on_finish_post_processor=openai_on_finish_post_processor, @@ -461,13 +459,11 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: ( - lambda acc, value: openai_accumulator( - acc, - value, - skip_last=not _openai_stream_options_is_set(inputs), - stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), - ) + make_accumulator=lambda inputs: lambda acc, value: openai_accumulator( + acc, + value, + skip_last=not _openai_stream_options_is_set(inputs), + stream_start_time=inputs.get(WEAVE_STREAM_START_TIME), ), should_accumulate=should_use_accumulator, on_finish_post_processor=openai_on_finish_post_processor, @@ -695,8 +691,8 @@ def _inner(*args: Any, **kwargs: Any) -> Any: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: ( - lambda acc, value: responses_accumulator(acc, value) + make_accumulator=lambda inputs: lambda acc, value: responses_accumulator( + acc, value ), should_accumulate=should_use_responses_accumulator, on_finish_post_processor=responses_on_finish_post_processor, @@ -719,8 +715,8 @@ async def _inner(*args: Any, **kwargs: Any) -> Any: op._set_on_input_handler(openai_on_input_handler) return _add_accumulator( op, # type: ignore - make_accumulator=lambda inputs: ( - lambda acc, value: responses_accumulator(acc, value) + make_accumulator=lambda inputs: lambda acc, value: responses_accumulator( + acc, value ), should_accumulate=should_use_responses_accumulator, on_finish_post_processor=responses_on_finish_post_processor, diff --git a/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py b/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py index a658367180b2..eb305e31358a 100644 --- a/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py +++ b/weave/integrations/openai_realtime/openai_realtime_websocket_patcher.py @@ -180,9 +180,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: make_new_value=make_new_async_value, ), SymbolPatcher( - get_base_symbol=lambda: ( - importlib.import_module("aiohttp").ClientSession - ), + get_base_symbol=lambda: importlib.import_module( + "aiohttp" + ).ClientSession, attribute_name="ws_connect", make_new_value=make_aiohttp_ws_connect, ), diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index 83373c12a740..479b61b1da86 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -99,9 +99,8 @@ def wrapper(fn: Callable) -> Callable: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: vertexai_accumulator, - should_accumulate=lambda inputs: ( - isinstance(inputs, dict) and bool(inputs.get("stream")) - ), + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), ) return wrapper @@ -125,9 +124,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _add_accumulator( op, # type: ignore make_accumulator=lambda inputs: vertexai_accumulator, - should_accumulate=lambda inputs: ( - isinstance(inputs, dict) and bool(inputs.get("stream")) - ), + should_accumulate=lambda inputs: isinstance(inputs, dict) + and bool(inputs.get("stream")), ) return wrapper From 22cdf4c9ad544faabb2ccaef2a68f841c3808427 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 3 Feb 2026 15:22:07 -0800 Subject: [PATCH 4/4] undo --- weave/integrations/anthropic/anthropic_sdk.py | 22 ++++++ .../integrations/google_genai/gemini_utils.py | 77 ++++++++++++------- 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index b6d34a75795b..3fbc0fc089cd 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -237,6 +237,18 @@ def get_anthropic_patcher( "kind": base.kind or "llm", } ) + beta_messages_parse_settings = base.model_copy( + update={ + "name": base.name or "anthropic.beta.Messages.parse", + "kind": base.kind or "llm", + } + ) + beta_async_messages_parse_settings = base.model_copy( + update={ + "name": base.name or "anthropic.beta.AsyncMessages.parse", + "kind": base.kind or "llm", + } + ) beta_stream_settings = base.model_copy( update={ "name": base.name or "anthropic.beta.Messages.stream", @@ -283,6 +295,16 @@ def get_anthropic_patcher( "AsyncMessages.create", create_wrapper_async(beta_async_messages_create_settings), ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.beta.messages"), + "Messages.parse", + create_wrapper_sync(beta_messages_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("anthropic.resources.beta.messages"), + "AsyncMessages.parse", + create_wrapper_async(beta_async_messages_parse_settings), + ), SymbolPatcher( lambda: importlib.import_module("anthropic.resources.beta.messages"), "Messages.stream", diff --git a/weave/integrations/google_genai/gemini_utils.py b/weave/integrations/google_genai/gemini_utils.py index bbd279f12c68..d97db785fdc0 100644 --- a/weave/integrations/google_genai/gemini_utils.py +++ b/weave/integrations/google_genai/gemini_utils.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import weave from weave.trace.autopatch import OpSettings @@ -74,44 +76,63 @@ def google_genai_gemini_on_finish( def google_genai_gemini_accumulator( - acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse" -) -> "GenerateContentResponse": + acc: GenerateContentResponse | None, value: GenerateContentResponse +) -> GenerateContentResponse: if acc is None: return value - for i, value_candidate in enumerate(value.candidates): - if i >= len(acc.candidates): + value_candidates = value.candidates or [] + acc_candidates = acc.candidates or [] + for i, value_candidate in enumerate(value_candidates): + if i >= len(acc_candidates): break - for j, value_part in enumerate(value_candidate.content.parts): - if j >= len(acc.candidates[i].content.parts): - break - if value_part.text is not None: - acc.candidates[i].content.parts[j].text += value_part.text - - if acc.usage_metadata.prompt_token_count is None: - acc.usage_metadata.prompt_token_count = 0 - elif value.usage_metadata.prompt_token_count is not None: - acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count - - if acc.usage_metadata.candidates_token_count is None: - acc.usage_metadata.candidates_token_count = 0 - elif value.usage_metadata.candidates_token_count is not None: - acc.usage_metadata.candidates_token_count += ( + + value_parts = value_candidate.content.parts or [] + for value_part in value_parts: + if value_part.text is None: + continue + + # Check if this part is thinking content (thought=True) + value_part_is_thought = getattr(value_part, "thought", False) + + # Find matching part by type (thought vs non-thought), not by index + matched = False + for acc_part in acc.candidates[i].content.parts: + acc_part_is_thought = getattr(acc_part, "thought", False) + if acc_part_is_thought == value_part_is_thought: + acc_part.text += value_part.text + matched = True + break + + # If no matching part found, append as new part + if not matched: + acc.candidates[i].content.parts.append(value_part) + + # Replace token counts with latest non-None values (Gemini returns cumulative counts) + # Per Google docs: "When streaming output, the usageMetadata attribute only appears + # on the last chunk of the stream." + if value.usage_metadata.prompt_token_count is not None: + acc.usage_metadata.prompt_token_count = value.usage_metadata.prompt_token_count + + if value.usage_metadata.candidates_token_count is not None: + acc.usage_metadata.candidates_token_count = ( value.usage_metadata.candidates_token_count ) - if acc.usage_metadata.total_token_count is None: - acc.usage_metadata.total_token_count = 0 - elif value.usage_metadata.total_token_count is not None: - acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count + if value.usage_metadata.total_token_count is not None: + acc.usage_metadata.total_token_count = value.usage_metadata.total_token_count - if acc.usage_metadata.cached_content_token_count is None: - acc.usage_metadata.cached_content_token_count = 0 - elif value.usage_metadata.cached_content_token_count is not None: - acc.usage_metadata.cached_content_token_count += ( + if value.usage_metadata.cached_content_token_count is not None: + acc.usage_metadata.cached_content_token_count = ( value.usage_metadata.cached_content_token_count ) + # Also handle thoughts_token_count for thinking models + if getattr(value.usage_metadata, "thoughts_token_count", None) is not None: + acc.usage_metadata.thoughts_token_count = ( + value.usage_metadata.thoughts_token_count + ) + return acc