diff --git a/python/AGENTS.md b/python/AGENTS.md index de20bcc..6dd03c7 100644 --- a/python/AGENTS.md +++ b/python/AGENTS.md @@ -12,6 +12,7 @@ This repo uses a unified, deterministic testing infrastructure to keep tests fas - Unit client uses `mock_agent_factory` and `mock_vector_db`. - Integration client injects a real `RagPipeline` wired to `mock_query_processor` + `mock_vector_db` (via the same `mock_agent_factory`). - Replace ad‑hoc stubs with shared fixtures: `sample_processed_query`, `mock_query_processor`, `sample_documents`, and `mock_returned_documents` (built from `sample_documents`). +- Respect declared types. When a signature says the argument is type `T`, never guard it with `is None` or `hasattr` checks for `T`'s own surface area—just call the method and let the type system show bugs. (Example: if something is typed `dspy.Prediction`, call `get_lm_usage()` directly and set usage via `set_lm_usage`. Don't assume these attributes are not present.) ## DSPy/LLM Behavior diff --git a/python/src/cairo_coder/agents/registry.py b/python/src/cairo_coder/agents/registry.py index 1a9ac00..e6e7b55 100644 --- a/python/src/cairo_coder/agents/registry.py +++ b/python/src/cairo_coder/agents/registry.py @@ -5,8 +5,10 @@ agent system with a simple, in-memory registry of available agents. """ -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field from enum import Enum +from typing import Any from cairo_coder.core.config import VectorStoreConfig from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineFactory @@ -33,7 +35,8 @@ class AgentSpec: name: str description: str sources: list[DocumentSource] - generation_program_type: AgentId + pipeline_builder: Callable[..., RagPipeline] + builder_kwargs: dict[str, Any] = field(default_factory=dict) max_source_count: int = 5 similarity_threshold: float = 0.4 @@ -48,31 +51,15 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector Returns: Configured RagPipeline instance """ - match self.generation_program_type: - case AgentId.STARKNET: - return RagPipelineFactory.create_pipeline( - name=self.name, - vector_store_config=vector_store_config, - sources=self.sources, - query_processor=create_query_processor(), - generation_program=create_generation_program(AgentId.STARKNET), - mcp_generation_program=create_mcp_generation_program(), - max_source_count=self.max_source_count, - similarity_threshold=self.similarity_threshold, - vector_db=vector_db, - ) - case AgentId.CAIRO_CODER: - return RagPipelineFactory.create_pipeline( - name=self.name, - vector_store_config=vector_store_config, - sources=self.sources, - query_processor=create_query_processor(), - generation_program=create_generation_program(AgentId.CAIRO_CODER), - mcp_generation_program=create_mcp_generation_program(), - max_source_count=self.max_source_count, - similarity_threshold=self.similarity_threshold, - vector_db=vector_db, - ) + return self.pipeline_builder( + name=self.name, + vector_store_config=vector_store_config, + vector_db=vector_db, + sources=self.sources, + max_source_count=self.max_source_count, + similarity_threshold=self.similarity_threshold, + **self.builder_kwargs, + ) # The global registry of available agents @@ -81,7 +68,12 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector name="Cairo Coder", description="General Cairo programming assistant", sources=list(DocumentSource), # All sources - generation_program_type=AgentId.CAIRO_CODER, + pipeline_builder=RagPipelineFactory.create_pipeline, + builder_kwargs={ + "query_processor": create_query_processor(), + "generation_program": create_generation_program(AgentId.CAIRO_CODER), + "mcp_generation_program": create_mcp_generation_program(), + }, max_source_count=5, similarity_threshold=0.4, ), @@ -89,7 +81,12 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector name="Starknet Agent", description="Assistant for the Starknet ecosystem (contracts, tools, docs).", sources=list(DocumentSource), - generation_program_type=AgentId.STARKNET, + pipeline_builder=RagPipelineFactory.create_pipeline, + builder_kwargs={ + "query_processor": create_query_processor(), + "generation_program": create_generation_program(AgentId.STARKNET), + "mcp_generation_program": create_mcp_generation_program(), + }, max_source_count=5, similarity_threshold=0.4, ), diff --git a/python/src/cairo_coder/core/rag_pipeline.py b/python/src/cairo_coder/core/rag_pipeline.py index 91c3496..1397885 100644 --- a/python/src/cairo_coder/core/rag_pipeline.py +++ b/python/src/cairo_coder/core/rag_pipeline.py @@ -19,10 +19,12 @@ from cairo_coder.core.types import ( Document, DocumentSource, + FormattedSource, Message, ProcessedQuery, StreamEvent, StreamEventType, + combine_usage, title_from_url, ) from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram @@ -82,11 +84,28 @@ def __init__(self, config: RagPipelineConfig): self._current_processed_query: ProcessedQuery | None = None self._current_documents: list[Document] = [] + # Token usage accumulator + self._accumulated_usage: dict[str, dict[str, int]] = {} + @property def last_retrieved_documents(self) -> list[Document]: """Documents retrieved during the most recent pipeline execution.""" return self._current_documents + def _accumulate_usage(self, prediction: dspy.Prediction) -> None: + """ + Accumulate token usage from a prediction. + + Args: + prediction: DSPy prediction object with usage information + """ + usage = prediction.get_lm_usage() + self._accumulated_usage = combine_usage(self._accumulated_usage, usage) + + def _reset_usage(self) -> None: + """Reset accumulated usage for a new request.""" + self._accumulated_usage = {} + async def _aprocess_query_and_retrieve_docs( self, query: str, @@ -94,21 +113,28 @@ async def _aprocess_query_and_retrieve_docs( sources: list[DocumentSource] | None = None, ) -> tuple[ProcessedQuery, list[Document]]: """Process query and retrieve documents - shared async logic.""" - processed_query = await self.query_processor.aforward( + qp_prediction = await self.query_processor.aforward( query=query, chat_history=chat_history_str ) + self._accumulate_usage(qp_prediction) + processed_query = qp_prediction.processed_query self._current_processed_query = processed_query # Use provided sources or fall back to processed query sources retrieval_sources = sources or processed_query.resources - documents = await self.document_retriever.aforward( + dr_prediction = await self.document_retriever.aforward( processed_query=processed_query, sources=retrieval_sources ) + self._accumulate_usage(dr_prediction) + documents = dr_prediction.documents # Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources. try: if DocumentSource.STARKNET_BLOG in retrieval_sources: - grok_docs = await self.grok_search.aforward(processed_query, chat_history_str) + grok_pred = await self.grok_search.aforward(processed_query, chat_history_str) + self._accumulate_usage(grok_pred) + grok_docs = grok_pred.documents + self._grok_citations = list(self.grok_search.last_citations) if grok_docs: documents.extend(grok_docs) @@ -126,7 +152,9 @@ async def _aprocess_query_and_retrieve_docs( lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5), adapter=XMLAdapter(), ): - documents = await self.retrieval_judge.aforward(query=query, documents=documents) + judge_pred = await self.retrieval_judge.aforward(query=query, documents=documents) + self._accumulate_usage(judge_pred) + documents = judge_pred.documents except Exception as e: logger.warning( "Retrieval judge failed (async), using all documents", @@ -158,6 +186,9 @@ async def aforward( mcp_mode: bool = False, sources: list[DocumentSource] | None = None, ) -> dspy.Prediction: + # Reset usage for this request + self._reset_usage() + chat_history_str = self._format_chat_history(chat_history or []) processed_query, documents = await self._aprocess_query_and_retrieve_docs( query, chat_history_str, sources @@ -167,13 +198,21 @@ async def aforward( ) if mcp_mode: - return await self.mcp_generation_program.aforward(documents) + result = await self.mcp_generation_program.aforward(documents) + self._accumulate_usage(result) + result.set_lm_usage(self._accumulated_usage) + return result context = self._prepare_context(documents) - return await self.generation_program.aforward( + result = await self.generation_program.aforward( query=query, context=context, chat_history=chat_history_str ) + if result: + self._accumulate_usage(result) + # Update the result's usage to include accumulated usage from previous steps + result.set_lm_usage(self._accumulated_usage) + return result async def aforward_streaming( @@ -251,6 +290,7 @@ async def aforward_streaming( logger.warning(f"Unknown signature field name: {chunk.signature_field_name}") elif isinstance(chunk, dspy.Prediction): # Final complete answer + self._accumulate_usage(chunk) final_text = getattr(chunk, "answer", None) or chunk_accumulator yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text) rt.end(outputs={"output": final_text}) @@ -268,28 +308,12 @@ async def aforward_streaming( def get_lm_usage(self) -> dict[str, dict[str, int]]: """ - Get the total number of tokens used by the LLM. - """ - generation_usage = self.generation_program.get_lm_usage() - query_usage = self.query_processor.get_lm_usage() - judge_usage = self.retrieval_judge.get_lm_usage() - - # Additive merge strategy - merged_usage = {} - - # Helper function to merge usage dictionaries - def merge_usage_dict(target: dict, source: dict) -> None: - for model_name, metrics in source.items(): - if model_name not in target: - target[model_name] = {} - for metric_name, value in metrics.items(): - target[model_name][metric_name] = target[model_name].get(metric_name, 0) + value + Get accumulated token usage from all predictions in the pipeline. - merge_usage_dict(merged_usage, generation_usage) - merge_usage_dict(merged_usage, query_usage) - merge_usage_dict(merged_usage, judge_usage) - - return merged_usage + Returns: + Dictionary mapping model names to usage metrics + """ + return self._accumulated_usage def _format_chat_history(self, chat_history: list[Message]) -> str: """ @@ -311,7 +335,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str: return "\n".join(formatted_messages) - def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]: + def _format_sources(self, documents: list[Document]) -> list[FormattedSource]: """ Format documents for the frontend-friendly sources event. @@ -322,9 +346,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]: documents: List of retrieved documents Returns: - List of dicts: [{"title": str, "url": str}, ...] + List of formatted sources with metadata """ - sources: list[dict[str, str]] = [] + sources: list[FormattedSource] = [] seen_urls: set[str] = set() diff --git a/python/src/cairo_coder/core/types.py b/python/src/cairo_coder/core/types.py index 5a5b7ad..8486f50 100644 --- a/python/src/cairo_coder/core/types.py +++ b/python/src/cairo_coder/core/types.py @@ -74,6 +74,29 @@ class ProcessedQuery: is_test_related: bool = False resources: list[DocumentSource] = field(default_factory=list) +LMUsageEntry = dict[str, Any] +LMUsage = dict[str, LMUsageEntry] + + +class RetrievedSourceData(TypedDict): + """Structure for retrieved source data stored in database.""" + + page_content: str + metadata: DocumentMetadata + + +class FormattedSourceMetadata(TypedDict): + """Metadata structure for formatted sources sent to frontend.""" + + title: str + url: str + source_type: str + + +class FormattedSource(TypedDict): + """Structure for formatted sources sent to frontend.""" + + metadata: FormattedSourceMetadata # Helper to extract domain title def title_from_url(url: str) -> str: @@ -174,6 +197,33 @@ def to_dict(self) -> dict[str, Any]: "details": self.details, "timestamp": self.timestamp.isoformat(), } + + +def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage: + """Combine two LM usage dictionaries, tolerating missing inputs.""" + result: LMUsage = {model: (metrics or {}).copy() for model, metrics in usage1.items()} + + for model, metrics in usage2.items(): + if model not in result: + result[model] = metrics.copy() + else: + # Merge metrics + for key, value in metrics.items(): + if isinstance(value, int | float): + result[model][key] = result[model].get(key, 0) + value + elif isinstance(value, dict): + if key not in result[model] or result[model][key] is None: + result[model][key] = value.copy() + else: + # Recursive merge for nested dicts + for detail_key, detail_value in value.items(): + if isinstance(detail_value, int | float): + result[model][key][detail_key] = ( + result[model][key].get(detail_key, 0) + detail_value + ) + return result + + class AgentResponse(BaseModel): """Response from agent processing.""" diff --git a/python/src/cairo_coder/db/models.py b/python/src/cairo_coder/db/models.py index fd87b7b..6a1a243 100644 --- a/python/src/cairo_coder/db/models.py +++ b/python/src/cairo_coder/db/models.py @@ -10,6 +10,8 @@ from pydantic import BaseModel, Field +from cairo_coder.core.types import RetrievedSourceData + class UserInteraction(BaseModel): """Represents a record in the user_interactions table.""" @@ -21,5 +23,5 @@ class UserInteraction(BaseModel): chat_history: Optional[list[dict[str, Any]]] = None query: str generated_answer: Optional[str] = None - retrieved_sources: Optional[list[dict[str, Any]]] = None + retrieved_sources: Optional[list[RetrievedSourceData]] = None llm_usage: Optional[dict[str, Any]] = None diff --git a/python/src/cairo_coder/dspy/document_retriever.py b/python/src/cairo_coder/dspy/document_retriever.py index 0595dad..192aa15 100644 --- a/python/src/cairo_coder/dspy/document_retriever.py +++ b/python/src/cairo_coder/dspy/document_retriever.py @@ -565,7 +565,7 @@ def __init__( async def aforward( self, processed_query: ProcessedQuery, sources: list[DocumentSource] | None = None - ) -> list[Document]: + ) -> dspy.Prediction: """ Execute the document retrieval process asynchronously. @@ -574,7 +574,7 @@ async def aforward( sources: Optional list of DocumentSource to filter by Returns: - List of relevant Document objects, ranked by similarity + dspy.Prediction containing list of relevant Document objects, ranked by similarity """ # Use sources from processed query if not provided if sources is None: @@ -584,10 +584,15 @@ async def aforward( documents = await self._afetch_documents(processed_query, sources) if not documents: - return [] + empty_prediction = dspy.Prediction(documents=[]) + empty_prediction.set_lm_usage({}) + return empty_prediction # Step 2: Enrich context with appropriate templates based on query type. - return self._enhance_context(processed_query, documents) + enhanced_documents = self._enhance_context(processed_query, documents) + prediction = dspy.Prediction(documents=enhanced_documents) + prediction.set_lm_usage({}) + return prediction def forward( self, processed_query: ProcessedQuery, sources: list[DocumentSource] | None = None @@ -701,7 +706,11 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume context.append( Document( page_content=CONTRACT_TEMPLATE, - metadata={"title": CONTRACT_TEMPLATE_TITLE, "source": CONTRACT_TEMPLATE_TITLE, "sourceLink": "https://www.starknet.io/cairo-book/ch103-06-01-deploying-and-interacting-with-a-voting-contract.html"}, + metadata={ + "title": CONTRACT_TEMPLATE_TITLE, + "source": DocumentSource.CAIRO_BOOK, + "sourceLink": "https://www.starknet.io/cairo-book/ch103-06-01-deploying-and-interacting-with-a-voting-contract.html", + }, ) ) @@ -710,7 +719,11 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume context.append( Document( page_content=TEST_TEMPLATE, - metadata={"title": TEST_TEMPLATE_TITLE, "source": TEST_TEMPLATE_TITLE, "sourceLink": "https://www.starknet.io/cairo-book/ch104-02-testing-smart-contracts.html"}, + metadata={ + "title": TEST_TEMPLATE_TITLE, + "source": DocumentSource.CAIRO_BOOK, + "sourceLink": "https://www.starknet.io/cairo-book/ch104-02-testing-smart-contracts.html", + }, ) ) return context diff --git a/python/src/cairo_coder/dspy/generation_program.py b/python/src/cairo_coder/dspy/generation_program.py index 49370a4..e93eb19 100644 --- a/python/src/cairo_coder/dspy/generation_program.py +++ b/python/src/cairo_coder/dspy/generation_program.py @@ -192,12 +192,6 @@ def __init__(self, program_type): raise FileNotFoundError(f"{compiled_program_path} not found") self.generation_program.load(compiled_program_path) - def get_lm_usage(self) -> dict[str, int]: - """ - Get the total number of tokens used by the LLM. - """ - return self.generation_program.get_lm_usage() - @traceable( name="GenerationProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm} ) @@ -339,14 +333,6 @@ async def aforward(self, documents: list[Document]) -> dspy.Prediction: """ return self(documents) - def get_lm_usage(self) -> dict[str, int]: - """ - Get the total number of tokens used by the LLM. - Note: MCP mode doesn't use LLM generation, so no tokens are consumed. - """ - # MCP mode doesn't use LLM generation, return empty dict - return {} - def create_generation_program(program_type: str) -> GenerationProgram: """ diff --git a/python/src/cairo_coder/dspy/grok_search.py b/python/src/cairo_coder/dspy/grok_search.py index 0815658..7091905 100644 --- a/python/src/cairo_coder/dspy/grok_search.py +++ b/python/src/cairo_coder/dspy/grok_search.py @@ -96,13 +96,14 @@ def _domain_from_url(url: str) -> str: return url @traceable(name="GrokSearchProgram", run_type="llm") - async def aforward(self, processed_query: ProcessedQuery, chat_history: str) -> list[Document]: + async def aforward(self, processed_query: ProcessedQuery, chat_history: str) -> dspy.Prediction: formatted_query = f"""Answer the following query: {processed_query.original}. \ Here is the chat history: {chat_history}, that might be relevant to the question. \ For more context, here are some semantic terms associated with the question: \ {', '.join(processed_query.search_queries)}. \ Make sure that your final answer will contain links to the relevant sources used to construct your answer. """ + # TODO: track LM usage chat = self.client.chat.create( model=DEFAULT_GROK_MODEL, tools=[web_search(), x_search()], @@ -147,4 +148,6 @@ async def aforward(self, processed_query: ProcessedQuery, chat_history: str) -> ) ) - return documents + prediction = dspy.Prediction(documents=documents) + prediction.set_lm_usage({}) + return prediction diff --git a/python/src/cairo_coder/dspy/query_processor.py b/python/src/cairo_coder/dspy/query_processor.py index 6886bfc..c0d8447 100644 --- a/python/src/cairo_coder/dspy/query_processor.py +++ b/python/src/cairo_coder/dspy/query_processor.py @@ -125,7 +125,7 @@ def __init__(self): } @traceable(name="QueryProcessorProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm}) - async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery: + async def aforward(self, query: str, chat_history: Optional[str] = None) -> dspy.Prediction: """ Process a user query into a structured format for document retrieval. @@ -134,7 +134,7 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> Proc chat_history: Previous conversation context (optional) Returns: - ProcessedQuery with search terms, resource identification, and categorization + dspy.Prediction containing processed_query and attached usage """ # Execute the DSPy retrieval program result = await self.retrieval_program.aforward(query=query, chat_history=chat_history) @@ -144,7 +144,7 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> Proc resources = self._validate_resources(result.resources) # Build structured query result - return ProcessedQuery( + processed_query = ProcessedQuery( original=query, search_queries=search_queries, is_contract_related=self._is_contract_query(query), @@ -152,11 +152,10 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> Proc resources=resources, ) - def get_lm_usage(self) -> dict[str, int]: - """ - Get the total number of tokens used by the LLM. - """ - return self.retrieval_program.get_lm_usage() + prediction = dspy.Prediction(processed_query=processed_query) + prediction.set_lm_usage(result.get_lm_usage() or {}) + + return prediction def _validate_resources(self, resources: list[str]) -> list[DocumentSource]: """ diff --git a/python/src/cairo_coder/dspy/retrieval_judge.py b/python/src/cairo_coder/dspy/retrieval_judge.py index cb6eff4..72546f1 100644 --- a/python/src/cairo_coder/dspy/retrieval_judge.py +++ b/python/src/cairo_coder/dspy/retrieval_judge.py @@ -16,7 +16,7 @@ from langsmith import traceable import dspy -from cairo_coder.core.types import Document +from cairo_coder.core.types import Document, combine_usage from cairo_coder.dspy.document_retriever import CONTRACT_TEMPLATE_TITLE, TEST_TEMPLATE_TITLE logger = structlog.get_logger(__name__) @@ -135,14 +135,16 @@ def __init__(self): @traceable( name="RetrievalJudge", run_type="llm", metadata={"llm_provider": dspy.settings.lm} ) - async def aforward(self, query: str, documents: list[Document]) -> list[Document]: + async def aforward(self, query: str, documents: list[Document]) -> dspy.Prediction: """Async judge.""" if not documents: - return documents + return dspy.Prediction(documents=documents) keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs( documents ) + + aggregated_usage = {} # TODO: can we use dspy.Parallel here instead of asyncio gather? if judged_payloads: @@ -154,6 +156,12 @@ async def judge_one(doc_string: str): results = await asyncio.gather( *[judge_one(ds) for ds in judged_payloads], return_exceptions=True ) + + # Aggregate usage from results + for res in results: + if isinstance(res, dspy.Prediction): + aggregated_usage = combine_usage(aggregated_usage, res.get_lm_usage()) + self._attach_scores_and_filter_async( query=query, documents=documents, @@ -167,15 +175,11 @@ async def judge_one(doc_string: str): error=str(e), exc_info=True, ) - return documents - - return keep_docs + return dspy.Prediction(documents=documents) - def get_lm_usage(self) -> dict[str, int]: - """ - Get the total number of tokens used by the LLM. - """ - return self.rater.get_lm_usage() + pred = dspy.Prediction(documents=keep_docs) + pred.set_lm_usage(aggregated_usage) + return pred # ========================= # Internal Helpers diff --git a/python/src/cairo_coder/server/app.py b/python/src/cairo_coder/server/app.py index 9a9aa9c..bab90f6 100644 --- a/python/src/cairo_coder/server/app.py +++ b/python/src/cairo_coder/server/app.py @@ -20,7 +20,7 @@ from dspy.adapters import ChatAdapter, XMLAdapter from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field, field_validator from cairo_coder.config.manager import ConfigManager @@ -171,8 +171,7 @@ async def log_interaction_task( query=query, generated_answer=response.choices[0].message.content if response.choices else None, retrieved_sources=sources_data, - # TODO: fix LLM usage metrics - llm_usage={} + llm_usage=agent.get_lm_usage(), ) await create_user_interaction(interaction) @@ -203,7 +202,7 @@ async def log_interaction_raw( query=query, generated_answer=generated_answer, retrieved_sources=sources_data, - llm_usage={}, + llm_usage=agent.get_lm_usage() ) await create_user_interaction(interaction) @@ -247,6 +246,9 @@ def __init__( self.app.include_router(insights_router) + # Setup global exception handler + self._setup_exception_handlers() + # Setup routes self._setup_routes() @@ -258,6 +260,43 @@ def __init__( track_usage=True, ) + def _setup_exception_handlers(self): + """Setup global exception handlers for the application.""" + + @self.app.exception_handler(ValueError) + async def value_error_handler(request: Request, exc: ValueError): + """Handle ValueError as 400 Bad Request.""" + logger.warning("Bad request", error=str(exc), path=request.url.path) + return JSONResponse( + status_code=400, + content={ + "detail": ErrorResponse( + error=ErrorDetail( + message=str(exc), + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump() + }, + ) + + @self.app.exception_handler(Exception) + async def global_exception_handler(request: Request, exc: Exception): + """Handle all unhandled exceptions as 500 Internal Server Error.""" + logger.error("Unhandled exception", error=str(exc), path=request.url.path, exc_info=True) + return JSONResponse( + status_code=500, + content={ + "detail": ErrorResponse( + error=ErrorDetail( + message=f"Internal server error: {str(exc)}", + type="server_error", + code="internal_error", + ) + ).model_dump() + }, + ) + def _setup_routes(self): """Setup FastAPI routes matching TypeScript backend.""" @@ -271,40 +310,26 @@ async def list_agents( agent_factory: AgentFactory = Depends(get_agent_factory), ): """List all available agents.""" - try: - # Create agent factory with injected vector_db - available_agents = agent_factory.get_available_agents() - agents_info = [] - - for agent_id in available_agents: - try: - info = agent_factory.get_agent_info(agent_id=agent_id) - agents_info.append( - AgentInfo( - id=info["id"], - name=info["name"], - description=info["description"], - sources=info["sources"], - ) - ) - except Exception as e: - logger.warning( - "Failed to get agent info", agent_id=agent_id, error=str(e), exc_info=True + available_agents = agent_factory.get_available_agents() + agents_info = [] + + for agent_id in available_agents: + try: + info = agent_factory.get_agent_info(agent_id=agent_id) + agents_info.append( + AgentInfo( + id=info["id"], + name=info["name"], + description=info["description"], + sources=info["sources"], ) + ) + except Exception as e: + logger.warning( + "Failed to get agent info", agent_id=agent_id, error=str(e), exc_info=True + ) - return agents_info - except Exception as e: - logger.error("Failed to list agents", error=str(e), exc_info=True) - raise HTTPException( - status_code=500, - detail=ErrorResponse( - error=ErrorDetail( - message="Failed to list agents", - type="server_error", - code="internal_error", - ) - ).dict(), - ) from e + return agents_info @self.app.post("/v1/agents/{agent_id}/chat/completions") async def agent_chat_completions( @@ -318,21 +343,19 @@ async def agent_chat_completions( agent_factory: AgentFactory = Depends(get_agent_factory), ): """Agent-specific chat completions""" - # Create agent factory to validate agent exists try: agent_factory.get_agent_info(agent_id=agent_id) - except ValueError as e: + except ValueError as exc: raise HTTPException( status_code=404, - detail=ErrorResponse( - error=ErrorDetail( - message=f"Agent '{agent_id}' not found", - type="invalid_request_error", - code="agent_not_found", - param="agent_id", - ) - ).dict(), - ) from e + detail={ + "error": { + "message": str(exc), + "type": "invalid_request_error", + "code": "agent_not_found", + } + }, + ) from exc # Determine MCP mode mcp_mode = bool(mcp or x_mcp_mode) @@ -380,28 +403,14 @@ async def chat_completions( @self.app.post("/v1/suggestions", response_model=SuggestionResponse) async def generate_suggestions(request: SuggestionRequest): """Generate follow-up conversation suggestions based on chat history.""" - try: - formatted_history = self._format_chat_history_for_suggestions(request.chat_history) - suggestion_program = dspy.Predict(SuggestionGeneration) - with dspy.context( - lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000), adapter=XMLAdapter() - ): - result = await suggestion_program.aforward(chat_history=formatted_history) - suggestions = result.suggestions if isinstance(result.suggestions, list) else [] - return SuggestionResponse(suggestions=suggestions) - - except Exception as e: - logger.error("Error generating suggestions", error=str(e), exc_info=True) - raise HTTPException( - status_code=500, - detail=ErrorResponse( - error=ErrorDetail( - message="Failed to generate suggestions", - type="server_error", - code="internal_error", - ) - ).dict(), - ) from e + formatted_history = self._format_chat_history_for_suggestions(request.chat_history) + suggestion_program = dspy.Predict(SuggestionGeneration) + with dspy.context( + lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000), adapter=XMLAdapter() + ): + result = await suggestion_program.aforward(chat_history=formatted_history) + suggestions = result.suggestions if isinstance(result.suggestions, list) else [] + return SuggestionResponse(suggestions=suggestions) async def _handle_chat_completion( self, @@ -414,70 +423,48 @@ async def _handle_chat_completion( vector_db: SourceFilteredPgVectorRM | None = None, ): """Handle chat completion request.""" - try: - # Convert messages to internal format - messages = [] - for msg in request.messages: - messages.append(Message(role=msg.role, content=msg.content)) + # Convert messages to internal format + messages = [] + for msg in request.messages: + messages.append(Message(role=msg.role, content=msg.content)) - # Get last user message as query - query = request.messages[-1].content + # Get last user message as query + query = request.messages[-1].content - # Determine agent ID (fallback to cairo-coder) - effective_agent_id = agent_id or "cairo-coder" + # Determine agent ID (fallback to cairo-coder) + effective_agent_id = agent_id or "cairo-coder" - # Create agent - agent = agent_factory.get_or_create_agent( - agent_id=effective_agent_id, - mcp_mode=mcp_mode, - ) + # Create agent + agent = agent_factory.get_or_create_agent( + agent_id=effective_agent_id, + mcp_mode=mcp_mode, + ) - # Handle streaming vs non-streaming - if request.stream: - return StreamingResponse( - self._stream_chat_completion(agent, query, messages[:-1], mcp_mode, effective_agent_id), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - chat_history = messages[:-1] - response = await self._generate_chat_completion(agent, query, chat_history, mcp_mode) - - background_tasks.add_task( - log_interaction_task, - agent_id=effective_agent_id, - mcp_mode=mcp_mode, - query=query, - chat_history=chat_history, - response=response, - agent=agent, + # Handle streaming vs non-streaming + if request.stream: + return StreamingResponse( + self._stream_chat_completion(agent, query, messages[:-1], mcp_mode, effective_agent_id), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, ) + chat_history = messages[:-1] + response = await self._generate_chat_completion(agent, query, chat_history, mcp_mode) + + background_tasks.add_task( + log_interaction_task, + agent_id=effective_agent_id, + mcp_mode=mcp_mode, + query=query, + chat_history=chat_history, + response=response, + agent=agent, + ) - return response - - except ValueError as e: - raise HTTPException( - status_code=400, - detail=ErrorResponse( - error=ErrorDetail( - message=str(e), type="invalid_request_error", code="invalid_request" - ) - ).dict(), - ) from e - - except Exception as e: - logger.error("Error in chat completion", error=str(e), exc_info=True) - raise HTTPException( - status_code=500, - detail=ErrorResponse( - error=ErrorDetail( - message="Internal server error", type="server_error", code="internal_error" - ) - ).model_dump(), - ) from e + return response async def _stream_chat_completion( self, agent: RagPipeline, query: str, history: list[Message], mcp_mode: bool, agent_id: str @@ -644,10 +631,9 @@ async def _generate_chat_completion( answer = response.answer - # Somehow this is not always returning something (None). In that case, we're not capable of getting the - # tracked usage. - lm_usage = response.get_lm_usage() - logger.info(f"LM usage from response: {lm_usage}") + # Get accumulated usage from the pipeline (not the prediction) + lm_usage = agent.get_lm_usage() + logger.info(f"LM usage from pipeline: {lm_usage}") if not lm_usage: logger.warning("No LM usage data available, setting defaults to 0") diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 42bf3ba..93f57a7 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -287,7 +287,7 @@ def client(server, postgres_container, real_pipeline, mock_vector_db, mock_agent server.app.dependency_overrides[get_vector_db] = lambda: mock_vector_db server.app.dependency_overrides[get_agent_factory] = lambda: mock_agent_factory - return TestClient(server.app) + return TestClient(server.app, raise_server_exceptions=False) # ============================================================================= # Sample Data Fixtures @@ -403,8 +403,11 @@ def clean_config_env_vars(monkeypatch): def mock_query_processor(sample_processed_query): """Create a mock QueryProcessorProgram.""" processor = Mock(spec=QueryProcessorProgram) - processor.forward = Mock(return_value=sample_processed_query) - processor.aforward = AsyncMock(return_value=sample_processed_query) + prediction = dspy.Prediction(processed_query=sample_processed_query) + prediction.set_lm_usage({}) + + processor.forward = Mock(return_value=prediction) + processor.aforward = AsyncMock(return_value=prediction) processor.get_lm_usage = Mock(return_value={}) return processor @@ -413,8 +416,11 @@ def mock_query_processor(sample_processed_query): def mock_document_retriever(sample_documents): """Create a mock DocumentRetrieverProgram.""" retriever = Mock(spec=DocumentRetrieverProgram) - retriever.forward = Mock(return_value=sample_documents) - retriever.aforward = AsyncMock(return_value=sample_documents) + prediction = dspy.Prediction(documents=sample_documents) + prediction.set_lm_usage({}) + + retriever.forward = Mock(return_value=prediction) + retriever.aforward = AsyncMock(return_value=prediction) retriever.get_lm_usage = Mock(return_value={}) return retriever @@ -424,14 +430,21 @@ def mock_generation_program(): """Create a mock GenerationProgram.""" program = Mock(spec=GenerationProgram) answer = "Here's how to write Cairo contracts..." - program.forward = Mock(return_value=dspy.Prediction(answer=answer)) - program.aforward = AsyncMock(return_value=dspy.Prediction(answer=answer)) + + # Create predictions with usage tracking + prediction = dspy.Prediction(answer=answer) + prediction.set_lm_usage({}) + + program.forward = Mock(return_value=prediction) + program.aforward = AsyncMock(return_value=prediction) program.get_lm_usage = Mock(return_value={}) async def mock_streaming(*args, **kwargs): yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Here's how to write ", is_last_chunk=False) yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Cairo contracts...", is_last_chunk=True) - yield dspy.Prediction(answer=answer) + final_prediction = dspy.Prediction(answer=answer) + final_prediction.set_lm_usage({}) + yield final_prediction program.aforward_streaming = mock_streaming return program @@ -458,7 +471,11 @@ def mock_mcp_generation_program(): Storage variables use #[storage] attribute. """ - program.aforward = AsyncMock(return_value=dspy.Prediction(answer=mcp_answer)) + # Create prediction with usage tracking + prediction = dspy.Prediction(answer=mcp_answer) + prediction.set_lm_usage({}) + + program.aforward = AsyncMock(return_value=prediction) program.get_lm_usage = Mock(return_value={}) return program @@ -532,15 +549,44 @@ def pipeline_config( +@pytest.fixture +def pipeline_config_for_pipeline( + mock_vector_store_config, + mock_query_processor, + mock_document_retriever, + mock_generation_program, + mock_mcp_generation_program, +): + """Create a pipeline configuration with prediction-returning mocks.""" + return RagPipelineConfig( + name="test_pipeline", + vector_store_config=mock_vector_store_config, + query_processor=mock_query_processor, + document_retriever=mock_document_retriever, + generation_program=mock_generation_program, + mcp_generation_program=mock_mcp_generation_program, + sources=list(DocumentSource), + max_source_count=10, + similarity_threshold=0.4, + ) + + @pytest.fixture(scope="function") -def pipeline(pipeline_config): +def pipeline(pipeline_config_for_pipeline): """Create a RagPipeline instance.""" with patch("cairo_coder.core.rag_pipeline.RetrievalJudge") as mock_judge_class: mock_judge = Mock() mock_judge.get_lm_usage.return_value = {} - mock_judge.aforward = AsyncMock(side_effect=lambda query, documents: documents) + + # Judge should return prediction with documents + async def judge_aforward(query, documents): + prediction = dspy.Prediction(documents=documents) + prediction.set_lm_usage({}) + return prediction + + mock_judge.aforward = AsyncMock(side_effect=judge_aforward) mock_judge_class.return_value = mock_judge - return RagPipeline(pipeline_config) + return RagPipeline(pipeline_config_for_pipeline) @pytest.fixture(scope="function") def rag_pipeline(pipeline_config): diff --git a/python/tests/integration/conftest.py b/python/tests/integration/conftest.py index 9902fc5..22ba3f9 100644 --- a/python/tests/integration/conftest.py +++ b/python/tests/integration/conftest.py @@ -86,7 +86,12 @@ def real_pipeline(mock_query_processor, mock_vector_store_config, mock_vector_db # Avoid LLM calls in the judge and non-streaming generation from unittest.mock import AsyncMock, Mock - pipeline.retrieval_judge.aforward = AsyncMock(side_effect=lambda query, documents: documents) + async def _judge_aforward(query, documents): + prediction = dspy.Prediction(documents=documents) + prediction.set_lm_usage({}) + return prediction + + pipeline.retrieval_judge.aforward = AsyncMock(side_effect=_judge_aforward) pipeline.retrieval_judge.get_lm_usage = Mock(return_value={}) # Patch non-streaming generation to mimic conversation turns using chat_history @@ -101,24 +106,32 @@ async def _fake_gen_aforward(query: str, context: str, chat_history: str | None "You can deploy it using Scarb with the deploy command.", ] idx = min((len(lines)) // 2, len(responses) - 1) - return _dspy.Prediction(answer=responses[idx]) + prediction = _dspy.Prediction(answer=responses[idx]) + prediction.set_lm_usage({}) + return prediction async def _fake_gen_aforward_streaming(query: str, context: str, chat_history: str | None = None): yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Hello! I'm Cairo Coder, ", is_last_chunk=False) yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="ready to help with Cairo programming.", is_last_chunk=True) - yield dspy.Prediction(answer="Hello! I'm Cairo Coder, ready to help with Cairo programming.") + final_prediction = dspy.Prediction(answer="Hello! I'm Cairo Coder, ready to help with Cairo programming.") + final_prediction.set_lm_usage({}) + yield final_prediction pipeline.generation_program.aforward = AsyncMock(side_effect=_fake_gen_aforward) pipeline.generation_program.aforward_streaming =_fake_gen_aforward_streaming pipeline.generation_program.get_lm_usage = Mock(return_value={}) # Patch MCP generation to a deterministic simple string as tests expect - pipeline.mcp_generation_program.aforward = AsyncMock( - return_value=_dspy.Prediction(answer="Cairo is a programming language") - ) - pipeline.mcp_generation_program.forward = lambda documents: _dspy.Prediction( - answer="Cairo is a programming language" - ) + mcp_prediction = _dspy.Prediction(answer="Cairo is a programming language") + mcp_prediction.set_lm_usage({}) + pipeline.mcp_generation_program.aforward = AsyncMock(return_value=mcp_prediction) + + def _mcp_forward(documents): # noqa: ARG001 - deterministic response + prediction = _dspy.Prediction(answer="Cairo is a programming language") + prediction.set_lm_usage({}) + return prediction + + pipeline.mcp_generation_program.forward = _mcp_forward return pipeline diff --git a/python/tests/integration/test_insights_api.py b/python/tests/integration/test_insights_api.py index c49e212..8974a6b 100644 --- a/python/tests/integration/test_insights_api.py +++ b/python/tests/integration/test_insights_api.py @@ -163,6 +163,10 @@ async def test_chat_completion_logs_interaction_to_db(self, client, test_db_pool "messages": [{"role": "user", "content": "Hello"}], "stream": False, } + + async with test_db_pool.acquire() as conn: + initial_count = await conn.fetchval("SELECT COUNT(*) FROM user_interactions") + resp = client.post("/v1/chat/completions", json=payload) assert resp.status_code == 200 @@ -171,11 +175,11 @@ async def test_chat_completion_logs_interaction_to_db(self, client, test_db_pool for _ in range(50): async with test_db_pool.acquire() as conn: count = await conn.fetchval("SELECT COUNT(*) FROM user_interactions") - if count >= 1: + if count >= initial_count + 1: break await asyncio.sleep(0.05) - assert count >= 1 + assert count >= initial_count + 1 # Verify content matches request/response shape async with test_db_pool.acquire() as conn: @@ -199,6 +203,10 @@ async def test_streaming_chat_completion_logs_interaction_to_db(self, client, te "messages": [{"role": "user", "content": "Hello streaming"}], "stream": True, } + + async with test_db_pool.acquire() as conn: + initial_count = await conn.fetchval("SELECT COUNT(*) FROM user_interactions") + # The `with` statement ensures the full request/response cycle completes with client.stream("POST", "/v1/chat/completions", json=payload) as response: assert response.status_code == 200 @@ -210,11 +218,11 @@ async def test_streaming_chat_completion_logs_interaction_to_db(self, client, te for _ in range(50): # Wait up to ~2.5 seconds async with test_db_pool.acquire() as conn: count = await conn.fetchval("SELECT COUNT(*) FROM user_interactions") - if count >= 1: + if count >= initial_count + 1: break await asyncio.sleep(0.05) - assert count >= 1, "Interaction was not logged for streaming request" + assert count >= initial_count + 1, "Interaction was not logged for streaming request" # Verify the logged data async with test_db_pool.acquire() as conn: diff --git a/python/tests/integration/test_server_integration.py b/python/tests/integration/test_server_integration.py index 72fee07..d34aca9 100644 --- a/python/tests/integration/test_server_integration.py +++ b/python/tests/integration/test_server_integration.py @@ -49,7 +49,7 @@ def test_list_agents_error_handling(self, client: TestClient, mock_agent_factory data = response.json() assert "detail" in data - assert data["detail"]["error"]["message"] == "Failed to list agents" + assert data["detail"]["error"]["message"] == "Internal server error: Database error" assert data["detail"]["error"]["type"] == "server_error" def test_full_agent_workflow(self, client: TestClient, mock_agent: Mock): @@ -258,8 +258,6 @@ def test_agent_chat_completions_valid_agent(self, client: TestClient): def test_agent_chat_completions_invalid_agent(self, client: TestClient, mock_agent_factory: Mock): """Test agent-specific chat completions with invalid agent.""" - mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") - response = client.post( "/v1/agents/unknown-agent/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}]}, diff --git a/python/tests/unit/db/test_repository.py b/python/tests/unit/db/test_repository.py index d50dcb7..b6ee69e 100644 --- a/python/tests/unit/db/test_repository.py +++ b/python/tests/unit/db/test_repository.py @@ -15,6 +15,8 @@ import pytest +from cairo_coder.core.types import DocumentSource + # Import shared fixtures from integration conftest pytest_plugins = ["tests.integration.conftest"] @@ -47,7 +49,9 @@ async def test_create_user_interaction(test_db_pool, db_connection): chat_history=[{"role": "user", "content": "Hello"}], query="Hello", generated_answer="Hi", - retrieved_sources=[{"pageContent": "Cairo", "metadata": {"source": "cairo_book"}}], + retrieved_sources=[ + {"page_content": "Cairo", "metadata": {"source": DocumentSource.CAIRO_BOOK}} + ], llm_usage={"model": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}}, ) @@ -187,4 +191,3 @@ async def test_migrate_user_interaction_upsert(test_db_pool, db_connection): # Verify still only one record count = await db_connection.fetchval("SELECT COUNT(*) FROM user_interactions WHERE id = $1", interaction_id) assert count == 1 - diff --git a/python/tests/unit/test_agent_factory.py b/python/tests/unit/test_agent_factory.py index 09cc545..53a777c 100644 --- a/python/tests/unit/test_agent_factory.py +++ b/python/tests/unit/test_agent_factory.py @@ -155,34 +155,46 @@ def test_get_agent_by_string_id_invalid(self): with pytest.raises(ValueError, match="Agent not found: invalid"): get_agent_by_string_id("invalid") - @patch("cairo_coder.core.rag_pipeline.RagPipelineFactory.create_pipeline") - def test_agent_spec_build_general(self, mock_create_pipeline, mock_vector_db, mock_vector_store_config): + def test_agent_spec_build_general(self, mock_vector_db, mock_vector_store_config): """Test building a general agent from spec.""" spec = registry[AgentId.CAIRO_CODER] mock_pipeline = Mock(spec=RagPipeline) - mock_create_pipeline.return_value = mock_pipeline - pipeline = spec.build(mock_vector_db, mock_vector_store_config) + # Patch the spec's pipeline_builder directly + original_builder = spec.pipeline_builder + spec.pipeline_builder = Mock(return_value=mock_pipeline) - assert pipeline == mock_pipeline - mock_create_pipeline.assert_called_once() - call_args = mock_create_pipeline.call_args[1] - assert call_args["name"] == "Cairo Coder" - assert call_args["vector_db"] == mock_vector_db - assert call_args["vector_store_config"] == mock_vector_store_config + try: + pipeline = spec.build(mock_vector_db, mock_vector_store_config) - @patch("cairo_coder.core.rag_pipeline.RagPipelineFactory.create_pipeline") - def test_agent_spec_build_scarb(self, mock_create_scarb, mock_vector_db, mock_vector_store_config): + assert pipeline == mock_pipeline + spec.pipeline_builder.assert_called_once() + call_args = spec.pipeline_builder.call_args[1] + assert call_args["name"] == "Cairo Coder" + assert call_args["vector_db"] == mock_vector_db + assert call_args["vector_store_config"] == mock_vector_store_config + finally: + # Restore original builder + spec.pipeline_builder = original_builder + + def test_agent_spec_build_scarb(self, mock_vector_db, mock_vector_store_config): """Test building a Starknet agent from spec.""" spec = registry[AgentId.STARKNET] mock_pipeline = Mock(spec=RagPipeline) - mock_create_scarb.return_value = mock_pipeline - - pipeline = spec.build(mock_vector_db, mock_vector_store_config) - assert pipeline == mock_pipeline - mock_create_scarb.assert_called_once() - call_args = mock_create_scarb.call_args[1] - assert call_args["name"] == "Starknet Agent" - assert call_args["vector_db"] == mock_vector_db - assert call_args["vector_store_config"] == mock_vector_store_config + # Patch the spec's pipeline_builder directly + original_builder = spec.pipeline_builder + spec.pipeline_builder = Mock(return_value=mock_pipeline) + + try: + pipeline = spec.build(mock_vector_db, mock_vector_store_config) + + assert pipeline == mock_pipeline + spec.pipeline_builder.assert_called_once() + call_args = spec.pipeline_builder.call_args[1] + assert call_args["name"] == "Starknet Agent" + assert call_args["vector_db"] == mock_vector_db + assert call_args["vector_store_config"] == mock_vector_store_config + finally: + # Restore original builder + spec.pipeline_builder = original_builder diff --git a/python/tests/unit/test_document_retriever.py b/python/tests/unit/test_document_retriever.py index 9a39898..101b9cb 100644 --- a/python/tests/unit/test_document_retriever.py +++ b/python/tests/unit/test_document_retriever.py @@ -51,9 +51,11 @@ async def test_basic_document_retrieval( retriever.vector_db.aforward.return_value = mock_dspy_examples # Execute retrieval - use async version since we're in async test - result = await retriever.aforward(sample_processed_query) + prediction = await retriever.aforward(sample_processed_query) - # Verify results + # Verify results - aforward now returns a Prediction with documents attribute + assert isinstance(prediction, dspy.Prediction) + result = prediction.documents assert len(result) != 0 assert all(isinstance(doc, Document) for doc in result) @@ -78,10 +80,11 @@ async def test_retrieval_with_empty_transformed_terms(self, retriever: DocumentR resources=[DocumentSource.CAIRO_BOOK], ) - result = await retriever.aforward(query) + prediction = await retriever.aforward(query) # Should still work with empty transformed terms - assert len(result) != 0 + assert isinstance(prediction, dspy.Prediction) + assert len(prediction.documents) != 0 # Query should just be the reasoning with empty tags expected_query = query.original @@ -95,9 +98,11 @@ async def test_retrieval_with_custom_sources(self, retriever, sample_processed_q # Override sources custom_sources = [DocumentSource.SCARB_DOCS, DocumentSource.OPENZEPPELIN_DOCS] - result = await retriever.aforward(sample_processed_query, sources=custom_sources) + prediction = await retriever.aforward(sample_processed_query, sources=custom_sources) # Verify result + assert isinstance(prediction, dspy.Prediction) + result = prediction.documents assert len(result) != 0 # Note: sources filtering is not currently implemented in PgVectorRM call @@ -109,9 +114,11 @@ async def test_empty_document_handling(self, retriever, sample_processed_query): """Test handling of empty document results.""" retriever.vector_db.aforward = AsyncMock(return_value=[]) - result = await retriever.aforward(sample_processed_query) + prediction = await retriever.aforward(sample_processed_query) - assert result == [] + # Should return prediction with empty documents list + assert isinstance(prediction, dspy.Prediction) + assert prediction.documents == [] @pytest.mark.asyncio async def test_pgvector_rm_error_handling(self, retriever, sample_processed_query): @@ -174,7 +181,8 @@ async def test_document_conversion( retriever.vector_db.aforward = AsyncMock(return_value=mock_examples) - result = await retriever.aforward(sample_processed_query) + prediction = await retriever.aforward(sample_processed_query) + result = prediction.documents # Verify conversion to Document objects # Ran 3 times the query, returned 2 docs each - but de-duped @@ -235,12 +243,13 @@ async def test_context_enhancement( ) mock_vector_db.aforward.return_value = mock_dspy_examples - result: list[Document] = await retriever.aforward(query) + prediction = await retriever.aforward(query) + result: list[Document] = prediction.documents found_templates = { - doc.source + doc.title for doc in result - if "Template" in doc.source + if "Template" in doc.title } assert set(expected_templates) == found_templates diff --git a/python/tests/unit/test_grok_integration.py b/python/tests/unit/test_grok_integration.py index a7700e7..bc99c5b 100644 --- a/python/tests/unit/test_grok_integration.py +++ b/python/tests/unit/test_grok_integration.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock +import dspy import pytest from cairo_coder.core.types import Document, DocumentSource @@ -46,7 +47,9 @@ async def test_grok_citations_emitted_in_sources_and_summary_excluded( ): # Mock Grok module on the pipeline instance grok_doc = _make_grok_summary_doc(GROK_ANSWER) - pipeline.grok_search.aforward = AsyncMock(return_value=[grok_doc]) + grok_prediction = dspy.Prediction(documents=[grok_doc]) + grok_prediction.set_lm_usage({}) + pipeline.grok_search.aforward = AsyncMock(return_value=grok_prediction) pipeline.grok_search.last_citations = list(GROK_CITATIONS) # Stream to get SOURCES event @@ -78,7 +81,9 @@ async def test_grok_summary_is_first_in_generation_context( ): # Mock Grok module on the pipeline instance grok_doc = _make_grok_summary_doc(GROK_ANSWER) - pipeline.grok_search.aforward = AsyncMock(return_value=[grok_doc]) + grok_prediction = dspy.Prediction(documents=[grok_doc]) + grok_prediction.set_lm_usage({}) + pipeline.grok_search.aforward = AsyncMock(return_value=grok_prediction) pipeline.grok_search.last_citations = list(GROK_CITATIONS) await pipeline.aforward( diff --git a/python/tests/unit/test_query_processor.py b/python/tests/unit/test_query_processor.py index 3d793f3..cba3a77 100644 --- a/python/tests/unit/test_query_processor.py +++ b/python/tests/unit/test_query_processor.py @@ -34,8 +34,11 @@ async def test_contract_query_processing(self, mock_lm_predict, processor): query = "How do I define storage variables in a Cairo contract?" - result = await processor.aforward(query) + result_prediction = await processor.aforward(query) + # aforward now returns a Prediction with processed_query attribute + assert isinstance(result_prediction, dspy.Prediction) + result = result_prediction.processed_query assert isinstance(result, ProcessedQuery) assert result.original == query assert result.is_contract_related is True @@ -96,8 +99,11 @@ async def test_empty_query_handling(self, processor): ) ) - result = await processor.aforward("") + result_prediction = await processor.aforward("") + # aforward now returns a Prediction + assert isinstance(result_prediction, dspy.Prediction) + result = result_prediction.processed_query assert result.original == "" assert result.resources == list(DocumentSource) # Default fallback diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index 55acd62..63c9147 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, Mock, patch +import dspy import pytest from cairo_coder.core.rag_pipeline import ( @@ -149,10 +150,12 @@ async def test_pipeline_with_custom_sources(self, pipeline): assert call_args["sources"] == sources @pytest.mark.asyncio - async def test_empty_documents_handling(self, pipeline, mock_document_retriever): + async def test_empty_documents_handling(self, pipeline): """Test pipeline handling of empty document list.""" - # Configure retriever to return empty list - mock_document_retriever.aforward.return_value = [] + # Configure retriever to return empty prediction + empty_prediction = dspy.Prediction(documents=[]) + empty_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = empty_prediction await pipeline.aforward("test query") @@ -161,17 +164,17 @@ async def test_empty_documents_handling(self, pipeline, mock_document_retriever) assert "No relevant documentation found" in call_args[1]["context"] @pytest.mark.asyncio - async def test_pipeline_error_handling(self, pipeline, mock_document_retriever): + async def test_pipeline_error_handling(self, pipeline): """Test pipeline error handling.""" - # Configure retriever to fail - mock_document_retriever.aforward.side_effect = Exception("Retrieval error") + # Configure pipeline's retriever to fail + pipeline.document_retriever.aforward.side_effect = Exception("Retrieval error") events = [] async for event in pipeline.aforward_streaming("test query"): events.append(event) # Should have an error event - error_events = [e for e in events if e.type == "error"] + error_events = [e for e in events if e.type == StreamEventType.ERROR] assert len(error_events) == 1 assert "Retrieval error" in error_events[0].data @@ -182,7 +185,7 @@ class TestRagPipelineWithJudge: @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio async def test_judge_enabled_filters_documents( - self, mock_judge_class, pipeline, mock_document_retriever + self, mock_judge_class, pipeline ): """Test that judge filters out low-scoring documents.""" # Create documents with varying relevance @@ -193,7 +196,10 @@ async def test_judge_enabled_filters_documents( ("Cairo Storage", "Cairo storage content", "cairo_book"), ] ) - mock_document_retriever.aforward.return_value = docs + # Return prediction with documents - modify the pipeline's document retriever + dr_prediction = dspy.Prediction(documents=docs) + dr_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = dr_prediction # Setup judge with specific scores judge = create_custom_retrieval_judge( @@ -204,7 +210,14 @@ async def test_judge_enabled_filters_documents( } ) # Configure the mock instance that the pipeline will use - pipeline.retrieval_judge.aforward.side_effect = judge.aforward + + async def judge_aforward_with_prediction(query, documents): + result_docs = await judge.aforward(query, documents) + prediction = dspy.Prediction(documents=result_docs) + prediction.set_lm_usage({}) + return prediction + + pipeline.retrieval_judge.aforward.side_effect = judge_aforward_with_prediction pipeline.retrieval_judge.threshold = judge.threshold await pipeline.aforward("Cairo question") @@ -243,10 +256,13 @@ async def test_judge_disabled_passes_all_documents( @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio async def test_judge_threshold_parameterization( - self, mock_judge_class, threshold, sample_documents, pipeline, mock_document_retriever + self, mock_judge_class, threshold, sample_documents, pipeline ): """Test different judge thresholds.""" - mock_document_retriever.aforward.return_value = sample_documents + # Return prediction with documents + dr_prediction = dspy.Prediction(documents=sample_documents) + dr_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = dr_prediction # Judge with scores: 0.9, 0.8, 0.7, 0.6 (based on sample_documents) score_map = { @@ -257,7 +273,14 @@ async def test_judge_threshold_parameterization( } judge = create_custom_retrieval_judge(score_map, threshold=threshold) - pipeline.retrieval_judge.aforward.side_effect = judge.aforward + + async def judge_aforward_with_prediction(query, documents): + result_docs = await judge.aforward(query, documents) + prediction = dspy.Prediction(documents=result_docs) + prediction.set_lm_usage({}) + return prediction + + pipeline.retrieval_judge.aforward.side_effect = judge_aforward_with_prediction pipeline.retrieval_judge.threshold = judge.threshold await pipeline.aforward("test query") @@ -295,7 +318,7 @@ async def test_judge_failure_fallback(self, mock_judge_class, sample_documents, @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio async def test_judge_parse_error_handling( - self, mock_judge_class, pipeline, mock_document_retriever + self, mock_judge_class, pipeline ): """Test handling of parse errors in judge scores.""" docs = create_custom_documents( @@ -304,12 +327,13 @@ async def test_judge_parse_error_handling( ("Doc2", "Content2", "source2"), ] ) - mock_document_retriever.aforward.return_value = docs + # Return prediction with documents + dr_prediction = dspy.Prediction(documents=docs) + dr_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = dr_prediction # Create judge that returns invalid score - judge = Mock(spec=RetrievalJudge) - - def filter_with_parse_error(query, documents): + async def filter_with_parse_error(query, documents): # First doc gets invalid score, second gets valid documents[0].metadata["llm_judge_score"] = "invalid" # Will cause parse error documents[0].metadata["llm_judge_reason"] = "Parse error" @@ -319,13 +343,12 @@ def filter_with_parse_error(query, documents): # In the real implementation, docs with parse errors are now DROPPED. # The mock's side effect must replicate the real judge's behavior. - return [documents[1]] + prediction = dspy.Prediction(documents=[documents[1]]) + prediction.set_lm_usage({}) + return prediction - judge.aforward = AsyncMock(side_effect=filter_with_parse_error) - judge.threshold = 0.5 - - pipeline.retrieval_judge.aforward.side_effect = judge.aforward - pipeline.retrieval_judge.threshold = judge.threshold + pipeline.retrieval_judge.aforward.side_effect = filter_with_parse_error + pipeline.retrieval_judge.threshold = 0.5 await pipeline.aforward("test query") @@ -349,9 +372,21 @@ async def test_async_judge_execution(self, mock_judge_class, pipeline, mock_retr @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio - async def test_streaming_with_judge(self, mock_judge_class, pipeline, mock_retrieval_judge): + async def test_streaming_with_judge(self, mock_judge_class, pipeline, mock_retrieval_judge, sample_documents): """Test streaming execution with judge.""" - pipeline.retrieval_judge.aforward.side_effect = mock_retrieval_judge.aforward + # Return prediction with documents + dr_prediction = dspy.Prediction(documents=sample_documents) + dr_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = dr_prediction + + # Set up judge to return prediction + async def judge_aforward_with_prediction(query, documents): + result_docs = await mock_retrieval_judge.aforward(query, documents) + prediction = dspy.Prediction(documents=result_docs) + prediction.set_lm_usage({}) + return prediction + + pipeline.retrieval_judge.aforward.side_effect = judge_aforward_with_prediction events = [] async for event in pipeline.aforward_streaming("test query"): @@ -369,14 +404,24 @@ async def test_streaming_with_judge(self, mock_judge_class, pipeline, mock_retri @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio async def test_judge_metadata_enrichment( - self, mock_judge_class, pipeline, mock_document_retriever + self, mock_judge_class, pipeline ): """Test that judge adds metadata to documents.""" docs = create_custom_documents([("Test Doc", "Test content", "test_source")]) - mock_document_retriever.aforward.return_value = docs + # Return prediction with documents + dr_prediction = dspy.Prediction(documents=docs) + dr_prediction.set_lm_usage({}) + pipeline.document_retriever.aforward.return_value = dr_prediction judge = create_custom_retrieval_judge({"Test Doc": 0.75}) - pipeline.retrieval_judge.aforward.side_effect = judge.aforward + + async def judge_aforward_with_prediction(query, documents): + result_docs = await judge.aforward(query, documents) + prediction = dspy.Prediction(documents=result_docs) + prediction.set_lm_usage({}) + return prediction + + pipeline.retrieval_judge.aforward.side_effect = judge_aforward_with_prediction await pipeline.aforward("test query") @@ -633,28 +678,46 @@ def test_get_current_state(self, sample_documents, sample_processed_query, pipel pytest.param({}, {}, _JUDGE_USAGE, _JUDGE_USAGE, id="only_judge_usage"), ], ) + @pytest.mark.asyncio @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_get_lm_usage( + async def test_get_lm_usage( self, mock_judge_class, - pipeline, - mock_query_processor, - mock_generation_program, + pipeline_config, + sample_processed_query, + sample_documents, query_usage, generation_usage, judge_usage, expected_usage, ): """Tests that get_lm_usage correctly aggregates token usage from its components.""" - mock_query_processor.get_lm_usage.return_value = query_usage - mock_generation_program.get_lm_usage.return_value = generation_usage - pipeline.retrieval_judge.get_lm_usage.return_value = judge_usage + # Set up mocks to return predictions with usage + qp_prediction = dspy.Prediction(processed_query=sample_processed_query) + qp_prediction.set_lm_usage(query_usage) + pipeline_config.query_processor.aforward = AsyncMock(return_value=qp_prediction) + + dr_prediction = dspy.Prediction(documents=sample_documents) + dr_prediction.set_lm_usage({}) + pipeline_config.document_retriever.aforward = AsyncMock(return_value=dr_prediction) + + gen_prediction = dspy.Prediction(answer="Test answer") + gen_prediction.set_lm_usage(generation_usage) + pipeline_config.generation_program.aforward = AsyncMock(return_value=gen_prediction) + + # Set up judge mock + judge_prediction = dspy.Prediction(documents=sample_documents) + judge_prediction.set_lm_usage(judge_usage) + mock_judge = Mock() + mock_judge.aforward = AsyncMock(return_value=judge_prediction) + mock_judge_class.return_value = mock_judge + + # Create pipeline and execute it + pipeline = RagPipeline(pipeline_config) + await pipeline.aforward("test query") + # Check accumulated usage result = pipeline.get_lm_usage() - - pipeline.query_processor.get_lm_usage.assert_called_once() - pipeline.generation_program.get_lm_usage.assert_called_once() - pipeline.retrieval_judge.get_lm_usage.assert_called_once() assert result == expected_usage @pytest.mark.asyncio @@ -667,24 +730,52 @@ def test_get_lm_usage( ), ], ) - async def test_get_lm_usage_after_streaming(self, pipeline_config, mcp_mode, expected_usage): + async def test_get_lm_usage_after_streaming( + self, pipeline_config, sample_processed_query, sample_documents, mcp_mode, expected_usage + ): """Tests that get_lm_usage works correctly after a streaming execution.""" - # To test token aggregation, we mock the return values of sub-components' - # get_lm_usage methods. The test logic simulates which components would - # be "active" in each mode by setting others to return empty usage. - pipeline_config.query_processor.get_lm_usage.return_value = self._QUERY_USAGE_MINI + # Set up query processor to return prediction with usage + qp_prediction = dspy.Prediction(processed_query=sample_processed_query) + qp_prediction.set_lm_usage(self._QUERY_USAGE_MINI) + pipeline_config.query_processor.aforward = AsyncMock(return_value=qp_prediction) + + # Set up document retriever + dr_prediction = dspy.Prediction(documents=sample_documents) + dr_prediction.set_lm_usage({}) + pipeline_config.document_retriever.aforward = AsyncMock(return_value=dr_prediction) + if mcp_mode: - pipeline_config.generation_program.get_lm_usage.return_value = {} - # MCP program doesn't use an LM, so its usage is empty - pipeline_config.mcp_generation_program.get_lm_usage.return_value = {} + # MCP mode - set MCP program with no usage + mcp_prediction = dspy.Prediction(answer="MCP answer") + mcp_prediction.set_lm_usage({}) + pipeline_config.mcp_generation_program.aforward = AsyncMock(return_value=mcp_prediction) else: - pipeline_config.generation_program.get_lm_usage.return_value = self._GEN_USAGE_FULL - pipeline_config.mcp_generation_program.get_lm_usage.return_value = {} + # Normal mode - set generation program with usage + async def mock_streaming(*args, **kwargs): + yield dspy.streaming.StreamResponse( + predict_name="GenerationProgram", + signature_field_name="answer", + chunk="Test ", + is_last_chunk=False, + ) + yield dspy.streaming.StreamResponse( + predict_name="GenerationProgram", + signature_field_name="answer", + chunk="answer", + is_last_chunk=True, + ) + gen_prediction = dspy.Prediction(answer="Test answer") + gen_prediction.set_lm_usage(self._GEN_USAGE_FULL) + yield gen_prediction + + pipeline_config.generation_program.aforward_streaming = mock_streaming - # Patch the RetrievalJudge to have a proper get_lm_usage method + # Patch the RetrievalJudge with patch("cairo_coder.core.rag_pipeline.RetrievalJudge") as mock_judge_class: + judge_prediction = dspy.Prediction(documents=sample_documents) + judge_prediction.set_lm_usage({}) mock_judge = Mock() - mock_judge.get_lm_usage.return_value = {} + mock_judge.aforward = AsyncMock(return_value=judge_prediction) mock_judge_class.return_value = mock_judge pipeline = RagPipeline(pipeline_config) @@ -698,8 +789,6 @@ async def test_get_lm_usage_after_streaming(self, pipeline_config, mcp_mode, exp result = pipeline.get_lm_usage() assert result == expected_usage - pipeline.query_processor.get_lm_usage.assert_called() - pipeline.generation_program.get_lm_usage.assert_called() class TestConvenienceFunctions: diff --git a/python/tests/unit/test_retrieval_judge.py b/python/tests/unit/test_retrieval_judge.py index f2d885d..9fbe584 100644 --- a/python/tests/unit/test_retrieval_judge.py +++ b/python/tests/unit/test_retrieval_judge.py @@ -43,8 +43,9 @@ async def test_retrieval_judge_initialization(self): async def test_aforward_empty_documents(self): """Test forward with empty document list.""" judge = RetrievalJudge() - result = await judge.aforward("test query", []) - assert result == [] + prediction = await judge.aforward("test query", []) + assert isinstance(prediction, dspy.Prediction) + assert prediction.documents == [] @pytest.mark.asyncio async def test_aforward_with_mocked_rater(self, sample_documents): @@ -58,7 +59,8 @@ async def test_aforward_with_mocked_rater(self, sample_documents): ]) documents = sample_documents - filtered_docs = await judge.aforward("How to write Cairo programs?", documents) + prediction = await judge.aforward("How to write Cairo programs?", documents) + filtered_docs = prediction.documents # Assertions assert len(filtered_docs) == 1 # Only first doc passes threshold @@ -77,7 +79,8 @@ async def test_forward_with_parse_error(self, sample_documents): MagicMock(resource_note=0.7, reasoning="Valid result"), ]) documents = sample_documents[:2] - filtered_docs = await judge.aforward("test query", documents) + prediction = await judge.aforward("test query", documents) + filtered_docs = prediction.documents # Should only keep the doc that was successfully parsed and scored above threshold. assert len(filtered_docs) == 1 @@ -95,7 +98,8 @@ async def test_aforward_with_exception(self, sample_documents): judge = RetrievalJudge() judge.rater.acall = AsyncMock(side_effect=Exception("Parallel execution failed")) documents = sample_documents - filtered_docs = await judge.aforward("test query", documents) + prediction = await judge.aforward("test query", documents) + filtered_docs = prediction.documents # Should return all documents on failure assert len(filtered_docs) == len(documents) @@ -105,7 +109,7 @@ async def test_aforward_with_exception(self, sample_documents): async def test_aforward_with_contract_and_test_templates(self, sample_documents): """Test forward with contract template.""" judge = RetrievalJudge() - result = await judge.aforward( + prediction = await judge.aforward( "test query", [ Document( @@ -118,6 +122,7 @@ async def test_aforward_with_contract_and_test_templates(self, sample_documents) ), ], ) + result = prediction.documents assert result == [ Document( page_content="", @@ -133,7 +138,7 @@ async def test_aforward_with_contract_and_test_templates(self, sample_documents) async def test_aforward_with_contract_template(self, sample_documents): """Test async forward with contract template.""" judge = RetrievalJudge() - result = await judge.aforward( + prediction = await judge.aforward( "test query", [ Document( @@ -146,6 +151,7 @@ async def test_aforward_with_contract_template(self, sample_documents): ), ], ) + result = prediction.documents assert result == [ Document( page_content="", @@ -168,7 +174,8 @@ async def test_score_clamping(self, sample_documents): MagicMock(resource_note=0.5, reasoning="Valid score"), ]) documents = sample_documents - filtered_docs = await judge.aforward("test", documents) + prediction = await judge.aforward("test", documents) + filtered_docs = prediction.documents # Check scores are clamped and filtering works assert len(filtered_docs) == 2 # Only 2 docs pass threshold of 0.4 diff --git a/python/uv.lock b/python/uv.lock index 4f33b46..638cd07 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -345,7 +345,7 @@ wheels = [ [[package]] name = "cairo-coder" -version = "0.3.0" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "aiohttp" },