Skip to content

Commit 6b4dbf5

Browse files
committed
chore: also wrap stream_text for pydantic-ai
1 parent c1174b0 commit 6b4dbf5

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

src/llmling_agent_providers/pydanticai/provider.py

+40-33
Original file line numberDiff line numberDiff line change
@@ -383,44 +383,51 @@ async def stream_response( # type: ignore[override]
383383
) as stream_result:
384384
stream_result = cast(StreamedRunResult[AgentContext[Any], Any], stream_result)
385385
original_stream = stream_result.stream
386-
387-
async def wrapped_stream(*args, **kwargs):
388-
last_content = None
389-
async for chunk in original_stream(*args, **kwargs):
390-
# Only emit if content has changed
391-
if chunk != last_content:
392-
self.chunk_streamed.emit(str(chunk), message_id)
393-
last_content = chunk
394-
yield chunk
395-
396-
if stream_result.is_complete:
397-
self.chunk_streamed.emit("", message_id)
398-
messages = stream_result.new_messages()
399-
tool_dict = {i.name: i for i in tools or []}
400-
# Extract and update tool calls
401-
tool_calls = get_tool_calls(messages, tool_dict, agent_name=self.name)
402-
for call in tool_calls:
403-
call.message_id = message_id
404-
call.context_data = self._context.data if self._context else None
405-
self.tool_used.emit(call)
406-
# Format final content
407-
responses = [m for m in messages if isinstance(m, ModelResponse)]
408-
parts = [p for msg in responses for p in msg.parts]
409-
content = "\n".join(format_part(p) for p in parts)
410-
resolved_model = (
411-
use_model.model_name
412-
if isinstance(use_model, Model)
413-
else str(use_model)
414-
)
415-
# Update stream result with formatted content
416-
stream_result.formatted_content = content # type: ignore
417-
stream_result.model_name = resolved_model # type: ignore
386+
original_text_stream = stream_result.stream_text
387+
resolved_model = (
388+
use_model.model_name if isinstance(use_model, Model) else str(use_model)
389+
)
390+
stream_result.model_name = resolved_model # type: ignore
391+
392+
def get_wrapped_stream(fn):
393+
async def wrapped_stream(*args, **kwargs):
394+
last_content = None
395+
async for chunk in fn(*args, **kwargs):
396+
# Only emit if content has changed
397+
if chunk != last_content:
398+
self.chunk_streamed.emit(str(chunk), message_id)
399+
last_content = chunk
400+
yield chunk
401+
402+
if stream_result.is_complete:
403+
self.chunk_streamed.emit("", message_id)
404+
messages = stream_result.new_messages()
405+
tool_dict = {i.name: i for i in tools or []}
406+
# Extract and update tool calls
407+
tool_calls = get_tool_calls(
408+
messages, tool_dict, agent_name=self.name
409+
)
410+
for call in tool_calls:
411+
call.message_id = message_id
412+
call.context_data = (
413+
self._context.data if self._context else None
414+
)
415+
self.tool_used.emit(call)
416+
# Format final content
417+
responses = [m for m in messages if isinstance(m, ModelResponse)]
418+
parts = [p for msg in responses for p in msg.parts]
419+
content = "\n".join(format_part(p) for p in parts)
420+
# Update stream result with formatted content
421+
stream_result.formatted_content = content # type: ignore
422+
423+
return wrapped_stream
418424

419425
if model:
420426
original = self.model
421427
if isinstance(original, str):
422428
original = infer_model(original)
423429
self.model_changed.emit(original)
424430

425-
stream_result.stream = wrapped_stream # type: ignore
431+
stream_result.stream = get_wrapped_stream(original_stream) # type: ignore
432+
stream_result.stream_text = get_wrapped_stream(original_text_stream) # type: ignore
426433
yield stream_result # type: ignore

0 commit comments

Comments
 (0)