|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from collections.abc import Callable |
2 | 4 | from functools import wraps |
3 | | -from typing import TYPE_CHECKING, Any, Optional |
| 5 | +from typing import TYPE_CHECKING, Any |
4 | 6 |
|
5 | 7 | import weave |
6 | 8 | from weave.trace.autopatch import OpSettings |
@@ -74,44 +76,63 @@ def google_genai_gemini_on_finish( |
74 | 76 |
|
75 | 77 |
|
76 | 78 | def google_genai_gemini_accumulator( |
77 | | - acc: Optional["GenerateContentResponse"], value: "GenerateContentResponse" |
78 | | -) -> "GenerateContentResponse": |
| 79 | + acc: GenerateContentResponse | None, value: GenerateContentResponse |
| 80 | +) -> GenerateContentResponse: |
79 | 81 | if acc is None: |
80 | 82 | return value |
81 | 83 |
|
82 | | - for i, value_candidate in enumerate(value.candidates): |
83 | | - if i >= len(acc.candidates): |
| 84 | + value_candidates = value.candidates or [] |
| 85 | + acc_candidates = acc.candidates or [] |
| 86 | + for i, value_candidate in enumerate(value_candidates): |
| 87 | + if i >= len(acc_candidates): |
84 | 88 | break |
85 | | - for j, value_part in enumerate(value_candidate.content.parts): |
86 | | - if j >= len(acc.candidates[i].content.parts): |
87 | | - break |
88 | | - if value_part.text is not None: |
89 | | - acc.candidates[i].content.parts[j].text += value_part.text |
90 | | - |
91 | | - if acc.usage_metadata.prompt_token_count is None: |
92 | | - acc.usage_metadata.prompt_token_count = 0 |
93 | | - elif value.usage_metadata.prompt_token_count is not None: |
94 | | - acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count |
95 | | - |
96 | | - if acc.usage_metadata.candidates_token_count is None: |
97 | | - acc.usage_metadata.candidates_token_count = 0 |
98 | | - elif value.usage_metadata.candidates_token_count is not None: |
99 | | - acc.usage_metadata.candidates_token_count += ( |
| 89 | + |
| 90 | + value_parts = value_candidate.content.parts or [] |
| 91 | + for value_part in value_parts: |
| 92 | + if value_part.text is None: |
| 93 | + continue |
| 94 | + |
| 95 | + # Check if this part is thinking content (thought=True) |
| 96 | + value_part_is_thought = getattr(value_part, "thought", False) |
| 97 | + |
| 98 | + # Find matching part by type (thought vs non-thought), not by index |
| 99 | + matched = False |
| 100 | + for acc_part in acc.candidates[i].content.parts: |
| 101 | + acc_part_is_thought = getattr(acc_part, "thought", False) |
| 102 | + if acc_part_is_thought == value_part_is_thought: |
| 103 | + acc_part.text += value_part.text |
| 104 | + matched = True |
| 105 | + break |
| 106 | + |
| 107 | + # If no matching part found, append as new part |
| 108 | + if not matched: |
| 109 | + acc.candidates[i].content.parts.append(value_part) |
| 110 | + |
| 111 | + # Replace token counts with latest non-None values (Gemini returns cumulative counts) |
| 112 | + # Per Google docs: "When streaming output, the usageMetadata attribute only appears |
| 113 | + # on the last chunk of the stream." |
| 114 | + if value.usage_metadata.prompt_token_count is not None: |
| 115 | + acc.usage_metadata.prompt_token_count = value.usage_metadata.prompt_token_count |
| 116 | + |
| 117 | + if value.usage_metadata.candidates_token_count is not None: |
| 118 | + acc.usage_metadata.candidates_token_count = ( |
100 | 119 | value.usage_metadata.candidates_token_count |
101 | 120 | ) |
102 | 121 |
|
103 | | - if acc.usage_metadata.total_token_count is None: |
104 | | - acc.usage_metadata.total_token_count = 0 |
105 | | - elif value.usage_metadata.total_token_count is not None: |
106 | | - acc.usage_metadata.total_token_count += value.usage_metadata.total_token_count |
| 122 | + if value.usage_metadata.total_token_count is not None: |
| 123 | + acc.usage_metadata.total_token_count = value.usage_metadata.total_token_count |
107 | 124 |
|
108 | | - if acc.usage_metadata.cached_content_token_count is None: |
109 | | - acc.usage_metadata.cached_content_token_count = 0 |
110 | | - elif value.usage_metadata.cached_content_token_count is not None: |
111 | | - acc.usage_metadata.cached_content_token_count += ( |
| 125 | + if value.usage_metadata.cached_content_token_count is not None: |
| 126 | + acc.usage_metadata.cached_content_token_count = ( |
112 | 127 | value.usage_metadata.cached_content_token_count |
113 | 128 | ) |
114 | 129 |
|
| 130 | + # Also handle thoughts_token_count for thinking models |
| 131 | + if getattr(value.usage_metadata, "thoughts_token_count", None) is not None: |
| 132 | + acc.usage_metadata.thoughts_token_count = ( |
| 133 | + value.usage_metadata.thoughts_token_count |
| 134 | + ) |
| 135 | + |
115 | 136 | return acc |
116 | 137 |
|
117 | 138 |
|
|
0 commit comments