Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This repo uses a unified, deterministic testing infrastructure to keep tests fas
- Unit client uses `mock_agent_factory` and `mock_vector_db`.
- Integration client injects a real `RagPipeline` wired to `mock_query_processor` + `mock_vector_db` (via the same `mock_agent_factory`).
- Replace ad‑hoc stubs with shared fixtures: `sample_processed_query`, `mock_query_processor`, `sample_documents`, and `mock_returned_documents` (built from `sample_documents`).
- Respect declared types. When a signature says the argument is type `T`, never guard it with `is None` or `hasattr` checks for `T`'s own surface area—just call the method and let the type system show bugs. (Example: if something is typed `dspy.Prediction`, call `get_lm_usage()` directly and set usage via `set_lm_usage`. Don't assume these attributes are not present.)

## DSPy/LLM Behavior

Expand Down
55 changes: 26 additions & 29 deletions python/src/cairo_coder/agents/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
agent system with a simple, in-memory registry of available agents.
"""

from dataclasses import dataclass
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any

from cairo_coder.core.config import VectorStoreConfig
from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineFactory
Expand All @@ -33,7 +35,8 @@ class AgentSpec:
name: str
description: str
sources: list[DocumentSource]
generation_program_type: AgentId
pipeline_builder: Callable[..., RagPipeline]
builder_kwargs: dict[str, Any] = field(default_factory=dict)
max_source_count: int = 5
similarity_threshold: float = 0.4

Expand All @@ -48,31 +51,15 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
Returns:
Configured RagPipeline instance
"""
match self.generation_program_type:
case AgentId.STARKNET:
return RagPipelineFactory.create_pipeline(
name=self.name,
vector_store_config=vector_store_config,
sources=self.sources,
query_processor=create_query_processor(),
generation_program=create_generation_program(AgentId.STARKNET),
mcp_generation_program=create_mcp_generation_program(),
max_source_count=self.max_source_count,
similarity_threshold=self.similarity_threshold,
vector_db=vector_db,
)
case AgentId.CAIRO_CODER:
return RagPipelineFactory.create_pipeline(
name=self.name,
vector_store_config=vector_store_config,
sources=self.sources,
query_processor=create_query_processor(),
generation_program=create_generation_program(AgentId.CAIRO_CODER),
mcp_generation_program=create_mcp_generation_program(),
max_source_count=self.max_source_count,
similarity_threshold=self.similarity_threshold,
vector_db=vector_db,
)
return self.pipeline_builder(
name=self.name,
vector_store_config=vector_store_config,
vector_db=vector_db,
sources=self.sources,
max_source_count=self.max_source_count,
similarity_threshold=self.similarity_threshold,
**self.builder_kwargs,
)


# The global registry of available agents
Expand All @@ -81,15 +68,25 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
name="Cairo Coder",
description="General Cairo programming assistant",
sources=list(DocumentSource), # All sources
generation_program_type=AgentId.CAIRO_CODER,
pipeline_builder=RagPipelineFactory.create_pipeline,
builder_kwargs={
"query_processor": create_query_processor(),
"generation_program": create_generation_program(AgentId.CAIRO_CODER),
"mcp_generation_program": create_mcp_generation_program(),
},
max_source_count=5,
similarity_threshold=0.4,
),
AgentId.STARKNET: AgentSpec(
name="Starknet Agent",
description="Assistant for the Starknet ecosystem (contracts, tools, docs).",
sources=list(DocumentSource),
generation_program_type=AgentId.STARKNET,
pipeline_builder=RagPipelineFactory.create_pipeline,
builder_kwargs={
"query_processor": create_query_processor(),
"generation_program": create_generation_program(AgentId.STARKNET),
"mcp_generation_program": create_mcp_generation_program(),
},
max_source_count=5,
similarity_threshold=0.4,
),
Expand Down
84 changes: 54 additions & 30 deletions python/src/cairo_coder/core/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from cairo_coder.core.types import (
Document,
DocumentSource,
FormattedSource,
Message,
ProcessedQuery,
StreamEvent,
StreamEventType,
combine_usage,
title_from_url,
)
from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram
Expand Down Expand Up @@ -82,33 +84,57 @@ def __init__(self, config: RagPipelineConfig):
self._current_processed_query: ProcessedQuery | None = None
self._current_documents: list[Document] = []

# Token usage accumulator
self._accumulated_usage: dict[str, dict[str, int]] = {}

@property
def last_retrieved_documents(self) -> list[Document]:
"""Documents retrieved during the most recent pipeline execution."""
return self._current_documents

def _accumulate_usage(self, prediction: dspy.Prediction) -> None:
"""
Accumulate token usage from a prediction.

Args:
prediction: DSPy prediction object with usage information
"""
usage = prediction.get_lm_usage()
self._accumulated_usage = combine_usage(self._accumulated_usage, usage)

def _reset_usage(self) -> None:
"""Reset accumulated usage for a new request."""
self._accumulated_usage = {}

async def _aprocess_query_and_retrieve_docs(
self,
query: str,
chat_history_str: str,
sources: list[DocumentSource] | None = None,
) -> tuple[ProcessedQuery, list[Document]]:
"""Process query and retrieve documents - shared async logic."""
processed_query = await self.query_processor.aforward(
qp_prediction = await self.query_processor.aforward(
query=query, chat_history=chat_history_str
)
self._accumulate_usage(qp_prediction)
processed_query = qp_prediction.processed_query
self._current_processed_query = processed_query

# Use provided sources or fall back to processed query sources
retrieval_sources = sources or processed_query.resources
documents = await self.document_retriever.aforward(
dr_prediction = await self.document_retriever.aforward(
processed_query=processed_query, sources=retrieval_sources
)
self._accumulate_usage(dr_prediction)
documents = dr_prediction.documents

# Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources.
try:
if DocumentSource.STARKNET_BLOG in retrieval_sources:
grok_docs = await self.grok_search.aforward(processed_query, chat_history_str)
grok_pred = await self.grok_search.aforward(processed_query, chat_history_str)
self._accumulate_usage(grok_pred)
grok_docs = grok_pred.documents

self._grok_citations = list(self.grok_search.last_citations)
if grok_docs:
documents.extend(grok_docs)
Expand All @@ -126,7 +152,9 @@ async def _aprocess_query_and_retrieve_docs(
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5),
adapter=XMLAdapter(),
):
documents = await self.retrieval_judge.aforward(query=query, documents=documents)
judge_pred = await self.retrieval_judge.aforward(query=query, documents=documents)
self._accumulate_usage(judge_pred)
documents = judge_pred.documents
except Exception as e:
logger.warning(
"Retrieval judge failed (async), using all documents",
Expand Down Expand Up @@ -158,6 +186,9 @@ async def aforward(
mcp_mode: bool = False,
sources: list[DocumentSource] | None = None,
) -> dspy.Prediction:
# Reset usage for this request
self._reset_usage()

chat_history_str = self._format_chat_history(chat_history or [])
processed_query, documents = await self._aprocess_query_and_retrieve_docs(
query, chat_history_str, sources
Expand All @@ -167,13 +198,21 @@ async def aforward(
)

if mcp_mode:
return await self.mcp_generation_program.aforward(documents)
result = await self.mcp_generation_program.aforward(documents)
self._accumulate_usage(result)
result.set_lm_usage(self._accumulated_usage)
return result

context = self._prepare_context(documents)

return await self.generation_program.aforward(
result = await self.generation_program.aforward(
query=query, context=context, chat_history=chat_history_str
)
if result:
self._accumulate_usage(result)
# Update the result's usage to include accumulated usage from previous steps
result.set_lm_usage(self._accumulated_usage)
return result


async def aforward_streaming(
Expand Down Expand Up @@ -251,6 +290,7 @@ async def aforward_streaming(
logger.warning(f"Unknown signature field name: {chunk.signature_field_name}")
elif isinstance(chunk, dspy.Prediction):
# Final complete answer
self._accumulate_usage(chunk)
final_text = getattr(chunk, "answer", None) or chunk_accumulator
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text)
rt.end(outputs={"output": final_text})
Expand All @@ -268,28 +308,12 @@ async def aforward_streaming(

def get_lm_usage(self) -> dict[str, dict[str, int]]:
"""
Get the total number of tokens used by the LLM.
"""
generation_usage = self.generation_program.get_lm_usage()
query_usage = self.query_processor.get_lm_usage()
judge_usage = self.retrieval_judge.get_lm_usage()

# Additive merge strategy
merged_usage = {}

# Helper function to merge usage dictionaries
def merge_usage_dict(target: dict, source: dict) -> None:
for model_name, metrics in source.items():
if model_name not in target:
target[model_name] = {}
for metric_name, value in metrics.items():
target[model_name][metric_name] = target[model_name].get(metric_name, 0) + value
Get accumulated token usage from all predictions in the pipeline.

merge_usage_dict(merged_usage, generation_usage)
merge_usage_dict(merged_usage, query_usage)
merge_usage_dict(merged_usage, judge_usage)

return merged_usage
Returns:
Dictionary mapping model names to usage metrics
"""
return self._accumulated_usage

def _format_chat_history(self, chat_history: list[Message]) -> str:
"""
Expand All @@ -311,7 +335,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str:

return "\n".join(formatted_messages)

def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
def _format_sources(self, documents: list[Document]) -> list[FormattedSource]:
"""
Format documents for the frontend-friendly sources event.

Expand All @@ -322,9 +346,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
documents: List of retrieved documents

Returns:
List of dicts: [{"title": str, "url": str}, ...]
List of formatted sources with metadata
"""
sources: list[dict[str, str]] = []
sources: list[FormattedSource] = []
seen_urls: set[str] = set()


Expand Down
50 changes: 50 additions & 0 deletions python/src/cairo_coder/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@ class ProcessedQuery:
is_test_related: bool = False
resources: list[DocumentSource] = field(default_factory=list)

LMUsageEntry = dict[str, Any]
LMUsage = dict[str, LMUsageEntry]


class RetrievedSourceData(TypedDict):
"""Structure for retrieved source data stored in database."""

page_content: str
metadata: DocumentMetadata


class FormattedSourceMetadata(TypedDict):
"""Metadata structure for formatted sources sent to frontend."""

title: str
url: str
source_type: str


class FormattedSource(TypedDict):
"""Structure for formatted sources sent to frontend."""

metadata: FormattedSourceMetadata

# Helper to extract domain title
def title_from_url(url: str) -> str:
Expand Down Expand Up @@ -174,6 +197,33 @@ def to_dict(self) -> dict[str, Any]:
"details": self.details,
"timestamp": self.timestamp.isoformat(),
}


def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
"""Combine two LM usage dictionaries, tolerating missing inputs."""
result: LMUsage = {model: (metrics or {}).copy() for model, metrics in usage1.items()}

for model, metrics in usage2.items():
if model not in result:
result[model] = metrics.copy()
else:
# Merge metrics
for key, value in metrics.items():
if isinstance(value, int | float):
result[model][key] = result[model].get(key, 0) + value
elif isinstance(value, dict):
if key not in result[model] or result[model][key] is None:
result[model][key] = value.copy()
else:
# Recursive merge for nested dicts
for detail_key, detail_value in value.items():
if isinstance(detail_value, int | float):
result[model][key][detail_key] = (
result[model][key].get(detail_key, 0) + detail_value
)
return result


class AgentResponse(BaseModel):
"""Response from agent processing."""

Expand Down
4 changes: 3 additions & 1 deletion python/src/cairo_coder/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from pydantic import BaseModel, Field

from cairo_coder.core.types import RetrievedSourceData


class UserInteraction(BaseModel):
"""Represents a record in the user_interactions table."""
Expand All @@ -21,5 +23,5 @@ class UserInteraction(BaseModel):
chat_history: Optional[list[dict[str, Any]]] = None
query: str
generated_answer: Optional[str] = None
retrieved_sources: Optional[list[dict[str, Any]]] = None
retrieved_sources: Optional[list[RetrievedSourceData]] = None
llm_usage: Optional[dict[str, Any]] = None
Loading