Skip to content

Commit e946e61

Browse files
author
yuehuazhang
committed
feature:fastapi
1 parent 2586722 commit e946e61

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
import json
3+
import logging
4+
5+
from langgraph.graph.state import CompiledStateGraph
6+
from langgraph.prebuilt import tools_condition, ToolNode
7+
from sse_starlette.sse import EventSourceResponse
8+
from typing import Annotated
9+
from typing_extensions import TypedDict
10+
11+
from fastapi import FastAPI, Request
12+
from langgraph.graph import StateGraph
13+
from langgraph.graph.message import add_messages
14+
15+
from chatchat.server.agent.tools_factory import search_internet
16+
from chatchat.server.utils import create_agent_models, add_tools_if_not_exists
17+
18+
app = FastAPI()
19+
logger = logging.getLogger("uvicorn.error")
20+
21+
22+
class ClientDisconnectException(Exception):
23+
pass
24+
25+
26+
def get_chatbot() -> CompiledStateGraph:
27+
class State(TypedDict):
28+
messages: Annotated[list, add_messages]
29+
30+
llm = create_agent_models(configs=None,
31+
model="Qwen2.5-72B-Instruct",
32+
max_tokens=None,
33+
temperature=0,
34+
stream=True)
35+
36+
tools = add_tools_if_not_exists(tools_provides=[], tools_need_append=[search_internet])
37+
llm_with_tools = llm.bind_tools(tools)
38+
39+
def chatbot(state: State):
40+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
41+
42+
graph_builder = StateGraph(State)
43+
graph_builder.add_node("chatbot", chatbot)
44+
45+
tool_node = ToolNode(tools=tools)
46+
graph_builder.add_node("tools", tool_node)
47+
48+
graph_builder.add_conditional_edges(
49+
"chatbot",
50+
tools_condition,
51+
)
52+
# Any time a tool is called, we return to the chatbot to decide the next step
53+
graph_builder.add_edge("tools", "chatbot")
54+
graph_builder.set_entry_point("chatbot")
55+
graph = graph_builder.compile()
56+
57+
return graph
58+
59+
60+
@app.get("/stream")
61+
async def openai_stream_output(request: Request):
62+
async def generator():
63+
graph = get_chatbot()
64+
inputs = {"role": "user", "content": "Please introduce Trump based on the Internet search results."}
65+
try:
66+
async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
67+
disconnected = await request.is_disconnected()
68+
if disconnected:
69+
raise ClientDisconnectException("Client disconnected")
70+
yield str(event)
71+
except asyncio.exceptions.CancelledError:
72+
logger.warning("Streaming progress has been interrupted by user.")
73+
return
74+
except Exception as e:
75+
logger.error(f"Error in stream: {e}")
76+
yield {"data": json.dumps({"error": str(e)})}
77+
return
78+
79+
return EventSourceResponse(generator())
80+
81+
if __name__ == "__main__":
82+
import uvicorn
83+
uvicorn.run(app, host="127.0.0.1", port=8000)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import TypedDict, Annotated
2+
3+
from langgraph.checkpoint.memory import MemorySaver
4+
from langgraph.graph import END, StateGraph, add_messages
5+
from langgraph.graph.state import CompiledStateGraph
6+
from langgraph.prebuilt import ToolNode, tools_condition
7+
from pydantic import BaseModel
8+
9+
from chatchat.server.agent.tools_factory import search_internet
10+
from chatchat.server.utils import create_agent_models, add_tools_if_not_exists
11+
12+
13+
class State(BaseModel):
14+
n: int
15+
16+
17+
def get_chatbot() -> CompiledStateGraph:
18+
class State(TypedDict):
19+
messages: Annotated[list, add_messages]
20+
21+
llm = create_agent_models(configs=None,
22+
model="Qwen2.5-72B-Instruct",
23+
max_tokens=None,
24+
temperature=0,
25+
stream=True)
26+
27+
tools = add_tools_if_not_exists(tools_provides=[], tools_need_append=[search_internet])
28+
llm_with_tools = llm.bind_tools(tools)
29+
30+
def chatbot(state: State):
31+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
32+
33+
graph_builder = StateGraph(State)
34+
graph_builder.add_node("chatbot", chatbot)
35+
36+
tool_node = ToolNode(tools=tools)
37+
graph_builder.add_node("tools", tool_node)
38+
39+
graph_builder.add_conditional_edges(
40+
"chatbot",
41+
tools_condition,
42+
)
43+
# Any time a tool is called, we return to the chatbot to decide the next step
44+
graph_builder.add_edge("tools", "chatbot")
45+
graph_builder.set_entry_point("chatbot")
46+
graph = graph_builder.compile()
47+
48+
return graph
49+
50+
51+
if __name__ == "__main__":
52+
config = {"configurable": {"thread_id": "1"}}
53+
graph = get_chatbot()
54+
inputs = {"role": "user", "content": "Please introduce Trump based on the Internet search results."}
55+
56+
print(graph.invoke(input={"messages": inputs}, stream_mode="updates"))

0 commit comments

Comments
 (0)