1919from cairo_coder .core .types import (
2020 Document ,
2121 DocumentSource ,
22+ FormattedSource ,
2223 Message ,
2324 ProcessedQuery ,
2425 StreamEvent ,
@@ -82,11 +83,34 @@ def __init__(self, config: RagPipelineConfig):
8283 self ._current_processed_query : ProcessedQuery | None = None
8384 self ._current_documents : list [Document ] = []
8485
86+ # Token usage accumulator
87+ self ._accumulated_usage : dict [str , dict [str , int ]] = {}
88+
8589 @property
8690 def last_retrieved_documents (self ) -> list [Document ]:
8791 """Documents retrieved during the most recent pipeline execution."""
8892 return self ._current_documents
8993
94+ def _accumulate_usage (self , prediction : dspy .Prediction ) -> None :
95+ """
96+ Accumulate token usage from a prediction.
97+
98+ Args:
99+ prediction: DSPy prediction object with usage information
100+ """
101+ usage = prediction .get_lm_usage ();
102+ for model_name , metrics in usage .items ():
103+ if model_name not in self ._accumulated_usage :
104+ self ._accumulated_usage [model_name ] = {}
105+ for metric_name , value in metrics .items ():
106+ self ._accumulated_usage [model_name ][metric_name ] = (
107+ self ._accumulated_usage [model_name ].get (metric_name , 0 ) + value
108+ )
109+
110+ def _reset_usage (self ) -> None :
111+ """Reset accumulated usage for a new request."""
112+ self ._accumulated_usage = {}
113+
90114 async def _aprocess_query_and_retrieve_docs (
91115 self ,
92116 query : str ,
@@ -97,6 +121,7 @@ async def _aprocess_query_and_retrieve_docs(
97121 processed_query = await self .query_processor .aforward (
98122 query = query , chat_history = chat_history_str
99123 )
124+ self ._accumulate_usage (processed_query )
100125 self ._current_processed_query = processed_query
101126
102127 # Use provided sources or fall back to processed query sources
@@ -158,6 +183,9 @@ async def aforward(
158183 mcp_mode : bool = False ,
159184 sources : list [DocumentSource ] | None = None ,
160185 ) -> dspy .Prediction :
186+ # Reset usage for this request
187+ self ._reset_usage ()
188+
161189 chat_history_str = self ._format_chat_history (chat_history or [])
162190 processed_query , documents = await self ._aprocess_query_and_retrieve_docs (
163191 query , chat_history_str , sources
@@ -167,13 +195,17 @@ async def aforward(
167195 )
168196
169197 if mcp_mode :
170- return await self .mcp_generation_program .aforward (documents )
198+ result = await self .mcp_generation_program .aforward (documents )
199+ self ._accumulate_usage (result )
200+ return result
171201
172202 context = self ._prepare_context (documents )
173203
174- return await self .generation_program .aforward (
204+ result = await self .generation_program .aforward (
175205 query = query , context = context , chat_history = chat_history_str
176206 )
207+ self ._accumulate_usage (result )
208+ return result
177209
178210
179211 async def aforward_streaming (
@@ -268,28 +300,12 @@ async def aforward_streaming(
268300
269301 def get_lm_usage (self ) -> dict [str , dict [str , int ]]:
270302 """
271- Get the total number of tokens used by the LLM.
272- """
273- generation_usage = self .generation_program .get_lm_usage ()
274- query_usage = self .query_processor .get_lm_usage ()
275- judge_usage = self .retrieval_judge .get_lm_usage ()
276-
277- # Additive merge strategy
278- merged_usage = {}
279-
280- # Helper function to merge usage dictionaries
281- def merge_usage_dict (target : dict , source : dict ) -> None :
282- for model_name , metrics in source .items ():
283- if model_name not in target :
284- target [model_name ] = {}
285- for metric_name , value in metrics .items ():
286- target [model_name ][metric_name ] = target [model_name ].get (metric_name , 0 ) + value
303+ Get accumulated token usage from all predictions in the pipeline.
287304
288- merge_usage_dict (merged_usage , generation_usage )
289- merge_usage_dict (merged_usage , query_usage )
290- merge_usage_dict (merged_usage , judge_usage )
291-
292- return merged_usage
305+ Returns:
306+ Dictionary mapping model names to usage metrics
307+ """
308+ return self ._accumulated_usage
293309
294310 def _format_chat_history (self , chat_history : list [Message ]) -> str :
295311 """
@@ -311,7 +327,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str:
311327
312328 return "\n " .join (formatted_messages )
313329
314- def _format_sources (self , documents : list [Document ]) -> list [dict [ str , Any ] ]:
330+ def _format_sources (self , documents : list [Document ]) -> list [FormattedSource ]:
315331 """
316332 Format documents for the frontend-friendly sources event.
317333
@@ -322,9 +338,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
322338 documents: List of retrieved documents
323339
324340 Returns:
325- List of dicts: [{"title": str, "url": str}, ...]
341+ List of formatted sources with metadata
326342 """
327- sources : list [dict [ str , str ] ] = []
343+ sources : list [FormattedSource ] = []
328344 seen_urls : set [str ] = set ()
329345
330346
0 commit comments