Skip to content

Commit 502c252

Browse files
committed
tmp
1 parent a636e92 commit 502c252

File tree

10 files changed

+204
-218
lines changed

10 files changed

+204
-218
lines changed

python/src/cairo_coder/agents/registry.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
agent system with a simple, in-memory registry of available agents.
66
"""
77

8-
from dataclasses import dataclass
8+
from collections.abc import Callable
9+
from dataclasses import dataclass, field
910
from enum import Enum
11+
from typing import Any
1012

1113
from cairo_coder.core.config import VectorStoreConfig
1214
from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineFactory
@@ -33,7 +35,8 @@ class AgentSpec:
3335
name: str
3436
description: str
3537
sources: list[DocumentSource]
36-
generation_program_type: AgentId
38+
pipeline_builder: Callable[..., RagPipeline]
39+
builder_kwargs: dict[str, Any] = field(default_factory=dict)
3740
max_source_count: int = 5
3841
similarity_threshold: float = 0.4
3942

@@ -48,31 +51,15 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
4851
Returns:
4952
Configured RagPipeline instance
5053
"""
51-
match self.generation_program_type:
52-
case AgentId.STARKNET:
53-
return RagPipelineFactory.create_pipeline(
54-
name=self.name,
55-
vector_store_config=vector_store_config,
56-
sources=self.sources,
57-
query_processor=create_query_processor(),
58-
generation_program=create_generation_program(AgentId.STARKNET),
59-
mcp_generation_program=create_mcp_generation_program(),
60-
max_source_count=self.max_source_count,
61-
similarity_threshold=self.similarity_threshold,
62-
vector_db=vector_db,
63-
)
64-
case AgentId.CAIRO_CODER:
65-
return RagPipelineFactory.create_pipeline(
66-
name=self.name,
67-
vector_store_config=vector_store_config,
68-
sources=self.sources,
69-
query_processor=create_query_processor(),
70-
generation_program=create_generation_program(AgentId.CAIRO_CODER),
71-
mcp_generation_program=create_mcp_generation_program(),
72-
max_source_count=self.max_source_count,
73-
similarity_threshold=self.similarity_threshold,
74-
vector_db=vector_db,
75-
)
54+
return self.pipeline_builder(
55+
name=self.name,
56+
vector_store_config=vector_store_config,
57+
vector_db=vector_db,
58+
sources=self.sources,
59+
max_source_count=self.max_source_count,
60+
similarity_threshold=self.similarity_threshold,
61+
**self.builder_kwargs,
62+
)
7663

7764

7865
# The global registry of available agents
@@ -81,15 +68,25 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
8168
name="Cairo Coder",
8269
description="General Cairo programming assistant",
8370
sources=list(DocumentSource), # All sources
84-
generation_program_type=AgentId.CAIRO_CODER,
71+
pipeline_builder=RagPipelineFactory.create_pipeline,
72+
builder_kwargs={
73+
"query_processor": create_query_processor(),
74+
"generation_program": create_generation_program(AgentId.CAIRO_CODER),
75+
"mcp_generation_program": create_mcp_generation_program(),
76+
},
8577
max_source_count=5,
8678
similarity_threshold=0.4,
8779
),
8880
AgentId.STARKNET: AgentSpec(
8981
name="Starknet Agent",
9082
description="Assistant for the Starknet ecosystem (contracts, tools, docs).",
9183
sources=list(DocumentSource),
92-
generation_program_type=AgentId.STARKNET,
84+
pipeline_builder=RagPipelineFactory.create_pipeline,
85+
builder_kwargs={
86+
"query_processor": create_query_processor(),
87+
"generation_program": create_generation_program(AgentId.STARKNET),
88+
"mcp_generation_program": create_mcp_generation_program(),
89+
},
9390
max_source_count=5,
9491
similarity_threshold=0.4,
9592
),

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from cairo_coder.core.types import (
2020
Document,
2121
DocumentSource,
22+
FormattedSource,
2223
Message,
2324
ProcessedQuery,
2425
StreamEvent,
@@ -82,11 +83,34 @@ def __init__(self, config: RagPipelineConfig):
8283
self._current_processed_query: ProcessedQuery | None = None
8384
self._current_documents: list[Document] = []
8485

86+
# Token usage accumulator
87+
self._accumulated_usage: dict[str, dict[str, int]] = {}
88+
8589
@property
8690
def last_retrieved_documents(self) -> list[Document]:
8791
"""Documents retrieved during the most recent pipeline execution."""
8892
return self._current_documents
8993

94+
def _accumulate_usage(self, prediction: dspy.Prediction) -> None:
95+
"""
96+
Accumulate token usage from a prediction.
97+
98+
Args:
99+
prediction: DSPy prediction object with usage information
100+
"""
101+
usage = prediction.get_lm_usage();
102+
for model_name, metrics in usage.items():
103+
if model_name not in self._accumulated_usage:
104+
self._accumulated_usage[model_name] = {}
105+
for metric_name, value in metrics.items():
106+
self._accumulated_usage[model_name][metric_name] = (
107+
self._accumulated_usage[model_name].get(metric_name, 0) + value
108+
)
109+
110+
def _reset_usage(self) -> None:
111+
"""Reset accumulated usage for a new request."""
112+
self._accumulated_usage = {}
113+
90114
async def _aprocess_query_and_retrieve_docs(
91115
self,
92116
query: str,
@@ -97,6 +121,7 @@ async def _aprocess_query_and_retrieve_docs(
97121
processed_query = await self.query_processor.aforward(
98122
query=query, chat_history=chat_history_str
99123
)
124+
self._accumulate_usage(processed_query)
100125
self._current_processed_query = processed_query
101126

102127
# Use provided sources or fall back to processed query sources
@@ -158,6 +183,9 @@ async def aforward(
158183
mcp_mode: bool = False,
159184
sources: list[DocumentSource] | None = None,
160185
) -> dspy.Prediction:
186+
# Reset usage for this request
187+
self._reset_usage()
188+
161189
chat_history_str = self._format_chat_history(chat_history or [])
162190
processed_query, documents = await self._aprocess_query_and_retrieve_docs(
163191
query, chat_history_str, sources
@@ -167,13 +195,17 @@ async def aforward(
167195
)
168196

169197
if mcp_mode:
170-
return await self.mcp_generation_program.aforward(documents)
198+
result = await self.mcp_generation_program.aforward(documents)
199+
self._accumulate_usage(result)
200+
return result
171201

172202
context = self._prepare_context(documents)
173203

174-
return await self.generation_program.aforward(
204+
result = await self.generation_program.aforward(
175205
query=query, context=context, chat_history=chat_history_str
176206
)
207+
self._accumulate_usage(result)
208+
return result
177209

178210

179211
async def aforward_streaming(
@@ -268,28 +300,12 @@ async def aforward_streaming(
268300

269301
def get_lm_usage(self) -> dict[str, dict[str, int]]:
270302
"""
271-
Get the total number of tokens used by the LLM.
272-
"""
273-
generation_usage = self.generation_program.get_lm_usage()
274-
query_usage = self.query_processor.get_lm_usage()
275-
judge_usage = self.retrieval_judge.get_lm_usage()
276-
277-
# Additive merge strategy
278-
merged_usage = {}
279-
280-
# Helper function to merge usage dictionaries
281-
def merge_usage_dict(target: dict, source: dict) -> None:
282-
for model_name, metrics in source.items():
283-
if model_name not in target:
284-
target[model_name] = {}
285-
for metric_name, value in metrics.items():
286-
target[model_name][metric_name] = target[model_name].get(metric_name, 0) + value
303+
Get accumulated token usage from all predictions in the pipeline.
287304
288-
merge_usage_dict(merged_usage, generation_usage)
289-
merge_usage_dict(merged_usage, query_usage)
290-
merge_usage_dict(merged_usage, judge_usage)
291-
292-
return merged_usage
305+
Returns:
306+
Dictionary mapping model names to usage metrics
307+
"""
308+
return self._accumulated_usage
293309

294310
def _format_chat_history(self, chat_history: list[Message]) -> str:
295311
"""
@@ -311,7 +327,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str:
311327

312328
return "\n".join(formatted_messages)
313329

314-
def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
330+
def _format_sources(self, documents: list[Document]) -> list[FormattedSource]:
315331
"""
316332
Format documents for the frontend-friendly sources event.
317333
@@ -322,9 +338,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
322338
documents: List of retrieved documents
323339
324340
Returns:
325-
List of dicts: [{"title": str, "url": str}, ...]
341+
List of formatted sources with metadata
326342
"""
327-
sources: list[dict[str, str]] = []
343+
sources: list[FormattedSource] = []
328344
seen_urls: set[str] = set()
329345

330346

python/src/cairo_coder/core/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,29 @@ class ProcessedQuery:
7474
is_test_related: bool = False
7575
resources: list[DocumentSource] = field(default_factory=list)
7676

77+
LMUsageEntry = dict[str, Any]
78+
LMUsage = dict[str, LMUsageEntry]
79+
80+
81+
class RetrievedSourceData(TypedDict):
82+
"""Structure for retrieved source data stored in database."""
83+
84+
page_content: str
85+
metadata: DocumentMetadata
86+
87+
88+
class FormattedSourceMetadata(TypedDict):
89+
"""Metadata structure for formatted sources sent to frontend."""
90+
91+
title: str
92+
url: str
93+
source_type: str
94+
95+
96+
class FormattedSource(TypedDict):
97+
"""Structure for formatted sources sent to frontend."""
98+
99+
metadata: FormattedSourceMetadata
77100

78101
# Helper to extract domain title
79102
def title_from_url(url: str) -> str:

python/src/cairo_coder/db/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from pydantic import BaseModel, Field
1212

13+
from cairo_coder.core.types import RetrievedSourceData
14+
1315

1416
class UserInteraction(BaseModel):
1517
"""Represents a record in the user_interactions table."""
@@ -21,5 +23,5 @@ class UserInteraction(BaseModel):
2123
chat_history: Optional[list[dict[str, Any]]] = None
2224
query: str
2325
generated_answer: Optional[str] = None
24-
retrieved_sources: Optional[list[dict[str, Any]]] = None
26+
retrieved_sources: Optional[list[RetrievedSourceData]] = None
2527
llm_usage: Optional[dict[str, Any]] = None

python/src/cairo_coder/dspy/generation_program.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,6 @@ def __init__(self, program_type):
192192
raise FileNotFoundError(f"{compiled_program_path} not found")
193193
self.generation_program.load(compiled_program_path)
194194

195-
def get_lm_usage(self) -> dict[str, int]:
196-
"""
197-
Get the total number of tokens used by the LLM.
198-
"""
199-
return self.generation_program.get_lm_usage()
200-
201195
@traceable(
202196
name="GenerationProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm}
203197
)
@@ -339,14 +333,6 @@ async def aforward(self, documents: list[Document]) -> dspy.Prediction:
339333
"""
340334
return self(documents)
341335

342-
def get_lm_usage(self) -> dict[str, int]:
343-
"""
344-
Get the total number of tokens used by the LLM.
345-
Note: MCP mode doesn't use LLM generation, so no tokens are consumed.
346-
"""
347-
# MCP mode doesn't use LLM generation, return empty dict
348-
return {}
349-
350336

351337
def create_generation_program(program_type: str) -> GenerationProgram:
352338
"""

python/src/cairo_coder/dspy/grok_search.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ async def aforward(self, processed_query: ProcessedQuery, chat_history: str) ->
103103
{', '.join(processed_query.search_queries)}. \
104104
Make sure that your final answer will contain links to the relevant sources used to construct your answer.
105105
"""
106+
# TODO: track LM usage
106107
chat = self.client.chat.create(
107108
model=DEFAULT_GROK_MODEL,
108109
tools=[web_search(), x_search()],

python/src/cairo_coder/dspy/query_processor.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from langsmith import traceable
1414

1515
import dspy
16-
from cairo_coder.core.types import DocumentSource, ProcessedQuery
16+
from cairo_coder.core.types import DocumentSource, ProcessedQuery, LMUsage
1717

1818
logger = structlog.get_logger(__name__)
1919

@@ -125,7 +125,7 @@ def __init__(self):
125125
}
126126

127127
@traceable(name="QueryProcessorProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
128-
async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery:
128+
async def aforward(self, query: str, chat_history: Optional[str] = None) -> tuple[ProcessedQuery, LMUsage]:
129129
"""
130130
Process a user query into a structured format for document retrieval.
131131
@@ -144,19 +144,14 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> Proc
144144
resources = self._validate_resources(result.resources)
145145

146146
# Build structured query result
147-
return ProcessedQuery(
147+
processed_query = ProcessedQuery(
148148
original=query,
149149
search_queries=search_queries,
150150
is_contract_related=self._is_contract_query(query),
151151
is_test_related=self._is_test_query(query),
152152
resources=resources,
153153
)
154-
155-
def get_lm_usage(self) -> dict[str, int]:
156-
"""
157-
Get the total number of tokens used by the LLM.
158-
"""
159-
return self.retrieval_program.get_lm_usage()
154+
return processed_query, result.get_lm_usage()
160155

161156
def _validate_resources(self, resources: list[str]) -> list[DocumentSource]:
162157
"""

python/src/cairo_coder/dspy/retrieval_judge.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,6 @@ async def judge_one(doc_string: str):
171171

172172
return keep_docs
173173

174-
def get_lm_usage(self) -> dict[str, int]:
175-
"""
176-
Get the total number of tokens used by the LLM.
177-
"""
178-
return self.rater.get_lm_usage()
179-
180174
# =========================
181175
# Internal Helpers
182176
# =========================

0 commit comments

Comments
 (0)