Skip to content

Commit 234c0b3

Browse files
fix(refacto): token tracking in response and readme (#34)
Co-authored-by: enitrat <[email protected]>
1 parent 5db5edc commit 234c0b3

File tree

7 files changed

+195
-8
lines changed

7 files changed

+195
-8
lines changed

README.md

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,24 @@ Using Docker is highly recommended for a streamlined setup. For instructions on
8484
SIMILARITY_MEASURE="cosine"
8585
```
8686

87-
3. **Configure LangSmith (Optional but Recommended)**
87+
3. **Configure Agents Package (`packages/agents/config.toml`)**
88+
The ingester requires a configuration file in the `packages/agents` directory. Create `packages/agents/config.toml` with the following content:
89+
90+
```toml
91+
[API_KEYS]
92+
OPENAI = "your-openai-api-key-here"
93+
94+
[VECTOR_DB]
95+
POSTGRES_USER = "cairocoder"
96+
POSTGRES_HOST = "postgres"
97+
POSTGRES_DB = "cairocoder"
98+
POSTGRES_PASSWORD = "cairocoder"
99+
POSTGRES_PORT = "5432"
100+
```
101+
102+
Replace `"your-openai-api-key-here"` with your actual OpenAI API key. The database credentials should match those configured in your `.env` file.
103+
104+
4. **Configure LangSmith (Optional but Recommended)**
88105
To monitor and debug LLM calls, configure LangSmith.
89106

90107
- Create an account at [LangSmith](https://smith.langchain.com/) and create a project.
@@ -95,7 +112,7 @@ Using Docker is highly recommended for a streamlined setup. For instructions on
95112
LANGSMITH_API_KEY="lsv2..."
96113
```
97114

98-
4. **Add your API keys to `python/.env`: (mandatory)**
115+
5. **Add your API keys to `python/.env`: (mandatory)**
99116

100117
```yaml
101118
OPENAI_API_KEY="sk-..."
@@ -105,7 +122,7 @@ Using Docker is highly recommended for a streamlined setup. For instructions on
105122

106123
Add the API keys required for the LLMs you want to use.
107124

108-
5. **Run the ingesters (mandatory)**
125+
6. **Run the ingesters (mandatory)**
109126

110127
The ingesters are responsible for populating the vector database with the documentation sources. They need to be ran a first time, in isolation, so that the database is created.
111128

@@ -115,7 +132,7 @@ Using Docker is highly recommended for a streamlined setup. For instructions on
115132

116133
Once the ingester completes, the database will be populated with embeddings from all supported documentation sources, making them available for the RAG pipeline. Stop the database when you no longer need it.
117134

118-
6. **Run the Application**
135+
7. **Run the Application**
119136
Once the ingesters are done, start the database and the Python backend service using Docker Compose:
120137
```bash
121138
docker compose up postgres backend --build

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,22 @@ def get_lm_usage(self) -> dict[str, int]:
264264
"""
265265
generation_usage = self.generation_program.get_lm_usage()
266266
query_usage = self.query_processor.get_lm_usage()
267-
# merge both dictionaries
268-
return {**generation_usage, **query_usage}
267+
268+
# Additive merge strategy
269+
merged_usage = {}
270+
271+
# Helper function to merge usage dictionaries
272+
def merge_usage_dict(target: dict, source: dict) -> None:
273+
for model_name, metrics in source.items():
274+
if model_name not in target:
275+
target[model_name] = {}
276+
for metric_name, value in metrics.items():
277+
target[model_name][metric_name] = target[model_name].get(metric_name, 0) + value
278+
279+
merge_usage_dict(merged_usage, generation_usage)
280+
merge_usage_dict(merged_usage, query_usage)
281+
282+
return merged_usage
269283

270284
def _format_chat_history(self, chat_history: list[Message]) -> str:
271285
"""

python/src/cairo_coder/dspy/document_retriever.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,40 @@ def _enhance_context(self, query: str, context: list[Document]) -> list[Document
703703
)
704704
)
705705
return context
706+
707+
def get_lm_usage(self) -> dict[str, int]:
708+
"""
709+
Get the total number of tokens used by the LLM.
710+
Note: Document retrieval doesn't use LLM tokens directly, but embedding tokens might be tracked.
711+
"""
712+
# Document retrieval doesn't use LLM tokens, but we return empty dict for consistency
713+
return {}
714+
715+
716+
def create_document_retriever(
717+
vector_store_config: VectorStoreConfig,
718+
vector_db: SourceFilteredPgVectorRM | None = None,
719+
max_source_count: int = 5,
720+
similarity_threshold: float = 0.4,
721+
embedding_model: str = "text-embedding-3-large",
722+
) -> DocumentRetrieverProgram:
723+
"""
724+
Factory function to create a DocumentRetrieverProgram instance.
725+
726+
Args:
727+
vector_store_config: VectorStoreConfig for document retrieval
728+
vector_db: Optional pre-initialized vector database instance
729+
max_source_count: Maximum number of documents to retrieve
730+
similarity_threshold: Minimum similarity score for document inclusion
731+
embedding_model: OpenAI embedding model to use for reranking
732+
733+
Returns:
734+
Configured DocumentRetrieverProgram instance
735+
"""
736+
return DocumentRetrieverProgram(
737+
vector_store_config=vector_store_config,
738+
vector_db=vector_db,
739+
max_source_count=max_source_count,
740+
similarity_threshold=similarity_threshold,
741+
embedding_model=embedding_model,
742+
)

python/src/cairo_coder/dspy/generation_program.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,13 @@ def forward(self, documents: list[Document]) -> dspy.Prediction:
269269

270270
return dspy.Prediction(answer='\n'.join(formatted_docs))
271271

272+
def get_lm_usage(self) -> dict[str, int]:
273+
"""
274+
Get the total number of tokens used by the LLM.
275+
Note: MCP mode doesn't use LLM generation, so no tokens are consumed.
276+
"""
277+
# MCP mode doesn't use LLM generation, return empty dict
278+
return {}
272279

273280

274281
def create_generation_program(program_type: str = "general") -> GenerationProgram:

python/src/cairo_coder/dspy/query_processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def _is_test_query(self, query: str) -> bool:
226226
query_lower = query.lower()
227227
return any(keyword in query_lower for keyword in self.test_keywords)
228228

229+
def get_lm_usage(self) -> dict[str, int]:
230+
"""
231+
Get the total number of tokens used by the LLM.
232+
"""
233+
return self.retrieval_program.get_lm_usage()
234+
229235

230236
def create_query_processor() -> QueryProcessorProgram:
231237
"""

python/src/cairo_coder/server/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,10 @@ async def _generate_chat_completion(
456456
# Somehow this is not always returning something (None). In that case, we're not capable of getting the
457457
# tracked usage.
458458
lm_usage = response.get_lm_usage()
459+
logger.info(f"LM usage from response: {lm_usage}")
460+
459461
if not lm_usage:
462+
logger.warning("No LM usage data available, setting defaults to 0")
460463
total_prompt_tokens = 0
461464
total_completion_tokens = 0
462465
total_tokens = 0
@@ -467,6 +470,7 @@ async def _generate_chat_completion(
467470
entry.get("completion_tokens", 0) for entry in lm_usage.values()
468471
)
469472
total_tokens = sum(entry.get("total_tokens", 0) for entry in lm_usage.values())
473+
logger.info(f"Token usage - prompt: {total_prompt_tokens}, completion: {total_completion_tokens}, total: {total_tokens}")
470474

471475
return ChatCompletionResponse(
472476
id=response_id,

python/tests/unit/test_rag_pipeline.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from unittest.mock import AsyncMock, Mock, patch
99

10-
import pytest
1110
import dspy
11+
import pytest
1212

1313
from cairo_coder.core.rag_pipeline import (
1414
RagPipeline,
@@ -22,6 +22,18 @@
2222
from cairo_coder.dspy.query_processor import QueryProcessorProgram
2323

2424

25+
# Helper function to merge usage dictionaries
26+
def merge_usage_dict(sources: list[dict]) -> dict:
27+
"""Merge usage dictionaries."""
28+
merged_usage = {}
29+
for source in sources:
30+
for model_name, metrics in source.items():
31+
if model_name not in merged_usage:
32+
merged_usage[model_name] = {}
33+
for metric_name, value in metrics.items():
34+
merged_usage[model_name][metric_name] = merged_usage[model_name].get(metric_name, 0) + value
35+
return merged_usage
36+
2537
@pytest.fixture(scope='function')
2638
def mock_pgvector_rm():
2739
"""Patch the vector database for the document retriever."""
@@ -57,7 +69,8 @@ def mock_query_processor(self):
5769
resources=[DocumentSource.CAIRO_BOOK, DocumentSource.STARKNET_DOCS],
5870
)
5971
processor.forward.return_value = mock_res
60-
processor.aforward.return_value = mock_res
72+
processor.aforward = AsyncMock(return_value=mock_res)
73+
processor.get_lm_usage.return_value = {}
6174
return processor
6275

6376
@pytest.fixture
@@ -84,6 +97,7 @@ def mock_document_retriever(self):
8497
]
8598
retriever.aforward = AsyncMock(return_value=mock_return_value)
8699
retriever.forward = Mock(return_value=mock_return_value)
100+
retriever.get_lm_usage.return_value = {}
87101
return retriever
88102

89103
@pytest.fixture
@@ -101,6 +115,7 @@ async def mock_streaming(*args, **kwargs):
101115
yield chunk
102116

103117
program.forward_streaming = mock_streaming
118+
program.get_lm_usage.return_value = {}
104119
return program
105120

106121
@pytest.fixture
@@ -125,6 +140,7 @@ def mock_mcp_generation_program(self):
125140
Storage variables use #[storage] attribute.
126141
"""
127142
program.forward.return_value = dspy.Prediction(answer=mock_res)
143+
program.get_lm_usage.return_value = {}
128144
return program
129145

130146
@pytest.fixture
@@ -409,6 +425,92 @@ def test_get_current_state(self, pipeline):
409425
assert state["config"]["max_source_count"] == 10
410426
assert state["config"]["similarity_threshold"] == 0.4
411427

428+
# Define reusable usage constants to keep tests DRY
429+
_QUERY_USAGE_MINI = {
430+
"gpt-4o-mini": {"prompt_tokens": 200, "completion_tokens": 100, "total_tokens": 300}
431+
}
432+
_GEN_USAGE_MINI = {
433+
"gpt-4o-mini": {"prompt_tokens": 1000, "completion_tokens": 500, "total_tokens": 1500}
434+
}
435+
_GEN_USAGE_FULL = {
436+
"gpt-4o": {"prompt_tokens": 1000, "completion_tokens": 500, "total_tokens": 1500}
437+
}
438+
439+
440+
@pytest.mark.parametrize(
441+
"query_usage, generation_usage, expected_usage",
442+
[
443+
pytest.param(
444+
_QUERY_USAGE_MINI,
445+
_GEN_USAGE_MINI,
446+
merge_usage_dict([_QUERY_USAGE_MINI, _GEN_USAGE_MINI]),
447+
id="same_model_aggregation",
448+
),
449+
pytest.param(
450+
_QUERY_USAGE_MINI,
451+
_GEN_USAGE_FULL,
452+
merge_usage_dict([_QUERY_USAGE_MINI, _GEN_USAGE_FULL]),
453+
id="different_model_aggregation",
454+
),
455+
pytest.param({}, {}, {}, id="empty_usage"),
456+
pytest.param(
457+
_QUERY_USAGE_MINI, {}, _QUERY_USAGE_MINI, id="partial_empty_usage"
458+
),
459+
],
460+
)
461+
def test_get_lm_usage_aggregation(
462+
self, pipeline, query_usage, generation_usage, expected_usage
463+
):
464+
"""Tests that get_lm_usage correctly aggregates token usage from its components."""
465+
# The RAG pipeline implementation merges dictionaries with query_usage taking precedence
466+
pipeline.query_processor.get_lm_usage.return_value = query_usage
467+
pipeline.generation_program.get_lm_usage.return_value = generation_usage
468+
469+
result = pipeline.get_lm_usage()
470+
471+
pipeline.query_processor.get_lm_usage.assert_called_once()
472+
pipeline.generation_program.get_lm_usage.assert_called_once()
473+
474+
assert result == expected_usage
475+
476+
@pytest.mark.asyncio
477+
@pytest.mark.parametrize(
478+
"mcp_mode, expected_usage",
479+
[
480+
pytest.param(True, _QUERY_USAGE_MINI, id="mcp_mode"),
481+
pytest.param(
482+
False, merge_usage_dict([_QUERY_USAGE_MINI, _GEN_USAGE_FULL]), id="normal_mode"
483+
),
484+
],
485+
)
486+
async def test_get_lm_usage_after_streaming(
487+
self, pipeline, mcp_mode, expected_usage
488+
):
489+
"""Tests that get_lm_usage works correctly after a streaming execution."""
490+
# To test token aggregation, we mock the return values of sub-components'
491+
# get_lm_usage methods. The test logic simulates which components would
492+
# be "active" in each mode by setting others to return empty usage.
493+
pipeline.query_processor.get_lm_usage.return_value = self._QUERY_USAGE_MINI
494+
if mcp_mode:
495+
pipeline.generation_program.get_lm_usage.return_value = {}
496+
# MCP program doesn't use an LM, so its usage is empty
497+
pipeline.mcp_generation_program.get_lm_usage.return_value = {}
498+
else:
499+
pipeline.generation_program.get_lm_usage.return_value = self._GEN_USAGE_FULL
500+
pipeline.mcp_generation_program.get_lm_usage.return_value = {}
501+
502+
# Execute the pipeline to ensure the full flow is invoked.
503+
async for _ in pipeline.forward_streaming(
504+
query="How do I create a Cairo contract?", mcp_mode=mcp_mode
505+
):
506+
pass
507+
508+
result = pipeline.get_lm_usage()
509+
510+
assert result == expected_usage
511+
pipeline.query_processor.get_lm_usage.assert_called()
512+
pipeline.generation_program.get_lm_usage.assert_called()
513+
412514

413515
class TestRagPipelineFactory:
414516
"""Test suite for RagPipelineFactory."""

0 commit comments

Comments
 (0)