diff --git a/examples/lightrag_reasoning_deepseek_rerank_demo.py b/examples/lightrag_reasoning_deepseek_rerank_demo.py new file mode 100644 index 000000000..63cb17286 --- /dev/null +++ b/examples/lightrag_reasoning_deepseek_rerank_demo.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +import os +import asyncio +import logging +import sys + +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import openai_embed, gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) + +# Set LightRAG logger to DEBUG level +lightrag_logger = logging.getLogger("lightrag") +lightrag_logger.setLevel(logging.DEBUG) + +# Set API credentials +if "DEEPSEEK_API_KEY" not in os.environ: + os.environ["DEEPSEEK_API_KEY"] = "YOUR DEEPSEEK API KEY FOR REASONING" +if "DEEPSEEK_API_BASE" not in os.environ: + os.environ["DEEPSEEK_API_BASE"] = "https://api.deepseek.com/v1" +if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY FOR EMBEDDING" + +WORKING_DIR = "./YOUR WORKING DIRECTORY" + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with the necessary configuration.""" + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=openai_embed, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +def main(): + # Initialize LightRAG + rag = asyncio.run(initialize_rag()) + + print("\n===== LIGHTRAG REASONING RE-RANKING DEMO =====") + print("This demo shows the step-by-step reasoning process for re-ranking nodes") + print( + "You'll see: 1) Original node ranking, 2) Reasoning chain of thought, 3) Re-ranked nodes" + ) + + print("\n===== STANDARD RANKING (NO REASONING) =====") + query = "Why does Scrooge manage to have a happy ending?" + standard_result = rag.query(query, param=QueryParam(mode="local")) + print(f"\n{standard_result}") + + print("\n===== WITH REASONING RE-RANKING =====") + print("Now the same query but with reasoning-based re-ranking of nodes:") + print( + "Watch for the ORIGINAL NODE RANKING, CHAIN OF THOUGHT REASONING, and RE-RANKED NODE ORDER" + ) + reasoning_result = rag.query( + query, + param=QueryParam( + mode="local", + use_reasoning_reranking=True, + reasoning_model_name="deepseek_r1", + ), + ) + print("\n===== FINAL ANSWER WITH REASONING RE-RANKING =====") + print(f"{reasoning_result}") + + print("\n===== HYBRID MODE WITH REASONING RE-RANKING =====") + complex_query = "How does Scrooge make a lot of money in the end of the story?" + print("Using a different query in hybrid mode with reasoning re-ranking:") + hybrid_result = rag.query( + complex_query, + param=QueryParam( + mode="hybrid", + use_reasoning_reranking=True, + reasoning_model_name="deepseek_r1", + ), + ) + print("\n===== FINAL ANSWER =====") + print(f"{hybrid_result}") + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 816034877..fe5a5299a 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -88,6 +88,16 @@ class QueryRequest(BaseModel): description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", ) + use_reasoning_reranking: Optional[bool] = Field( + default=None, + description="Whether to use reasoning model-based re-ranking for nodes. When enabled, a reasoning model will evaluate query-node relevance beyond vector similarity.", + ) + + reasoning_model_name: Optional[str] = Field( + default=None, + description="The name of the reasoning model to use for node re-ranking. Options: 'deepseek_r1', 'gpt_4o', 'gpt_4o_mini', or any other model supported by the system.", + ) + @field_validator("query", mode="after") @classmethod def query_strip_after(cls, query: str) -> str: diff --git a/lightrag/base.py b/lightrag/base.py index b1f63fa54..1167a4cfb 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -93,6 +93,16 @@ class QueryParam: This allows using different models for different query modes. """ + use_reasoning_reranking: bool = False + """Whether to use reasoning model-based re-ranking for nodes. + When enabled, a reasoning model will evaluate query-node relevance beyond vector similarity. + """ + + reasoning_model_name: str = "deepseek_r1" + """The name of the reasoning model to use for node re-ranking. + Options: "deepseek_r1", "gpt_4o", "gpt_4o_mini", or any other model supported by the system. + """ + @dataclass class StorageNameSpace(ABC): diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 68b92e837..3556f5021 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -78,7 +78,7 @@ def create_openai_async_client( "default_headers": default_headers, "api_key": api_key, } - + os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1" if base_url is not None: merged_configs["base_url"] = base_url else: @@ -102,8 +102,9 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, + extract_reasoning: bool = False, **kwargs: Any, -) -> str: +) -> str | tuple[str, str]: """Complete a prompt using OpenAI's API with caching support. Args: @@ -113,6 +114,7 @@ async def openai_complete_if_cache( history_messages: Optional list of previous messages in the conversation. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + extract_reasoning: Whether to extract and return reasoning content when available. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. @@ -122,7 +124,8 @@ async def openai_complete_if_cache( - keyword_extraction: Will be removed from kwargs before passing to OpenAI. Returns: - The completed text or an async iterator of text chunks if streaming. + Either the completed text string, or a tuple of (completed_text, reasoning_content) + if extract_reasoning is True or the model supports reasoning. Raises: InvalidResponseError: If the response from OpenAI is invalid or empty. @@ -235,6 +238,30 @@ async def inner(): logger.debug(f"Response content len: {len(content)}") verbose_debug(f"Response: {response}") + # Try to extract reasoning_content if requested or if model supports it + reasoning_content = "" + + # First check: look for reasoning_content in the message object's attributes or _kwargs + try: + # Look directly in the message object + if hasattr(response.choices[0].message, "reasoning_content"): + reasoning_content = response.choices[0].message.reasoning_content + logger.info("Found reasoning_content in message attributes") + except Exception as e: + logger.warning(f"Error checking for reasoning_content: {e}") + + # Log the reasoning content if found + if reasoning_content: + logger.info("Successfully extracted chain of thought reasoning") + logger.debug(f"Reasoning content: {reasoning_content}") + print(f"==========Reasoning content==========: {reasoning_content}") + elif extract_reasoning: + logger.info("No reasoning content found, but extraction was requested") + + # Return tuple if reasoning was requested or found + if extract_reasoning or reasoning_content: + return content, reasoning_content + return content @@ -325,6 +352,133 @@ async def nvidia_openai_complete( return result +async def deepseek_r1_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, +) -> tuple[str, str]: + """Complete a prompt using DeepSeek Reasoning-1 model. + + This model is specialized for reasoning and analytical tasks, making it + useful for tasks like node re-ranking in knowledge graphs where reasoning + about relevance and importance is required. + + Args: + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + keyword_extraction: Whether to extract keywords from the response. + **kwargs: Additional keyword arguments to pass to the OpenAI API. + + Returns: + A tuple containing (completed_text, reasoning_content) where reasoning_content + contains the chain of thought reasoning if available, otherwise empty string. + """ + if history_messages is None: + history_messages = [] + keyword_extraction = kwargs.pop("keyword_extraction", None) + + # Ensure we have the right configuration for DeepSeek API + base_url = os.environ.get("DEEPSEEK_API_BASE", "https://api.deepseek.com/v1") + api_key = os.environ.get("DEEPSEEK_API_KEY", None) + + # Create the OpenAI client directly to get more control over the response + print("\n===== CALLING DEEPSEEK REASONING MODEL =====") + client = create_openai_async_client( + api_key=api_key, + base_url=base_url, + client_configs=kwargs.pop("openai_client_configs", {}), + ) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Make direct API call to get full response + try: + response = await client.chat.completions.create( + model="deepseek-reasoner", messages=messages, **kwargs + ) + + # Extract the content from the response + content = ( + response.choices[0].message.content + if response.choices and hasattr(response.choices[0].message, "content") + else "" + ) + + # Try to extract reasoning content + reasoning_content = "" + + # Print the entire response for debugging + print("\n===== DEEPSEEK API RESPONSE STRUCTURE =====") + print(f"Response type: {type(response)}") + print(f"Response attributes: {dir(response)}") + if hasattr(response, "choices") and response.choices: + print(f"Message attributes: {dir(response.choices[0].message)}") + + # Try various ways to access reasoning_content + try: + # Direct access + if hasattr(response.choices[0].message, "reasoning_content"): + reasoning_content = response.choices[0].message.reasoning_content + print("\n===== FOUND REASONING CONTENT DIRECTLY =====") + + # Look in _kwargs dictionary + elif hasattr(response.choices[0].message, "_kwargs"): + kwargs_dict = response.choices[0].message._kwargs + if "reasoning_content" in kwargs_dict: + reasoning_content = kwargs_dict["reasoning_content"] + print("\n===== FOUND REASONING CONTENT IN _KWARGS =====") + + # Check if it's in the model_dump + elif hasattr(response, "model_dump"): + dump = response.model_dump() + print(f"\n===== MODEL DUMP KEYS =====\n{list(dump.keys())}") + if "choices" in dump and dump["choices"]: + choice = dump["choices"][0] + if "message" in choice: + message = choice["message"] + if "reasoning_content" in message: + reasoning_content = message["reasoning_content"] + print("\n===== FOUND REASONING CONTENT IN MODEL_DUMP =====") + + # If we have reasoning content, print it + if reasoning_content: + print("\n===== CHAIN OF THOUGHT REASONING FROM DEEPSEEK =====") + print(reasoning_content) + else: + # Try to extract reasoning from the content itself + # If the content includes reasoning before JSON, try to separate it + if not content.startswith("[") and "[" in content: + parts = content.split("[", 1) + if parts[0].strip(): + reasoning_content = parts[0].strip() + content = "[" + parts[1] + print("\n===== EXTRACTED REASONING FROM CONTENT =====") + print(reasoning_content) + + if not reasoning_content: + print("\n===== NO REASONING CONTENT FOUND =====") + except Exception as e: + print(f"\n===== ERROR EXTRACTING REASONING CONTENT =====\n{str(e)}") + + if keyword_extraction and content: + return locate_json_string_body_from_string(content), reasoning_content + + return content, reasoning_content + + except Exception as e: + print(f"\n===== ERROR CALLING DEEPSEEK API =====\n{str(e)}") + logger.error(f"Error calling DeepSeek API: {e}") + return f"Error: {str(e)}", "" + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), diff --git a/lightrag/operate.py b/lightrag/operate.py index 84e1364e5..17fb1cbdb 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -280,7 +280,7 @@ async def _merge_nodes_then_upsert( if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: - status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" + status_message = f"LLM merge N: {entity_name} | {num_new_fragment}+{num_fragment - num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: @@ -295,7 +295,7 @@ async def _merge_nodes_then_upsert( llm_response_cache, ) else: - status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment-num_new_fragment}" + status_message = f"Merge N: {entity_name} | {num_new_fragment}+{num_fragment - num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: @@ -421,7 +421,7 @@ async def _merge_edges_then_upsert( if num_fragment > 1: if num_fragment >= force_llm_summary_on_merge: - status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" + status_message = f"LLM merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment - num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: @@ -436,7 +436,7 @@ async def _merge_edges_then_upsert( llm_response_cache, ) else: - status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment-num_new_fragment}" + status_message = f"Merge E: {src_id} - {tgt_id} | {num_new_fragment}+{num_fragment - num_new_fragment}" logger.info(status_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: @@ -1372,7 +1372,15 @@ async def _get_node_data( {**n, "entity_name": k["entity_name"], "rank": d} for k, n, d in zip(results, node_datas, node_degrees) if n is not None - ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. + ] + + # Apply reasoning-based re-ranking if enabled + if query_param.use_reasoning_reranking and len(node_datas) > 1: + node_datas = await _rerank_nodes_with_reasoning( + query, node_datas, query_param.reasoning_model_name + ) + logger.info(f"Re-ranked {len(node_datas)} nodes using reasoning model") + # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst @@ -1474,6 +1482,174 @@ async def _get_node_data( return entities_context, relations_context, text_units_context +async def _rerank_nodes_with_reasoning( + query: str, node_datas: list[dict], model_name: str = "deepseek_r1" +) -> list[dict]: + """Re-rank nodes using a reasoning model based on their relevance to the query. + + Args: + query: The user's query + node_datas: List of node data dictionaries with entity information + model_name: Name of the reasoning model to use (defaults to DeepSeek R1) + + Returns: + Re-ordered list of node data based on reasoning model's ranking + """ + from lightrag.llm.openai import deepseek_r1_complete + + # Don't re-rank if we have 0 or 1 nodes (nothing to reorder) + if len(node_datas) <= 1: + return node_datas + + logger.info( + f"Re-ranking {len(node_datas)} nodes using reasoning model: {model_name}" + ) + + # Print original ranking order to terminal + print("\n===== ORIGINAL NODE RANKING =====") + for i, node in enumerate(node_datas[:10]): # Show top 10 for brevity + print(f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})") + desc = node.get("description", "") + if desc and len(desc) > 100: + desc = desc[:100] + "..." + print(f" Description: {desc}") + + # Log original ranking order + logger.info("Original node ranking:") + for i, node in enumerate(node_datas): + logger.info(f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})") + + logger.debug("Original node ranking (top 10):") + for i, node in enumerate(node_datas[:10]): + logger.debug( + f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})" + ) + + # Prepare nodes for reasoning model in a simplified format + nodes_for_ranking = [] + for node in node_datas: + nodes_for_ranking.append( + { + "entity_name": node["entity_name"], + "description": node.get("description", ""), + "entity_type": node.get("entity_type", "UNKNOWN"), + "degree": node.get("rank", 0), # Using the node degree as 'rank' here + } + ) + + # Get the reasoning prompt template from PROMPTS + reasoning_prompt = PROMPTS["node_reasoning_rerank"].format( + query=query, nodes=json.dumps(nodes_for_ranking, indent=2) + ) + + # Print that we're calling the reasoning model + print("\n===== CALLING REASONING MODEL =====") + print(f"Query: {query}") + print(f"Using model: {model_name}") + + try: + # Call the reasoning model + logger.info(f"Calling {model_name} for re-ranking...") + + response, reasoning_content = await deepseek_r1_complete( + prompt=reasoning_prompt, + system_prompt=PROMPTS["node_reasoning_system_prompt"], + ) + + # Print the chain of thought reasoning if available + if reasoning_content: + print("\n===== CHAIN OF THOUGHT REASONING =====") + print(reasoning_content) + logger.info("Chain of thought reasoning:") + logger.info(reasoning_content) + else: + print("\n===== NO CHAIN OF THOUGHT REASONING AVAILABLE =====") + + logger.debug(f"Raw reasoning model response: {response}") + + try: + # Parse the JSON response to get the ordered entity names + if not response.startswith("["): + # Try to find and extract the JSON array from the response + json_match = re.search(r"\[(.*?)\]", response, re.DOTALL) + if json_match: + array_str = json_match.group(0) + ranked_entities = json.loads(array_str) + logger.debug(f"Extracted JSON array: {array_str}") + else: + logger.warning( + "Could not extract JSON array from reasoning model response" + ) + print("\n===== COULD NOT EXTRACT RANKING FROM RESPONSE =====") + print("Using original ranking order") + return node_datas # Return original ranking if parsing fails + else: + ranked_entities = json.loads(response) + logger.debug(f"Parsed JSON response: {ranked_entities}") + + # Reorder node_datas based on the reasoning model's ranking + reordered_node_datas = [] + + # First add nodes in the order specified by the reasoning model + for entity_name in ranked_entities: + for node in node_datas: + if ( + node["entity_name"] == entity_name + and node not in reordered_node_datas + ): + reordered_node_datas.append(node) + logger.debug(f"Added {entity_name} to re-ranked list") + + # Then add any remaining nodes that weren't ranked (shouldn't happen, but as a fallback) + for node in node_datas: + if node not in reordered_node_datas: + reordered_node_datas.append(node) + logger.debug( + f"Added missing node {node['entity_name']} to re-ranked list" + ) + + # Log the new ranking order + logger.info("Re-ranked nodes (reasoning model order):") + for i, node in enumerate(reordered_node_datas): + logger.info( + f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})" + ) + + # Print the re-ranked order to terminal + print("\n===== RE-RANKED NODE ORDER =====") + for i, node in enumerate( + reordered_node_datas[:10] + ): # Show top 10 for brevity + print( + f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})" + ) + desc = node.get("description", "") + if desc and len(desc) > 100: + desc = desc[:100] + "..." + print(f" Description: {desc}") + + logger.debug("Re-ranked nodes - top 10:") + for i, node in enumerate(reordered_node_datas[:10]): + logger.debug( + f" {i + 1}. {node['entity_name']} (degree: {node.get('rank', 0)})" + ) + + return reordered_node_datas + + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Error parsing reasoning model response: {e}") + logger.debug(f"Raw response: {response}") + print("\n===== ERROR PARSING REASONING MODEL RESPONSE =====") + print(f"Error: {e}") + return node_datas # Return original ranking if parsing fails + + except Exception as e: + logger.warning(f"Error calling reasoning model for node re-ranking: {e}") + print("\n===== ERROR CALLING REASONING MODEL =====") + print(f"Error: {e}") + return node_datas # Return original ranking if reasoning model fails + + async def _find_most_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, diff --git a/lightrag/prompt.py b/lightrag/prompt.py index d6d46e1ff..95694a015 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -385,3 +385,28 @@ - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path - If you don't know the answer, just say so. Do not make anything up. - Do not include information not provided by the Data Sources.""" + +# Node re-ranking with reasoning model +PROMPTS["node_reasoning_system_prompt"] = ( + "You are a reasoning assistant that analyzes and ranks knowledge graph nodes based on their relevance to a query. Your responses must be in valid JSON format." +) + +PROMPTS["node_reasoning_rerank"] = """ +Query: {query} + +I need to determine which of the following knowledge graph nodes are most relevant and important +for answering this query. Please analyze each node and re-rank them from most to least relevant, +considering: + +1. Semantic relevance to the query +2. Information richness and completeness +3. How central this node is to addressing the query's needs +4. The node's relationships (higher degree nodes may contain more useful context) + +Nodes: +{nodes} + +YOUR RESPONSE MUST BE VALID JSON. Return an array of entity_name strings, ordered from most to least relevant. +Format your entire response as a valid JSON array like this: ["most_relevant_entity", "second_most_relevant", ...]. +Do not include any explanations or text outside the JSON array. +"""