Skip to content

Feature: Implemented an option to use DeepSeek Reasoner for intelligent semantic node re-ranking and retrieval #1400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions examples/lightrag_reasoning_deepseek_rerank_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python
import os
import asyncio
import logging
import sys

from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_embed, gpt_4o_mini_complete
from lightrag.kg.shared_storage import initialize_pipeline_status

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)

# Set LightRAG logger to DEBUG level
lightrag_logger = logging.getLogger("lightrag")
lightrag_logger.setLevel(logging.DEBUG)

# Set API credentials
if "DEEPSEEK_API_KEY" not in os.environ:
os.environ["DEEPSEEK_API_KEY"] = "YOUR DEEPSEEK API KEY FOR REASONING"
if "DEEPSEEK_API_BASE" not in os.environ:
os.environ["DEEPSEEK_API_BASE"] = "https://api.deepseek.com/v1"
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY FOR EMBEDDING"

WORKING_DIR = "./YOUR WORKING DIRECTORY"

if not os.path.exists(WORKING_DIR):
os.makedirs(WORKING_DIR)


async def initialize_rag():
"""Initialize LightRAG with the necessary configuration."""
rag = LightRAG(
working_dir=WORKING_DIR,
embedding_func=openai_embed,
llm_model_func=gpt_4o_mini_complete,
)

await rag.initialize_storages()
await initialize_pipeline_status()
return rag


def main():
# Initialize LightRAG
rag = asyncio.run(initialize_rag())

print("\n===== LIGHTRAG REASONING RE-RANKING DEMO =====")
print("This demo shows the step-by-step reasoning process for re-ranking nodes")
print(
"You'll see: 1) Original node ranking, 2) Reasoning chain of thought, 3) Re-ranked nodes"
)

print("\n===== STANDARD RANKING (NO REASONING) =====")
query = "Why does Scrooge manage to have a happy ending?"
standard_result = rag.query(query, param=QueryParam(mode="local"))
print(f"\n{standard_result}")

print("\n===== WITH REASONING RE-RANKING =====")
print("Now the same query but with reasoning-based re-ranking of nodes:")
print(
"Watch for the ORIGINAL NODE RANKING, CHAIN OF THOUGHT REASONING, and RE-RANKED NODE ORDER"
)
reasoning_result = rag.query(
query,
param=QueryParam(
mode="local",
use_reasoning_reranking=True,
reasoning_model_name="deepseek_r1",
),
)
print("\n===== FINAL ANSWER WITH REASONING RE-RANKING =====")
print(f"{reasoning_result}")

print("\n===== HYBRID MODE WITH REASONING RE-RANKING =====")
complex_query = "How does Scrooge make a lot of money in the end of the story?"
print("Using a different query in hybrid mode with reasoning re-ranking:")
hybrid_result = rag.query(
complex_query,
param=QueryParam(
mode="hybrid",
use_reasoning_reranking=True,
reasoning_model_name="deepseek_r1",
),
)
print("\n===== FINAL ANSWER =====")
print(f"{hybrid_result}")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions lightrag/api/routers/query_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ class QueryRequest(BaseModel):
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.",
)

use_reasoning_reranking: Optional[bool] = Field(
default=None,
description="Whether to use reasoning model-based re-ranking for nodes. When enabled, a reasoning model will evaluate query-node relevance beyond vector similarity.",
)

reasoning_model_name: Optional[str] = Field(
default=None,
description="The name of the reasoning model to use for node re-ranking. Options: 'deepseek_r1', 'gpt_4o', 'gpt_4o_mini', or any other model supported by the system.",
)

@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
Expand Down
10 changes: 10 additions & 0 deletions lightrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ class QueryParam:
This allows using different models for different query modes.
"""

use_reasoning_reranking: bool = False
"""Whether to use reasoning model-based re-ranking for nodes.
When enabled, a reasoning model will evaluate query-node relevance beyond vector similarity.
"""

reasoning_model_name: str = "deepseek_r1"
"""The name of the reasoning model to use for node re-ranking.
Options: "deepseek_r1", "gpt_4o", "gpt_4o_mini", or any other model supported by the system.
"""


@dataclass
class StorageNameSpace(ABC):
Expand Down
160 changes: 157 additions & 3 deletions lightrag/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_openai_async_client(
"default_headers": default_headers,
"api_key": api_key,
}

os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1"
if base_url is not None:
merged_configs["base_url"] = base_url
else:
Expand All @@ -102,8 +102,9 @@ async def openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
extract_reasoning: bool = False,
**kwargs: Any,
) -> str:
) -> str | tuple[str, str]:
"""Complete a prompt using OpenAI's API with caching support.

Args:
Expand All @@ -113,6 +114,7 @@ async def openai_complete_if_cache(
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
extract_reasoning: Whether to extract and return reasoning content when available.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
Expand All @@ -122,7 +124,8 @@ async def openai_complete_if_cache(
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.

Returns:
The completed text or an async iterator of text chunks if streaming.
Either the completed text string, or a tuple of (completed_text, reasoning_content)
if extract_reasoning is True or the model supports reasoning.

Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
Expand Down Expand Up @@ -235,6 +238,30 @@ async def inner():
logger.debug(f"Response content len: {len(content)}")
verbose_debug(f"Response: {response}")

# Try to extract reasoning_content if requested or if model supports it
reasoning_content = ""

# First check: look for reasoning_content in the message object's attributes or _kwargs
try:
# Look directly in the message object
if hasattr(response.choices[0].message, "reasoning_content"):
reasoning_content = response.choices[0].message.reasoning_content
logger.info("Found reasoning_content in message attributes")
except Exception as e:
logger.warning(f"Error checking for reasoning_content: {e}")

# Log the reasoning content if found
if reasoning_content:
logger.info("Successfully extracted chain of thought reasoning")
logger.debug(f"Reasoning content: {reasoning_content}")
print(f"==========Reasoning content==========: {reasoning_content}")
elif extract_reasoning:
logger.info("No reasoning content found, but extraction was requested")

# Return tuple if reasoning was requested or found
if extract_reasoning or reasoning_content:
return content, reasoning_content

return content


Expand Down Expand Up @@ -325,6 +352,133 @@ async def nvidia_openai_complete(
return result


async def deepseek_r1_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> tuple[str, str]:
"""Complete a prompt using DeepSeek Reasoning-1 model.

This model is specialized for reasoning and analytical tasks, making it
useful for tasks like node re-ranking in knowledge graphs where reasoning
about relevance and importance is required.

Args:
prompt: The prompt to complete.
system_prompt: Optional system prompt to include.
history_messages: Optional list of previous messages in the conversation.
keyword_extraction: Whether to extract keywords from the response.
**kwargs: Additional keyword arguments to pass to the OpenAI API.

Returns:
A tuple containing (completed_text, reasoning_content) where reasoning_content
contains the chain of thought reasoning if available, otherwise empty string.
"""
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)

# Ensure we have the right configuration for DeepSeek API
base_url = os.environ.get("DEEPSEEK_API_BASE", "https://api.deepseek.com/v1")
api_key = os.environ.get("DEEPSEEK_API_KEY", None)

# Create the OpenAI client directly to get more control over the response
print("\n===== CALLING DEEPSEEK REASONING MODEL =====")
client = create_openai_async_client(
api_key=api_key,
base_url=base_url,
client_configs=kwargs.pop("openai_client_configs", {}),
)

# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

# Make direct API call to get full response
try:
response = await client.chat.completions.create(
model="deepseek-reasoner", messages=messages, **kwargs
)

# Extract the content from the response
content = (
response.choices[0].message.content
if response.choices and hasattr(response.choices[0].message, "content")
else ""
)

# Try to extract reasoning content
reasoning_content = ""

# Print the entire response for debugging
print("\n===== DEEPSEEK API RESPONSE STRUCTURE =====")
print(f"Response type: {type(response)}")
print(f"Response attributes: {dir(response)}")
if hasattr(response, "choices") and response.choices:
print(f"Message attributes: {dir(response.choices[0].message)}")

# Try various ways to access reasoning_content
try:
# Direct access
if hasattr(response.choices[0].message, "reasoning_content"):
reasoning_content = response.choices[0].message.reasoning_content
print("\n===== FOUND REASONING CONTENT DIRECTLY =====")

# Look in _kwargs dictionary
elif hasattr(response.choices[0].message, "_kwargs"):
kwargs_dict = response.choices[0].message._kwargs
if "reasoning_content" in kwargs_dict:
reasoning_content = kwargs_dict["reasoning_content"]
print("\n===== FOUND REASONING CONTENT IN _KWARGS =====")

# Check if it's in the model_dump
elif hasattr(response, "model_dump"):
dump = response.model_dump()
print(f"\n===== MODEL DUMP KEYS =====\n{list(dump.keys())}")
if "choices" in dump and dump["choices"]:
choice = dump["choices"][0]
if "message" in choice:
message = choice["message"]
if "reasoning_content" in message:
reasoning_content = message["reasoning_content"]
print("\n===== FOUND REASONING CONTENT IN MODEL_DUMP =====")

# If we have reasoning content, print it
if reasoning_content:
print("\n===== CHAIN OF THOUGHT REASONING FROM DEEPSEEK =====")
print(reasoning_content)
else:
# Try to extract reasoning from the content itself
# If the content includes reasoning before JSON, try to separate it
if not content.startswith("[") and "[" in content:
parts = content.split("[", 1)
if parts[0].strip():
reasoning_content = parts[0].strip()
content = "[" + parts[1]
print("\n===== EXTRACTED REASONING FROM CONTENT =====")
print(reasoning_content)

if not reasoning_content:
print("\n===== NO REASONING CONTENT FOUND =====")
except Exception as e:
print(f"\n===== ERROR EXTRACTING REASONING CONTENT =====\n{str(e)}")

if keyword_extraction and content:
return locate_json_string_body_from_string(content), reasoning_content

return content, reasoning_content

except Exception as e:
print(f"\n===== ERROR CALLING DEEPSEEK API =====\n{str(e)}")
logger.error(f"Error calling DeepSeek API: {e}")
return f"Error: {str(e)}", ""


@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
Expand Down
Loading