Skip to content

Commit 6011233

Browse files
Avoid unexpected error when stream chat doesn't yield (run-llama#13422)
Fix nonyielding stream chat bug Co-authored-by: Logan Markewich <[email protected]>
1 parent 662e0f6 commit 6011233

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed

llama-index-core/llama_index/core/instrumentation/events/llm.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Any, List, Optional
2-
32
from llama_index.core.bridge.pydantic import BaseModel
43
from llama_index.core.base.llms.types import (
54
ChatMessage,
@@ -138,7 +137,7 @@ class LLMChatInProgressEvent(BaseEvent):
138137
139138
Args:
140139
messages (List[ChatMessage]): List of chat messages.
141-
response (ChatResponse): Chat response currently beiung streamed.
140+
response (ChatResponse): Chat response currently being streamed.
142141
"""
143142

144143
messages: List[ChatMessage]
@@ -155,11 +154,11 @@ class LLMChatEndEvent(BaseEvent):
155154
156155
Args:
157156
messages (List[ChatMessage]): List of chat messages.
158-
response (ChatResponse): Chat response.
157+
response (Optional[ChatResponse]): Last chat response.
159158
"""
160159

161160
messages: List[ChatMessage]
162-
response: ChatResponse
161+
response: Optional[ChatResponse]
163162

164163
@classmethod
165164
def class_name(cls):

llama-index-core/llama_index/core/llms/callbacks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def wrapped_gen() -> ChatResponseAsyncGen:
9797
dispatcher.event(
9898
LLMChatEndEvent(
9999
messages=messages,
100-
response=x,
100+
response=last_response,
101101
span_id=span_id,
102102
)
103103
)
@@ -173,7 +173,7 @@ def wrapped_gen() -> ChatResponseGen:
173173
dispatcher.event(
174174
LLMChatEndEvent(
175175
messages=messages,
176-
response=x,
176+
response=last_response,
177177
span_id=span_id,
178178
)
179179
)

llama-index-core/llama_index/core/llms/mock.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from typing import Any, Callable, Optional, Sequence
2-
32
from llama_index.core.base.llms.types import (
43
ChatMessage,
4+
ChatResponseGen,
55
CompletionResponse,
66
CompletionResponseGen,
77
LLMMetadata,
88
)
99
from llama_index.core.callbacks import CallbackManager
10-
from llama_index.core.llms.callbacks import llm_completion_callback
10+
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
1111
from llama_index.core.llms.custom import CustomLLM
1212
from llama_index.core.types import PydanticProgramMode
1313

@@ -76,3 +76,11 @@ def gen_response(max_tokens: int) -> CompletionResponseGen:
7676
)
7777

7878
return gen_response(self.max_tokens) if self.max_tokens else gen_prompt()
79+
80+
81+
class MockLLMWithNonyieldingChatStream(MockLLM):
82+
@llm_chat_callback()
83+
def stream_chat(
84+
self, messages: Sequence[ChatMessage], **kwargs: Any
85+
) -> ChatResponseGen:
86+
yield from []

llama-index-core/tests/llms/test_callbacks.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
2+
from llama_index.core.base.llms.types import ChatMessage
23
from llama_index.core.llms.llm import LLM
34
from llama_index.core.llms.mock import MockLLM
5+
from llama_index.core.llms.mock import MockLLMWithNonyieldingChatStream
6+
7+
8+
@pytest.fixture()
9+
def nonyielding_llm() -> LLM:
10+
return MockLLMWithNonyieldingChatStream()
411

512

613
@pytest.fixture()
@@ -13,6 +20,25 @@ def prompt() -> str:
1320
return "test prompt"
1421

1522

23+
def test_llm_stream_chat_handles_nonyielding_stream(
24+
nonyielding_llm: LLM, prompt: str
25+
) -> None:
26+
response = nonyielding_llm.stream_chat([ChatMessage(role="user", content=prompt)])
27+
for _ in response:
28+
pass
29+
30+
31+
@pytest.mark.asyncio()
32+
async def test_llm_astream_chat_handles_nonyielding_stream(
33+
nonyielding_llm: LLM, prompt: str
34+
) -> None:
35+
response = await nonyielding_llm.astream_chat(
36+
[ChatMessage(role="user", content=prompt)]
37+
)
38+
async for _ in response:
39+
pass
40+
41+
1642
def test_llm_complete_prompt_arg(llm: LLM, prompt: str) -> None:
1743
res = llm.complete(prompt)
1844
expected_res_text = prompt

0 commit comments

Comments
 (0)