11import asyncio
22import json
3- import logging
3+ from typing import TypedDict , Annotated
44
5- from langgraph .graph .state import CompiledStateGraph
6- from langgraph .prebuilt import tools_condition , ToolNode
7- from sse_starlette .sse import EventSourceResponse
8- from typing import Annotated
9- from typing_extensions import TypedDict
5+ import rich
6+ from fastapi import APIRouter
7+ from langgraph .graph import add_messages
8+ from langgraph .graph .state import CompiledStateGraph , StateGraph
9+ from langgraph .prebuilt import ToolNode , tools_condition
10+ from sse_starlette import EventSourceResponse
1011
11- from fastapi import FastAPI , Request
12- from langgraph .graph import StateGraph
13- from langgraph .graph .message import add_messages
12+ from chatchat .server .agent .tools_factory import search_internet , search_youtube
13+ from chatchat .server .api_server .api_schemas import AgentChatInput
14+ from chatchat .server .utils import create_agent_models
15+ from chatchat .utils import build_logger
1416
15- from chatchat .server .agent .tools_factory import search_internet
16- from chatchat .server .utils import create_agent_models , add_tools_if_not_exists
1717
18- app = FastAPI ()
19- logger = logging .getLogger ("uvicorn.error" )
18+ logger = build_logger ()
2019
21-
22- class ClientDisconnectException (Exception ):
23- pass
20+ chat_router = APIRouter (prefix = "/v1" , tags = ["Agent 对话接口" ])
2421
2522
2623def get_chatbot () -> CompiledStateGraph :
2724 class State (TypedDict ):
2825 messages : Annotated [list , add_messages ]
2926
3027 llm = create_agent_models (configs = None ,
31- model = "qwen2.5-instruct " ,
28+ model = "hunyuan-turbo " ,
3229 max_tokens = None ,
3330 temperature = 0 ,
3431 stream = True )
3532
36- tools = add_tools_if_not_exists ( tools_provides = [], tools_need_append = [ search_internet ])
33+ tools = [ search_internet , search_youtube ]
3734 llm_with_tools = llm .bind_tools (tools )
3835
3936 def chatbot (state : State ):
@@ -57,16 +54,17 @@ def chatbot(state: State):
5754 return graph
5855
5956
60- @app .post ("/stream" )
61- async def openai_stream_output (request : Request ):
57+ @chat_router .post ("/chat/completions" )
58+ async def openai_stream_output (
59+ body : AgentChatInput
60+ ):
61+ rich .print (body )
62+
6263 async def generator ():
6364 graph = get_chatbot ()
64- inputs = {"role" : "user" , "content" : "Please introduce Trump based on the Internet search results." }
6565 try :
66- async for event in graph .astream (input = {"messages" : inputs }, stream_mode = "updates" ):
67- disconnected = await request .is_disconnected ()
68- if disconnected :
69- raise ClientDisconnectException ("Client disconnected" )
66+ # async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
67+ async for event in graph .astream (input = {"messages" : body .messages }, stream_mode = "updates" ):
7068 yield str (event )
7169 except asyncio .exceptions .CancelledError :
7270 logger .warning ("Streaming progress has been interrupted by user." )
@@ -77,7 +75,3 @@ async def generator():
7775 return
7876
7977 return EventSourceResponse (generator ())
80-
81- if __name__ == "__main__" :
82- import uvicorn
83- uvicorn .run (app , host = "127.0.0.1" , port = 8000 )
0 commit comments