Skip to content

Commit b299af4

Browse files
committed
fix: token usage calculation
1 parent 502c252 commit b299af4

19 files changed

+453
-178
lines changed

python/AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ This repo uses a unified, deterministic testing infrastructure to keep tests fas
1212
- Unit client uses `mock_agent_factory` and `mock_vector_db`.
1313
- Integration client injects a real `RagPipeline` wired to `mock_query_processor` + `mock_vector_db` (via the same `mock_agent_factory`).
1414
- Replace ad‑hoc stubs with shared fixtures: `sample_processed_query`, `mock_query_processor`, `sample_documents`, and `mock_returned_documents` (built from `sample_documents`).
15+
- Respect declared types. When a signature says the argument is type `T`, never guard it with `is None` or `hasattr` checks for `T`'s own surface area—just call the method and let the type system show bugs. (Example: if something is typed `dspy.Prediction`, call `get_lm_usage()` directly and set usage via `set_lm_usage`. Don't assume these attributes are not present.)
1516

1617
## DSPy/LLM Behavior
1718

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ProcessedQuery,
2525
StreamEvent,
2626
StreamEventType,
27+
combine_usage,
2728
title_from_url,
2829
)
2930
from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram
@@ -98,14 +99,8 @@ def _accumulate_usage(self, prediction: dspy.Prediction) -> None:
9899
Args:
99100
prediction: DSPy prediction object with usage information
100101
"""
101-
usage = prediction.get_lm_usage();
102-
for model_name, metrics in usage.items():
103-
if model_name not in self._accumulated_usage:
104-
self._accumulated_usage[model_name] = {}
105-
for metric_name, value in metrics.items():
106-
self._accumulated_usage[model_name][metric_name] = (
107-
self._accumulated_usage[model_name].get(metric_name, 0) + value
108-
)
102+
usage = prediction.get_lm_usage()
103+
self._accumulated_usage = combine_usage(self._accumulated_usage, usage)
109104

110105
def _reset_usage(self) -> None:
111106
"""Reset accumulated usage for a new request."""
@@ -118,22 +113,28 @@ async def _aprocess_query_and_retrieve_docs(
118113
sources: list[DocumentSource] | None = None,
119114
) -> tuple[ProcessedQuery, list[Document]]:
120115
"""Process query and retrieve documents - shared async logic."""
121-
processed_query = await self.query_processor.aforward(
116+
qp_prediction = await self.query_processor.aforward(
122117
query=query, chat_history=chat_history_str
123118
)
124-
self._accumulate_usage(processed_query)
119+
self._accumulate_usage(qp_prediction)
120+
processed_query = qp_prediction.processed_query
125121
self._current_processed_query = processed_query
126122

127123
# Use provided sources or fall back to processed query sources
128124
retrieval_sources = sources or processed_query.resources
129-
documents = await self.document_retriever.aforward(
125+
dr_prediction = await self.document_retriever.aforward(
130126
processed_query=processed_query, sources=retrieval_sources
131127
)
128+
self._accumulate_usage(dr_prediction)
129+
documents = dr_prediction.documents
132130

133131
# Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources.
134132
try:
135133
if DocumentSource.STARKNET_BLOG in retrieval_sources:
136-
grok_docs = await self.grok_search.aforward(processed_query, chat_history_str)
134+
grok_pred = await self.grok_search.aforward(processed_query, chat_history_str)
135+
self._accumulate_usage(grok_pred)
136+
grok_docs = grok_pred.documents
137+
137138
self._grok_citations = list(self.grok_search.last_citations)
138139
if grok_docs:
139140
documents.extend(grok_docs)
@@ -151,7 +152,9 @@ async def _aprocess_query_and_retrieve_docs(
151152
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5),
152153
adapter=XMLAdapter(),
153154
):
154-
documents = await self.retrieval_judge.aforward(query=query, documents=documents)
155+
judge_pred = await self.retrieval_judge.aforward(query=query, documents=documents)
156+
self._accumulate_usage(judge_pred)
157+
documents = judge_pred.documents
155158
except Exception as e:
156159
logger.warning(
157160
"Retrieval judge failed (async), using all documents",
@@ -197,14 +200,18 @@ async def aforward(
197200
if mcp_mode:
198201
result = await self.mcp_generation_program.aforward(documents)
199202
self._accumulate_usage(result)
203+
result.set_lm_usage(self._accumulated_usage)
200204
return result
201205

202206
context = self._prepare_context(documents)
203207

204208
result = await self.generation_program.aforward(
205209
query=query, context=context, chat_history=chat_history_str
206210
)
207-
self._accumulate_usage(result)
211+
if result:
212+
self._accumulate_usage(result)
213+
# Update the result's usage to include accumulated usage from previous steps
214+
result.set_lm_usage(self._accumulated_usage)
208215
return result
209216

210217

@@ -283,6 +290,7 @@ async def aforward_streaming(
283290
logger.warning(f"Unknown signature field name: {chunk.signature_field_name}")
284291
elif isinstance(chunk, dspy.Prediction):
285292
# Final complete answer
293+
self._accumulate_usage(chunk)
286294
final_text = getattr(chunk, "answer", None) or chunk_accumulator
287295
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text)
288296
rt.end(outputs={"output": final_text})

python/src/cairo_coder/core/types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,33 @@ def to_dict(self) -> dict[str, Any]:
197197
"details": self.details,
198198
"timestamp": self.timestamp.isoformat(),
199199
}
200+
201+
202+
def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
203+
"""Combine two LM usage dictionaries, tolerating missing inputs."""
204+
result: LMUsage = {model: (metrics or {}).copy() for model, metrics in usage1.items()}
205+
206+
for model, metrics in usage2.items():
207+
if model not in result:
208+
result[model] = metrics.copy()
209+
else:
210+
# Merge metrics
211+
for key, value in metrics.items():
212+
if isinstance(value, int | float):
213+
result[model][key] = result[model].get(key, 0) + value
214+
elif isinstance(value, dict):
215+
if key not in result[model] or result[model][key] is None:
216+
result[model][key] = value.copy()
217+
else:
218+
# Recursive merge for nested dicts
219+
for detail_key, detail_value in value.items():
220+
if isinstance(detail_value, int | float):
221+
result[model][key][detail_key] = (
222+
result[model][key].get(detail_key, 0) + detail_value
223+
)
224+
return result
225+
226+
200227
class AgentResponse(BaseModel):
201228
"""Response from agent processing."""
202229

python/src/cairo_coder/dspy/document_retriever.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def __init__(
565565

566566
async def aforward(
567567
self, processed_query: ProcessedQuery, sources: list[DocumentSource] | None = None
568-
) -> list[Document]:
568+
) -> dspy.Prediction:
569569
"""
570570
Execute the document retrieval process asynchronously.
571571
@@ -574,7 +574,7 @@ async def aforward(
574574
sources: Optional list of DocumentSource to filter by
575575
576576
Returns:
577-
List of relevant Document objects, ranked by similarity
577+
dspy.Prediction containing list of relevant Document objects, ranked by similarity
578578
"""
579579
# Use sources from processed query if not provided
580580
if sources is None:
@@ -584,10 +584,15 @@ async def aforward(
584584
documents = await self._afetch_documents(processed_query, sources)
585585

586586
if not documents:
587-
return []
587+
empty_prediction = dspy.Prediction(documents=[])
588+
empty_prediction.set_lm_usage({})
589+
return empty_prediction
588590

589591
# Step 2: Enrich context with appropriate templates based on query type.
590-
return self._enhance_context(processed_query, documents)
592+
enhanced_documents = self._enhance_context(processed_query, documents)
593+
prediction = dspy.Prediction(documents=enhanced_documents)
594+
prediction.set_lm_usage({})
595+
return prediction
591596

592597
def forward(
593598
self, processed_query: ProcessedQuery, sources: list[DocumentSource] | None = None
@@ -701,7 +706,11 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume
701706
context.append(
702707
Document(
703708
page_content=CONTRACT_TEMPLATE,
704-
metadata={"title": CONTRACT_TEMPLATE_TITLE, "source": CONTRACT_TEMPLATE_TITLE, "sourceLink": "https://www.starknet.io/cairo-book/ch103-06-01-deploying-and-interacting-with-a-voting-contract.html"},
709+
metadata={
710+
"title": CONTRACT_TEMPLATE_TITLE,
711+
"source": DocumentSource.CAIRO_BOOK,
712+
"sourceLink": "https://www.starknet.io/cairo-book/ch103-06-01-deploying-and-interacting-with-a-voting-contract.html",
713+
},
705714
)
706715
)
707716

@@ -710,7 +719,11 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume
710719
context.append(
711720
Document(
712721
page_content=TEST_TEMPLATE,
713-
metadata={"title": TEST_TEMPLATE_TITLE, "source": TEST_TEMPLATE_TITLE, "sourceLink": "https://www.starknet.io/cairo-book/ch104-02-testing-smart-contracts.html"},
722+
metadata={
723+
"title": TEST_TEMPLATE_TITLE,
724+
"source": DocumentSource.CAIRO_BOOK,
725+
"sourceLink": "https://www.starknet.io/cairo-book/ch104-02-testing-smart-contracts.html",
726+
},
714727
)
715728
)
716729
return context

python/src/cairo_coder/dspy/grok_search.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _domain_from_url(url: str) -> str:
9696
return url
9797

9898
@traceable(name="GrokSearchProgram", run_type="llm")
99-
async def aforward(self, processed_query: ProcessedQuery, chat_history: str) -> list[Document]:
99+
async def aforward(self, processed_query: ProcessedQuery, chat_history: str) -> dspy.Prediction:
100100
formatted_query = f"""Answer the following query: {processed_query.original}. \
101101
Here is the chat history: {chat_history}, that might be relevant to the question. \
102102
For more context, here are some semantic terms associated with the question: \
@@ -148,4 +148,6 @@ async def aforward(self, processed_query: ProcessedQuery, chat_history: str) ->
148148
)
149149
)
150150

151-
return documents
151+
prediction = dspy.Prediction(documents=documents)
152+
prediction.set_lm_usage({})
153+
return prediction

python/src/cairo_coder/dspy/query_processor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from langsmith import traceable
1414

1515
import dspy
16-
from cairo_coder.core.types import DocumentSource, ProcessedQuery, LMUsage
16+
from cairo_coder.core.types import DocumentSource, ProcessedQuery
1717

1818
logger = structlog.get_logger(__name__)
1919

@@ -125,7 +125,7 @@ def __init__(self):
125125
}
126126

127127
@traceable(name="QueryProcessorProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
128-
async def aforward(self, query: str, chat_history: Optional[str] = None) -> tuple[ProcessedQuery, LMUsage]:
128+
async def aforward(self, query: str, chat_history: Optional[str] = None) -> dspy.Prediction:
129129
"""
130130
Process a user query into a structured format for document retrieval.
131131
@@ -134,7 +134,7 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> tupl
134134
chat_history: Previous conversation context (optional)
135135
136136
Returns:
137-
ProcessedQuery with search terms, resource identification, and categorization
137+
dspy.Prediction containing processed_query and attached usage
138138
"""
139139
# Execute the DSPy retrieval program
140140
result = await self.retrieval_program.aforward(query=query, chat_history=chat_history)
@@ -151,7 +151,11 @@ async def aforward(self, query: str, chat_history: Optional[str] = None) -> tupl
151151
is_test_related=self._is_test_query(query),
152152
resources=resources,
153153
)
154-
return processed_query, result.get_lm_usage()
154+
155+
prediction = dspy.Prediction(processed_query=processed_query)
156+
prediction.set_lm_usage(result.get_lm_usage() or {})
157+
158+
return prediction
155159

156160
def _validate_resources(self, resources: list[str]) -> list[DocumentSource]:
157161
"""

python/src/cairo_coder/dspy/retrieval_judge.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from langsmith import traceable
1717

1818
import dspy
19-
from cairo_coder.core.types import Document
19+
from cairo_coder.core.types import Document, combine_usage
2020
from cairo_coder.dspy.document_retriever import CONTRACT_TEMPLATE_TITLE, TEST_TEMPLATE_TITLE
2121

2222
logger = structlog.get_logger(__name__)
@@ -135,14 +135,16 @@ def __init__(self):
135135
@traceable(
136136
name="RetrievalJudge", run_type="llm", metadata={"llm_provider": dspy.settings.lm}
137137
)
138-
async def aforward(self, query: str, documents: list[Document]) -> list[Document]:
138+
async def aforward(self, query: str, documents: list[Document]) -> dspy.Prediction:
139139
"""Async judge."""
140140
if not documents:
141-
return documents
141+
return dspy.Prediction(documents=documents)
142142

143143
keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(
144144
documents
145145
)
146+
147+
aggregated_usage = {}
146148

147149
# TODO: can we use dspy.Parallel here instead of asyncio gather?
148150
if judged_payloads:
@@ -154,6 +156,12 @@ async def judge_one(doc_string: str):
154156
results = await asyncio.gather(
155157
*[judge_one(ds) for ds in judged_payloads], return_exceptions=True
156158
)
159+
160+
# Aggregate usage from results
161+
for res in results:
162+
if isinstance(res, dspy.Prediction):
163+
aggregated_usage = combine_usage(aggregated_usage, res.get_lm_usage())
164+
157165
self._attach_scores_and_filter_async(
158166
query=query,
159167
documents=documents,
@@ -167,9 +175,11 @@ async def judge_one(doc_string: str):
167175
error=str(e),
168176
exc_info=True,
169177
)
170-
return documents
178+
return dspy.Prediction(documents=documents)
171179

172-
return keep_docs
180+
pred = dspy.Prediction(documents=keep_docs)
181+
pred.set_lm_usage(aggregated_usage)
182+
return pred
173183

174184
# =========================
175185
# Internal Helpers

python/src/cairo_coder/server/app.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ async def log_interaction_task(
171171
query=query,
172172
generated_answer=response.choices[0].message.content if response.choices else None,
173173
retrieved_sources=sources_data,
174-
# TODO: fix LLM usage metrics
175-
llm_usage={}
174+
llm_usage=agent.get_lm_usage(),
176175
)
177176
await create_user_interaction(interaction)
178177

@@ -203,7 +202,7 @@ async def log_interaction_raw(
203202
query=query,
204203
generated_answer=generated_answer,
205204
retrieved_sources=sources_data,
206-
llm_usage={},
205+
llm_usage=agent.get_lm_usage()
207206
)
208207
await create_user_interaction(interaction)
209208

@@ -270,13 +269,15 @@ async def value_error_handler(request: Request, exc: ValueError):
270269
logger.warning("Bad request", error=str(exc), path=request.url.path)
271270
return JSONResponse(
272271
status_code=400,
273-
content=ErrorResponse(
274-
error=ErrorDetail(
275-
message=str(exc),
276-
type="invalid_request_error",
277-
code="invalid_request",
278-
)
279-
).model_dump(),
272+
content={
273+
"detail": ErrorResponse(
274+
error=ErrorDetail(
275+
message=str(exc),
276+
type="invalid_request_error",
277+
code="invalid_request",
278+
)
279+
).model_dump()
280+
},
280281
)
281282

282283
@self.app.exception_handler(Exception)
@@ -285,13 +286,15 @@ async def global_exception_handler(request: Request, exc: Exception):
285286
logger.error("Unhandled exception", error=str(exc), path=request.url.path, exc_info=True)
286287
return JSONResponse(
287288
status_code=500,
288-
content=ErrorResponse(
289-
error=ErrorDetail(
290-
message="Internal server error",
291-
type="server_error",
292-
code="internal_error",
293-
)
294-
).model_dump(),
289+
content={
290+
"detail": ErrorResponse(
291+
error=ErrorDetail(
292+
message=f"Internal server error: {str(exc)}",
293+
type="server_error",
294+
code="internal_error",
295+
)
296+
).model_dump()
297+
},
295298
)
296299

297300
def _setup_routes(self):
@@ -340,8 +343,19 @@ async def agent_chat_completions(
340343
agent_factory: AgentFactory = Depends(get_agent_factory),
341344
):
342345
"""Agent-specific chat completions"""
343-
# Validate agent exists (will raise ValueError if not found, handled by global handler)
344-
agent_factory.get_agent_info(agent_id=agent_id)
346+
try:
347+
agent_factory.get_agent_info(agent_id=agent_id)
348+
except ValueError as exc:
349+
raise HTTPException(
350+
status_code=404,
351+
detail={
352+
"error": {
353+
"message": str(exc),
354+
"type": "invalid_request_error",
355+
"code": "agent_not_found",
356+
}
357+
},
358+
) from exc
345359

346360
# Determine MCP mode
347361
mcp_mode = bool(mcp or x_mcp_mode)

0 commit comments

Comments
 (0)