1
1
import asyncio
2
2
import json
3
- import logging
3
+ from typing import TypedDict , Annotated
4
4
5
- from langgraph .graph .state import CompiledStateGraph
6
- from langgraph .prebuilt import tools_condition , ToolNode
7
- from sse_starlette .sse import EventSourceResponse
8
- from typing import Annotated
9
- from typing_extensions import TypedDict
5
+ import rich
6
+ from fastapi import APIRouter
7
+ from langgraph .graph import add_messages
8
+ from langgraph .graph .state import CompiledStateGraph , StateGraph
9
+ from langgraph .prebuilt import ToolNode , tools_condition
10
+ from sse_starlette import EventSourceResponse
10
11
11
- from fastapi import FastAPI , Request
12
- from langgraph .graph import StateGraph
13
- from langgraph .graph .message import add_messages
12
+ from chatchat .server .agent .tools_factory import search_internet , search_youtube
13
+ from chatchat .server .api_server .api_schemas import AgentChatInput
14
+ from chatchat .server .utils import create_agent_models
15
+ from chatchat .utils import build_logger
14
16
15
- from chatchat .server .agent .tools_factory import search_internet
16
- from chatchat .server .utils import create_agent_models , add_tools_if_not_exists
17
17
18
- app = FastAPI ()
19
- logger = logging .getLogger ("uvicorn.error" )
18
+ logger = build_logger ()
20
19
21
-
22
- class ClientDisconnectException (Exception ):
23
- pass
20
+ chat_router = APIRouter (prefix = "/v1" , tags = ["Agent 对话接口" ])
24
21
25
22
26
23
def get_chatbot () -> CompiledStateGraph :
27
24
class State (TypedDict ):
28
25
messages : Annotated [list , add_messages ]
29
26
30
27
llm = create_agent_models (configs = None ,
31
- model = "qwen2.5-instruct " ,
28
+ model = "hunyuan-turbo " ,
32
29
max_tokens = None ,
33
30
temperature = 0 ,
34
31
stream = True )
35
32
36
- tools = add_tools_if_not_exists ( tools_provides = [], tools_need_append = [ search_internet ])
33
+ tools = [ search_internet , search_youtube ]
37
34
llm_with_tools = llm .bind_tools (tools )
38
35
39
36
def chatbot (state : State ):
@@ -57,16 +54,17 @@ def chatbot(state: State):
57
54
return graph
58
55
59
56
60
- @app .post ("/stream" )
61
- async def openai_stream_output (request : Request ):
57
+ @chat_router .post ("/chat/completions" )
58
+ async def openai_stream_output (
59
+ body : AgentChatInput
60
+ ):
61
+ rich .print (body )
62
+
62
63
async def generator ():
63
64
graph = get_chatbot ()
64
- inputs = {"role" : "user" , "content" : "Please introduce Trump based on the Internet search results." }
65
65
try :
66
- async for event in graph .astream (input = {"messages" : inputs }, stream_mode = "updates" ):
67
- disconnected = await request .is_disconnected ()
68
- if disconnected :
69
- raise ClientDisconnectException ("Client disconnected" )
66
+ # async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
67
+ async for event in graph .astream (input = {"messages" : body .messages }, stream_mode = "updates" ):
70
68
yield str (event )
71
69
except asyncio .exceptions .CancelledError :
72
70
logger .warning ("Streaming progress has been interrupted by user." )
@@ -77,7 +75,3 @@ async def generator():
77
75
return
78
76
79
77
return EventSourceResponse (generator ())
80
-
81
- if __name__ == "__main__" :
82
- import uvicorn
83
- uvicorn .run (app , host = "127.0.0.1" , port = 8000 )
0 commit comments