Skip to content

Commit 63f08b7

Browse files
author
yuehuazhang
committed
feature:chat api
1 parent 7c0760d commit 63f08b7

File tree

5 files changed

+45
-420
lines changed

5 files changed

+45
-420
lines changed

chatchat-server/chatchat/server/agent/graphs_factory/testinvoke.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

chatchat-server/chatchat/server/api_server/api_schemas.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,30 @@ class Config:
3030
extra = "allow"
3131

3232

33-
class OpenAIChatInput(OpenAIBaseInput):
33+
class AgentChatInput(BaseModel):
3434
messages: List[ChatCompletionMessageParam]
3535
model: str = get_default_llm()
36-
frequency_penalty: Optional[float] = None
37-
function_call: Optional[completion_create_params.FunctionCall] = None
38-
functions: List[completion_create_params.Function] = None
39-
logit_bias: Optional[Dict[str, int]] = None
40-
logprobs: Optional[bool] = None
41-
max_tokens: Optional[int] = None
42-
n: Optional[int] = None
43-
presence_penalty: Optional[float] = None
44-
response_format: completion_create_params.ResponseFormat = None
45-
seed: Optional[int] = None
46-
stop: Union[Optional[str], List[str]] = None
47-
stream: Optional[bool] = None
36+
graph: str
37+
thread_id: int
4838
temperature: Optional[float] = Settings.model_settings.TEMPERATURE
39+
max_completion_tokens: Optional[int] = None
4940
tool_choice: Optional[Union[ChatCompletionToolChoiceOptionParam, str]] = None
5041
tools: List[Union[ChatCompletionToolParam, str]] = None
51-
top_logprobs: Optional[int] = None
52-
top_p: Optional[float] = None
42+
stream: Optional[bool] = True
43+
stream_method: Optional[Literal["streamlog", "node", "invoke"]] = "streamlog"
44+
# frequency_penalty: Optional[float] = None
45+
# function_call: Optional[completion_create_params.FunctionCall] = None
46+
# functions: List[completion_create_params.Function] = None
47+
# logit_bias: Optional[Dict[str, int]] = None
48+
# logprobs: Optional[bool] = None
49+
# max_tokens: Optional[int] = None
50+
# n: Optional[int] = None
51+
# presence_penalty: Optional[float] = None
52+
# response_format: completion_create_params.ResponseFormat = None
53+
# seed: Optional[int] = None
54+
# stop: Union[Optional[str], List[str]] = None
55+
# top_logprobs: Optional[int] = None
56+
# top_p: Optional[float] = None
5357

5458

5559
class OpenAIEmbeddingsInput(OpenAIBaseInput):
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,36 @@
11
import asyncio
22
import json
3-
import logging
3+
from typing import TypedDict, Annotated
44

5-
from langgraph.graph.state import CompiledStateGraph
6-
from langgraph.prebuilt import tools_condition, ToolNode
7-
from sse_starlette.sse import EventSourceResponse
8-
from typing import Annotated
9-
from typing_extensions import TypedDict
5+
import rich
6+
from fastapi import APIRouter
7+
from langgraph.graph import add_messages
8+
from langgraph.graph.state import CompiledStateGraph, StateGraph
9+
from langgraph.prebuilt import ToolNode, tools_condition
10+
from sse_starlette import EventSourceResponse
1011

11-
from fastapi import FastAPI, Request
12-
from langgraph.graph import StateGraph
13-
from langgraph.graph.message import add_messages
12+
from chatchat.server.agent.tools_factory import search_internet, search_youtube
13+
from chatchat.server.api_server.api_schemas import AgentChatInput
14+
from chatchat.server.utils import create_agent_models
15+
from chatchat.utils import build_logger
1416

15-
from chatchat.server.agent.tools_factory import search_internet
16-
from chatchat.server.utils import create_agent_models, add_tools_if_not_exists
1717

18-
app = FastAPI()
19-
logger = logging.getLogger("uvicorn.error")
18+
logger = build_logger()
2019

21-
22-
class ClientDisconnectException(Exception):
23-
pass
20+
chat_router = APIRouter(prefix="/v1", tags=["Agent 对话接口"])
2421

2522

2623
def get_chatbot() -> CompiledStateGraph:
2724
class State(TypedDict):
2825
messages: Annotated[list, add_messages]
2926

3027
llm = create_agent_models(configs=None,
31-
model="qwen2.5-instruct",
28+
model="hunyuan-turbo",
3229
max_tokens=None,
3330
temperature=0,
3431
stream=True)
3532

36-
tools = add_tools_if_not_exists(tools_provides=[], tools_need_append=[search_internet])
33+
tools = [search_internet, search_youtube]
3734
llm_with_tools = llm.bind_tools(tools)
3835

3936
def chatbot(state: State):
@@ -57,16 +54,17 @@ def chatbot(state: State):
5754
return graph
5855

5956

60-
@app.post("/stream")
61-
async def openai_stream_output(request: Request):
57+
@chat_router.post("/chat/completions")
58+
async def openai_stream_output(
59+
body: AgentChatInput
60+
):
61+
rich.print(body)
62+
6263
async def generator():
6364
graph = get_chatbot()
64-
inputs = {"role": "user", "content": "Please introduce Trump based on the Internet search results."}
6565
try:
66-
async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
67-
disconnected = await request.is_disconnected()
68-
if disconnected:
69-
raise ClientDisconnectException("Client disconnected")
66+
# async for event in graph.astream(input={"messages": inputs}, stream_mode="updates"):
67+
async for event in graph.astream(input={"messages": body.messages}, stream_mode="updates"):
7068
yield str(event)
7169
except asyncio.exceptions.CancelledError:
7270
logger.warning("Streaming progress has been interrupted by user.")
@@ -77,7 +75,3 @@ async def generator():
7775
return
7876

7977
return EventSourceResponse(generator())
80-
81-
if __name__ == "__main__":
82-
import uvicorn
83-
uvicorn.run(app, host="127.0.0.1", port=8000)

0 commit comments

Comments
 (0)