Skip to content

Commit 3c72cb2

Browse files
committed
feat: send full prediction on streaming final_response
1 parent 7893a03 commit 3c72cb2

File tree

8 files changed

+47
-68
lines changed

8 files changed

+47
-68
lines changed

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ async def aforward_streaming(
210210
)
211211

212212
mcp_prediction = self.mcp_generation_program(documents)
213+
# Emit single response plus a final response event for clients that rely on it
213214
yield StreamEvent(type=StreamEventType.RESPONSE, data=mcp_prediction.answer)
215+
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=mcp_prediction.answer)
214216
else:
215217
# Normal mode: Generate response
216218
yield StreamEvent(type=StreamEventType.PROCESSING, data="Generating response...")
@@ -223,12 +225,19 @@ async def aforward_streaming(
223225
adapter=dspy.adapters.ChatAdapter()
224226
), ls.trace(name="GenerationProgramStreaming", run_type="llm", inputs={"query": query, "chat_history": chat_history_str, "context": context}) as rt:
225227
chunk_accumulator = ""
228+
final_text: str | None = None
226229
async for chunk in self.generation_program.aforward_streaming(
227230
query=query, context=context, chat_history=chat_history_str
228231
):
229-
chunk_accumulator += chunk
230-
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
231-
rt.end(outputs={"output": chunk_accumulator})
232+
if isinstance(chunk, dspy.streaming.StreamResponse):
233+
# Incremental token
234+
chunk_accumulator += chunk.chunk
235+
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk.chunk)
236+
elif isinstance(chunk, dspy.Prediction):
237+
# Final complete answer
238+
final_text = getattr(chunk, "answer", None) or chunk_accumulator
239+
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text)
240+
rt.end(outputs={"output": final_text})
232241

233242
# Pipeline completed
234243
yield StreamEvent(type=StreamEventType.END, data=None)

python/src/cairo_coder/core/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class StreamEventType(str, Enum):
113113
SOURCES = "sources"
114114
PROCESSING = "processing"
115115
RESPONSE = "response"
116+
FINAL_RESPONSE = "final_response"
116117
END = "end"
117118
ERROR = "error"
118119

python/src/cairo_coder/dspy/generation_program.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def aforward(self, query: str, context: str, chat_history: Optional[str] =
227227

228228
async def aforward_streaming(
229229
self, query: str, context: str, chat_history: Optional[str] = None
230-
) -> AsyncGenerator[str, None]:
230+
) -> AsyncGenerator[object, None]:
231231
"""
232232
Generate Cairo code response with streaming support using DSPy's native streaming.
233233
@@ -255,18 +255,8 @@ async def aforward_streaming(
255255
query=query, context=context, chat_history=chat_history
256256
)
257257

258-
# Process the stream and yield tokens
259-
is_cached = True
260258
async for chunk in output_stream:
261-
if isinstance(chunk, dspy.streaming.StreamResponse):
262-
# No streaming if cached
263-
is_cached = False
264-
# Yield the actual token content
265-
yield chunk.chunk
266-
elif isinstance(chunk, dspy.Prediction):
267-
if is_cached:
268-
yield chunk.answer
269-
# Final output received - streaming is complete
259+
yield chunk
270260

271261
def _format_chat_history(self, chat_history: list[Message]) -> str:
272262
"""

python/src/cairo_coder/server/app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
AgentLoggingCallback,
3131
RagPipeline,
3232
)
33-
from cairo_coder.core.types import Message, Role
33+
from cairo_coder.core.types import Message, Role, StreamEventType
3434
from cairo_coder.dspy.document_retriever import SourceFilteredPgVectorRM
3535
from cairo_coder.dspy.suggestion_program import SuggestionGeneration
3636
from cairo_coder.utils.logging import setup_logging
@@ -401,7 +401,7 @@ async def _handle_chat_completion(
401401
) from e
402402

403403
async def _stream_chat_completion(
404-
self, agent, query: str, history: list[Message], mcp_mode: bool
404+
self, agent: RagPipeline, query: str, history: list[Message], mcp_mode: bool
405405
) -> AsyncGenerator[str, None]:
406406
"""Stream chat completion response - replicates TypeScript streaming."""
407407
response_id = str(uuid.uuid4())
@@ -425,14 +425,14 @@ async def _stream_chat_completion(
425425
async for event in agent.aforward_streaming(
426426
query=query, chat_history=history, mcp_mode=mcp_mode
427427
):
428-
if event.type == "sources":
428+
if event.type == StreamEventType.SOURCES:
429429
# Emit sources event for clients to display
430430
sources_chunk = {
431431
"type": "sources",
432432
"data": event.data,
433433
}
434434
yield f"data: {json.dumps(sources_chunk)}\n\n"
435-
elif event.type == "response":
435+
elif event.type == StreamEventType.RESPONSE:
436436
content_buffer += event.data
437437

438438
# Send content chunk
@@ -446,7 +446,14 @@ async def _stream_chat_completion(
446446
],
447447
}
448448
yield f"data: {json.dumps(chunk)}\n\n"
449-
elif event.type == "error":
449+
elif event.type == StreamEventType.FINAL_RESPONSE:
450+
# Emit an explicit final response event for clients
451+
final_event = {
452+
"type": "final_response",
453+
"data": event.data,
454+
}
455+
yield f"data: {json.dumps(final_event)}\n\n"
456+
elif event.type == StreamEventType.ERROR:
450457
# Emit an error as a final delta and stop
451458
error_chunk = {
452459
"id": response_id,
@@ -463,7 +470,7 @@ async def _stream_chat_completion(
463470
}
464471
yield f"data: {json.dumps(error_chunk)}\n\n"
465472
break
466-
elif event.type == "end":
473+
elif event.type == StreamEventType.END:
467474
break
468475
rt.end(outputs={"output": content_buffer})
469476

python/tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ async def mock_aforward_streaming(
168168
],
169169
)
170170
yield StreamEvent(type=StreamEventType.RESPONSE, data="Cairo is a programming language")
171+
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Cairo is a programming language")
171172
else:
172173
# Normal mode returns response
173174
yield StreamEvent(type=StreamEventType.RESPONSE, data="Hello! I'm Cairo Coder.")
174175
yield StreamEvent(type=StreamEventType.RESPONSE, data=" How can I help you?")
176+
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Hello! I'm Cairo Coder. How can I help you?")
175177
yield StreamEvent(type=StreamEventType.END, data="")
176178

177179
def mock_forward(query: str, chat_history: list[Message] | None = None, mcp_mode: bool = False):
@@ -369,8 +371,9 @@ def mock_generation_program():
369371
program.get_lm_usage = Mock(return_value={})
370372

371373
async def mock_streaming(*args, **kwargs):
372-
yield "Here's how to write "
373-
yield "Cairo contracts..."
374+
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Here's how to write ", is_last_chunk=False)
375+
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Cairo contracts...", is_last_chunk=True)
376+
yield dspy.Prediction(answer=answer)
374377

375378
program.aforward_streaming = mock_streaming
376379
return program

python/tests/integration/conftest.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,48 +7,14 @@
77

88
from unittest.mock import AsyncMock, Mock
99

10+
import dspy
1011
import pytest
1112
from fastapi.testclient import TestClient
1213

1314
from cairo_coder.agents.registry import AgentId
1415
from cairo_coder.server.app import get_agent_factory, get_vector_db
1516

1617

17-
@pytest.fixture
18-
def patch_dspy_streaming_success(monkeypatch):
19-
"""Patch dspy.streamify to emit token-like chunks and provide StreamListener.
20-
21-
Yields two chunks: "Hello " and "world".
22-
"""
23-
import dspy
24-
25-
class FakeStreamResponse:
26-
def __init__(self, chunk: str):
27-
self.chunk = chunk
28-
29-
class FakeStreamListener:
30-
def __init__(self, signature_field_name: str): # noqa: ARG002
31-
pass
32-
33-
monkeypatch.setattr(
34-
dspy,
35-
"streaming",
36-
type("S", (), {"StreamResponse": FakeStreamResponse, "StreamListener": FakeStreamListener}),
37-
)
38-
39-
def fake_streamify(_program, stream_listeners=None): # noqa: ARG001
40-
def runner(**kwargs): # noqa: ARG001
41-
async def gen():
42-
yield FakeStreamResponse("Hello ")
43-
yield FakeStreamResponse("world")
44-
45-
return gen()
46-
47-
return runner
48-
49-
monkeypatch.setattr(dspy, "streamify", fake_streamify)
50-
51-
5218
@pytest.fixture
5319
def patch_dspy_streaming_error(monkeypatch, real_pipeline):
5420
"""Patch dspy.streamify to raise an error mid-stream and provide StreamListener.
@@ -140,8 +106,9 @@ async def _fake_gen_aforward(query: str, context: str, chat_history: str | None
140106
return _dspy.Prediction(answer=responses[idx])
141107

142108
async def _fake_gen_aforward_streaming(query: str, context: str, chat_history: str | None = None):
143-
yield "Hello! I'm Cairo Coder, "
144-
yield "ready to help with Cairo programming."
109+
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Hello! I'm Cairo Coder, ", is_last_chunk=False)
110+
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="ready to help with Cairo programming.", is_last_chunk=True)
111+
yield dspy.Prediction(answer="Hello! I'm Cairo Coder, ready to help with Cairo programming.")
145112

146113
pipeline.generation_program.aforward = AsyncMock(side_effect=_fake_gen_aforward)
147114
pipeline.generation_program.aforward_streaming =_fake_gen_aforward_streaming

python/tests/integration/test_server_integration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ async def mock_aforward(query: str, chat_history=None, mcp_mode=False, **kwargs)
115115
def test_streaming_integration(
116116
self,
117117
client: TestClient,
118-
patch_dspy_streaming_success,
119118
):
120119
"""Test streaming response end-to-end using a real pipeline with low-level patches."""
121120

@@ -457,8 +456,10 @@ def test_openai_streaming_response_structure(self, client: TestClient):
457456
if data_str != "[DONE]":
458457
chunks.append(json.loads(data_str))
459458

460-
# Filter out sources events (custom event type for frontend)
461-
openai_chunks = [chunk for chunk in chunks if chunk.get("type") != "sources"]
459+
# Filter out custom frontend events (sources, final_response)
460+
openai_chunks = [
461+
chunk for chunk in chunks if chunk.get("type") not in ("sources", "final_response")
462+
]
462463

463464
for chunk in openai_chunks:
464465
required_fields = ["id", "object", "created", "model", "choices"]

python/tests/unit/test_rag_pipeline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
RagPipelineConfig,
1515
RagPipelineFactory,
1616
)
17-
from cairo_coder.core.types import Document, DocumentSource, Message, Role
17+
from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType
1818
from cairo_coder.dspy.retrieval_judge import RetrievalJudge
1919

2020

@@ -148,10 +148,11 @@ async def test_streaming_pipeline_execution(self, pipeline):
148148

149149
# Verify event sequence
150150
event_types = [e.type for e in events]
151-
assert "processing" in event_types
152-
assert "sources" in event_types
153-
assert "response" in event_types
154-
assert "end" in event_types
151+
assert StreamEventType.PROCESSING in event_types
152+
assert StreamEventType.SOURCES in event_types
153+
assert StreamEventType.RESPONSE in event_types
154+
assert StreamEventType.FINAL_RESPONSE in event_types
155+
assert StreamEventType.END in event_types
155156

156157
@pytest.mark.asyncio
157158
async def test_mcp_mode_execution(self, pipeline):

0 commit comments

Comments
 (0)