Skip to content

Commit 17f12fc

Browse files
authored
feat: llm judge for docs retrieval (#32)
1 parent 696f312 commit 17f12fc

File tree

13 files changed

+1405
-442
lines changed

13 files changed

+1405
-442
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ The RAG pipeline is implemented in the `python/src/cairo_coder/core/` directory
182182

183183
1. **QueryProcessorProgram**: Analyzes user queries to extract semantic search queries and identify relevant documentation sources.
184184
2. **DocumentRetrieverProgram**: Retrieves relevant Cairo documentation from the vector database.
185-
3. **GenerationProgram**: Generates Cairo code and explanations based on the retrieved context.
186-
4. **RagPipeline**: Orchestrates the entire RAG process, chaining the modules together.
185+
3. **RetrievalJudge**: LLM-based judge that scores retrieved documents for relevance, filtering out low-quality results.
186+
4. **GenerationProgram**: Generates Cairo code and explanations based on the retrieved context.
187+
5. **RagPipeline**: Orchestrates the entire RAG process, chaining the modules together.
187188

188189
## Development
189190

python/optimizers/results/optimized_rag.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,33 @@
7777
"document_retriever.vector_db": {
7878
"k": 5
7979
},
80+
"retrieval_judge.rater": {
81+
"traces": [],
82+
"train": [],
83+
"demos": [],
84+
"signature": {
85+
"instructions": "Compare a system's retrieval response to the query and rate how much it can be leveraged to answer the query. When asked to reason, enumerate key ideas in each response, and whether they are present in the expected output. A document is considered useful if it is directly relevant to the query, or if it is informative and can be useful for context. For example, if the query is about creating or fixing a smart contract, then, an example of a smart contract, even if not _directly_ related, is considered useful. If the query is about a specific Cairo language feature, then a document about that feature is considered useful. Contract and test templates are always considered useful.",
86+
"fields": [
87+
{
88+
"prefix": "Query:",
89+
"description": "User's specific Cairo programming question or request for code generation"
90+
},
91+
{
92+
"prefix": "System Resource:",
93+
"description": "Single resource text (content + minimal metadata/title)"
94+
},
95+
{
96+
"prefix": "Reasoning:",
97+
"description": "A short sentence, on why a selected resource will be useful. If it's not selected, reason about why it's not going to be useful. Start by Resource <resource_title>..."
98+
},
99+
{
100+
"prefix": "Resource Note",
101+
"description": "A note between 0 and 1.0 on how useful the resource is to directly answer the query. 0 being completely unrelated, 1.0 being very relevant, 0.5 being 'not directly relatd but still informative and can be useful for context."
102+
}
103+
]
104+
},
105+
"lm": null
106+
},
80107
"generation_program.generation_program.predict": {
81108
"traces": [],
82109
"train": [],

python/src/cairo_coder/core/agent_factory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from cairo_coder.core.config import AgentConfiguration, VectorStoreConfig
1313
from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineFactory
1414
from cairo_coder.core.types import DocumentSource, Message
15+
from cairo_coder.utils.logging import get_logger
16+
17+
logger = get_logger(__name__)
1518

1619

1720
@dataclass
@@ -91,7 +94,6 @@ def create_agent(
9194
vector_db=vector_db,
9295
)
9396

94-
9597
@staticmethod
9698
def create_agent_by_id(
9799
query: str,
@@ -178,7 +180,7 @@ def get_or_create_agent(
178180

179181
return agent
180182

181-
def clear_cache(self):
183+
def clear_cache(self) -> None:
182184
"""Clear the agent cache."""
183185
self._agent_cache.clear()
184186

@@ -276,7 +278,7 @@ def _create_pipeline_from_config(
276278
277279
Args:
278280
agent_config: Agent configuration
279-
vector_store: Vector store for document retrieval
281+
vector_store_config: Vector store for document retrieval
280282
query: User's query
281283
history: Chat history
282284
mcp_mode: Whether to use MCP mode

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@
2626
from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram
2727
from cairo_coder.dspy.generation_program import GenerationProgram, McpGenerationProgram
2828
from cairo_coder.dspy.query_processor import QueryProcessorProgram
29+
from cairo_coder.dspy.retrieval_judge import RetrievalJudge
2930
from cairo_coder.utils.logging import get_logger
3031

3132
logger = get_logger(__name__)
3233

34+
SOURCE_PREVIEW_MAX_LEN = 200
35+
3336

3437
# 1. Define a custom callback class that extends BaseCallback class
3538
class AgentLoggingCallback(BaseCallback):
@@ -38,28 +41,28 @@ def on_module_start(
3841
call_id: str,
3942
instance: Any,
4043
inputs: dict[str, Any],
41-
):
44+
) -> None:
4245
logger.debug("Starting module", call_id=call_id, inputs=inputs)
4346

4447
# 2. Implement on_module_end handler to run a custom logging code.
45-
def on_module_end(self, call_id, outputs, exception):
48+
def on_module_end(self, call_id: str, outputs: dict[str, Any], exception: Exception | None) -> None:
4649
step = "Reasoning" if self._is_reasoning_output(outputs) else "Acting"
4750
logger.debug(f"== {step} Step ===")
4851
for k, v in outputs.items():
4952
logger.debug(f" {k}: {v}")
5053
logger.debug("\n")
5154

52-
def _is_reasoning_output(self, outputs):
55+
def _is_reasoning_output(self, outputs: dict[str, Any]) -> bool:
5356
return any(k.startswith("Thought") for k in outputs if isinstance(k, str))
5457

5558

5659
class LangsmithTracingCallback(BaseCallback):
5760
@traceable()
58-
def on_lm_start(self, call_id, instance, inputs):
61+
def on_lm_start(self, call_id: str, instance: Any, inputs: dict[str, Any]) -> None:
5962
pass
6063

6164
@traceable()
62-
def on_lm_end(self, call_id, outputs, exception):
65+
def on_lm_end(self, call_id: str, outputs: dict[str, Any], exception: Exception | None) -> None:
6366
pass
6467

6568

@@ -103,6 +106,7 @@ def __init__(self, config: RagPipelineConfig):
103106
self.document_retriever = config.document_retriever
104107
self.generation_program = config.generation_program
105108
self.mcp_generation_program = config.mcp_generation_program
109+
self.retrieval_judge = RetrievalJudge()
106110

107111
# Pipeline state
108112
self._current_processed_query: ProcessedQuery | None = None
@@ -122,6 +126,19 @@ def _process_query_and_retrieve_docs(
122126
documents = self.document_retriever.forward(
123127
processed_query=processed_query, sources=retrieval_sources
124128
)
129+
130+
# Apply LLM judge if enabled
131+
try:
132+
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)):
133+
documents = self.retrieval_judge.forward(query=query, documents=documents)
134+
except Exception as e:
135+
logger.warning(
136+
"Retrieval judge failed (sync), using all documents",
137+
error=str(e),
138+
exc_info=True,
139+
)
140+
# documents already contains all retrieved docs, no action needed
141+
125142
self._current_documents = documents
126143

127144
return processed_query, documents
@@ -142,6 +159,18 @@ async def _aprocess_query_and_retrieve_docs(
142159
documents = await self.document_retriever.aforward(
143160
processed_query=processed_query, sources=retrieval_sources
144161
)
162+
163+
try:
164+
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)):
165+
documents = await self.retrieval_judge.aforward(query=query, documents=documents)
166+
except Exception as e:
167+
logger.warning(
168+
"Retrieval judge failed (async), using all documents",
169+
error=str(e),
170+
exc_info=True,
171+
)
172+
# documents already contains all retrieved docs, no action needed
173+
145174
self._current_documents = documents
146175

147176
return processed_query, documents
@@ -258,12 +287,13 @@ async def forward_streaming(
258287
logger.error("Pipeline error", error=e)
259288
yield StreamEvent(StreamEventType.ERROR, data=f"Pipeline error: {str(e)}")
260289

261-
def get_lm_usage(self) -> dict[str, int]:
290+
def get_lm_usage(self) -> dict[str, dict[str, int]]:
262291
"""
263292
Get the total number of tokens used by the LLM.
264293
"""
265294
generation_usage = self.generation_program.get_lm_usage()
266295
query_usage = self.query_processor.get_lm_usage()
296+
judge_usage = self.retrieval_judge.get_lm_usage()
267297

268298
# Additive merge strategy
269299
merged_usage = {}
@@ -278,6 +308,7 @@ def merge_usage_dict(target: dict, source: dict) -> None:
278308

279309
merge_usage_dict(merged_usage, generation_usage)
280310
merge_usage_dict(merged_usage, query_usage)
311+
merge_usage_dict(merged_usage, judge_usage)
281312

282313
return merged_usage
283314

@@ -317,8 +348,8 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
317348
"title": doc.metadata.get("title", "Untitled"),
318349
"url": doc.metadata.get("url", "#"),
319350
"source_display": doc.metadata.get("source_display", "Unknown Source"),
320-
"content_preview": doc.page_content[:200]
321-
+ ("..." if len(doc.page_content) > 200 else ""),
351+
"content_preview": doc.page_content[:SOURCE_PREVIEW_MAX_LEN]
352+
+ ("..." if len(doc.page_content) > SOURCE_PREVIEW_MAX_LEN else ""),
322353
}
323354
sources.append(source_info)
324355

@@ -481,7 +512,7 @@ def create_pipeline(
481512

482513
@staticmethod
483514
def create_scarb_pipeline(
484-
name: str, vector_store_config: VectorStoreConfig, **kwargs
515+
name: str, vector_store_config: VectorStoreConfig, **kwargs: Any
485516
) -> RagPipeline:
486517
"""
487518
Create a Scarb-specialized RAG Pipeline.
@@ -511,7 +542,7 @@ def create_scarb_pipeline(
511542
)
512543

513544

514-
def create_rag_pipeline(name: str, vector_store_config: VectorStoreConfig, **kwargs) -> RagPipeline:
545+
def create_rag_pipeline(name: str, vector_store_config: VectorStoreConfig, **kwargs: Any) -> RagPipeline:
515546
"""
516547
Convenience function to create a RAG Pipeline.
517548

python/src/cairo_coder/dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
create_mcp_generation_program,
1616
)
1717
from .query_processor import QueryProcessorProgram, create_query_processor
18+
from .retrieval_judge import RetrievalJudge
1819

1920
__all__ = [
2021
"QueryProcessorProgram",
@@ -24,4 +25,5 @@
2425
"McpGenerationProgram",
2526
"create_generation_program",
2627
"create_mcp_generation_program",
28+
"RetrievalJudge",
2729
]

python/src/cairo_coder/dspy/document_retriever.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
logger = structlog.get_logger()
2020

2121
# Templates for different types of requests
22+
CONTRACT_TEMPLATE_TITLE = "Contract Template"
2223
CONTRACT_TEMPLATE = """
23-
contract>
24+
<contract>
2425
use starknet::ContractAddress;
2526
2627
// Define the contract interface
@@ -61,15 +62,20 @@
6162
6263
#[derive(Drop, starknet::Event)]
6364
pub struct DataRegistered {
64-
user: ContractAddress,
65-
data: felt252,
65+
pub user: ContractAddress,
66+
pub data: felt252,
6667
}
6768
6869
#[derive(Drop, starknet::Event)]
6970
pub struct DataUpdated {
70-
user: ContractAddress,
71-
index: u64,
72-
new_data: felt252,
71+
pub user: ContractAddress,
72+
pub index: u64,
73+
pub new_data: felt252,
74+
}
75+
76+
#[constructor]
77+
fn constructor(ref self: ContractState, initial_data: usize) {
78+
self.foo.write(initial_data);
7379
}
7480
7581
// Implement the contract interface
@@ -137,8 +143,9 @@
137143
Never add comments with urls to sources in the code that you produce.
138144
"""
139145

146+
TEST_TEMPLATE_TITLE = "Contract Testing Template"
140147
TEST_TEMPLATE = """
141-
contract_test>
148+
<contract_test>
142149
// Import the contract module itself
143150
use registry::Registry;
144151
// Make the required inner structs available in scope
@@ -167,7 +174,7 @@
167174
// 4. Create a dispatcher to interact with the contract
168175
let contract = declare("Registry");
169176
let mut constructor_args = array![];
170-
Serde::serialize(@1_u8, ref constructor_args);
177+
Serde::serialize(@0_u8, ref constructor_args);
171178
let (contract_address, _err) = contract
172179
.unwrap()
173180
.contract_class()
@@ -194,11 +201,11 @@
194201
195202
// Verify the data was stored correctly
196203
let stored_data = dispatcher.get_data(0);
197-
assert(stored_data == 42, 'Wrong stored data');
204+
assert_eq!(stored_data, 42);
198205
199206
// Verify user-specific data
200207
let user_data = dispatcher.get_user_data(caller);
201-
assert(user_data == 42, 'Wrong user data');
208+
assert_eq!(user_data, 42);
202209
203210
// Verify event emission:
204211
// 1. Create the expected event
@@ -231,11 +238,11 @@
231238
232239
// Verify the update
233240
let updated_data = dispatcher.get_data(0);
234-
assert(updated_data == 100, 'Wrong updated data');
241+
assert_eq!(updated_data, 100);
235242
236243
// Verify user data was updated
237244
let user_data = dispatcher.get_user_data(caller);
238-
assert(user_data == 100, 'Wrong updated user data');
245+
assert_eq!(user_data, 100);
239246
240247
// Verify update event
241248
let expected_updated_event = Registry::Event::DataUpdated(
@@ -264,16 +271,16 @@
264271
let all_data = dispatcher.get_all_data();
265272
266273
// Verify array contents
267-
assert(*all_data.at(0) == 10, 'Wrong data at index 0');
268-
assert(*all_data.at(1) == 20, 'Wrong data at index 1');
269-
assert(*all_data.at(2) == 30, 'Wrong data at index 2');
270-
assert(all_data.len() == 3, 'Wrong array length');
274+
assert_eq!(*all_data.at(0), 10);
275+
assert_eq!(*all_data.at(1), 20);
276+
assert_eq!(*all_data.at(2), 30);
277+
assert_eq!(all_data.len(), 3);
271278
272279
stop_cheat_caller_address(dispatcher.contract_address);
273280
}
274281
275282
#[test]
276-
#[should_panic(expected: "Index out of bounds")]
283+
#[should_panic(expected : "Index out of bounds")]
277284
fn test_get_data_out_of_bounds() {
278285
let dispatcher = deploy_contract();
279286
@@ -690,7 +697,7 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume
690697
context.append(
691698
Document(
692699
page_content=CONTRACT_TEMPLATE,
693-
metadata={"title": "contract_template", "source": "contract_template"},
700+
metadata={"title": CONTRACT_TEMPLATE_TITLE, "source": CONTRACT_TEMPLATE_TITLE},
694701
)
695702
)
696703

@@ -699,7 +706,7 @@ def _enhance_context(self, processed_query: ProcessedQuery, context: list[Docume
699706
context.append(
700707
Document(
701708
page_content=TEST_TEMPLATE,
702-
metadata={"title": "test_template", "source": "test_template"},
709+
metadata={"title": TEST_TEMPLATE_TITLE, "source": TEST_TEMPLATE_TITLE},
703710
)
704711
)
705712
return context
@@ -739,4 +746,4 @@ def create_document_retriever(
739746
max_source_count=max_source_count,
740747
similarity_threshold=similarity_threshold,
741748
embedding_model=embedding_model,
742-
)
749+
)

0 commit comments

Comments
 (0)