Skip to content

Commit 604e4d5

Browse files
authored
Merge pull request #28 from chatchat-space/dev
Dev
2 parents 9aa6c4d + 668d151 commit 604e4d5

23 files changed

+319
-266
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .reflexion import reflexion
2-
from .article_generation import article_generation
31
from .base_rag import BaseRagGraph
42
from .base_graph import BaseAgentGraph
5-
from .text_to_sql import TextToSQLGraph
63
from .plan_and_execute import PlanExecuteGraph
4+
from .reflexion import ReflexionGraph
5+
from .text_to_sql import TextToSQLGraph
6+
from .article_generation import article_generation

libs/chatchat-server/chatchat/server/agent/graphs_factory/base_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class BaseAgentGraph(Graph):
1616
name = "base_agent"
1717
label = "agent"
18-
title = "聊天机器人[纯净版]"
18+
title = "聊天机器人"
1919

2020
def __init__(self,
2121
llm: ChatOpenAI,

libs/chatchat-server/chatchat/server/agent/graphs_factory/reflexion.py

Lines changed: 211 additions & 178 deletions
Large diffs are not rendered by default.

libs/chatchat-server/chatchat/server/agent/graphs_factory/text_to_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
class TextToSQLGraph(Graph):
1717
name = "text_to_sql"
1818
label = "agent"
19-
title = "数据库查询机器人"
19+
title = "数据库查询机器人[Beta]"
2020

2121
def __init__(self,
2222
llm: ChatOpenAI,

libs/chatchat-server/chatchat/server/agent/tools_factory/amap_poi_search.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import requests
2-
from chatchat.server.pydantic_v1 import Field
2+
from pydantic import Field
3+
34
from .tools_registry import BaseToolOutput, regist_tool
45
from chatchat.server.utils import get_tool_config
56

67
BASE_URL = "https://restapi.amap.com/v5/place/text"
78

9+
810
def amap_poi_search_engine(keywords: str,types: str,config: dict):
911
API_KEY = config["api_key"]
1012
params = {
@@ -19,10 +21,9 @@ def amap_poi_search_engine(keywords: str,types: str,config: dict):
1921
return {"error": "API request failed"}
2022

2123

22-
2324
@regist_tool(title="高德地图POI搜索")
2425
def amap_poi_search(location: str = Field(description="'实际地名'或者'具体的地址',不能使用简称或者别称"),
25-
types: str = Field(description="POI类型,比如商场、学校、医院等等")):
26+
types: str = Field(description="POI类型,比如商场、学校、医院等等")):
2627
""" A wrapper that uses Amap to search."""
2728
tool_config = get_tool_config("amap")
2829
return BaseToolOutput(amap_poi_search_engine(keywords=location,types=types,config=tool_config))

libs/chatchat-server/chatchat/server/agent/tools_factory/amap_weather.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import requests
2-
from chatchat.server.pydantic_v1 import Field
2+
from pydantic import Field
33
from .tools_registry import BaseToolOutput, regist_tool
44
from chatchat.server.utils import get_tool_config
55

66
BASE_DISTRICT_URL = "https://restapi.amap.com/v3/config/district"
77
BASE_WEATHER_URL = "https://restapi.amap.com/v3/weather/weatherInfo"
88

9+
910
def get_adcode(city: str, config: dict) -> str:
1011
"""Get the adcode"""
1112
API_KEY = config["api_key"]
@@ -22,6 +23,7 @@ def get_adcode(city: str, config: dict) -> str:
2223
else:
2324
return None
2425

26+
2527
def get_weather(adcode: str, config: dict) -> dict:
2628
"""Get weather information."""
2729
API_KEY = config["api_key"]
@@ -36,6 +38,7 @@ def get_weather(adcode: str, config: dict) -> dict:
3638
else:
3739
return {"error": "API request failed"}
3840

41+
3942
@regist_tool(title="高德地图天气查询")
4043
def amap_weather(city: str = Field(description="城市名")):
4144
"""A wrapper that uses Amap to get weather information."""

libs/chatchat-server/chatchat/server/agent/tools_factory/calculate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from chatchat.server.pydantic_v1 import Field
1+
from pydantic import Field
22

33
from .tools_registry import BaseToolOutput, regist_tool
44

libs/chatchat-server/chatchat/server/agent/tools_factory/querysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
22
from langchain_community.utilities.sql_database import SQLDatabase
3-
from chatchat.server.pydantic_v1 import Field
3+
from pydantic import Field
44
from chatchat.server.utils import get_tool_config, build_logger
55
from .tools_registry import BaseToolOutput, regist_tool
66

libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
from typing import Dict, List
2-
2+
from pydantic import Field
33
from langchain.docstore.document import Document
4-
from langchain.text_splitter import RecursiveCharacterTextSplitter
5-
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
6-
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
7-
from langchain_community.utilities.searx_search import SearxSearchWrapper
8-
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
9-
from markdownify import markdownify
10-
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
114

125
from chatchat.settings import Settings
13-
from chatchat.server.pydantic_v1 import Field
146
from chatchat.server.utils import get_tool_config
157

16-
from .tools_registry import BaseToolOutput, regist_tool, format_context
8+
from .tools_registry import BaseToolOutput, regist_tool
9+
10+
11+
def duckduckgo_search(text, top_k: int):
12+
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
13+
search = DuckDuckGoSearchAPIWrapper()
14+
return search.results(text, top_k)
15+
16+
17+
def tavily_search(text, config, top_k: int):
18+
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
19+
search = TavilySearchAPIWrapper(
20+
tavily_api_key=config["tavily_key"],
21+
)
22+
return search.results(text, top_k)
1723

1824

1925
def searx_search(text, config, top_k: int):
26+
from langchain_community.utilities.searx_search import SearxSearchWrapper
2027
search = SearxSearchWrapper(
2128
searx_host=config["host"],
2229
engines=config["engines"],
@@ -27,31 +34,24 @@ def searx_search(text, config, top_k: int):
2734

2835

2936
def bing_search(text, config, top_k: int):
37+
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
3038
search = BingSearchAPIWrapper(
3139
bing_subscription_key=config["bing_key"],
3240
bing_search_url=config["bing_search_url"],
3341
)
3442
return search.results(text, top_k)
3543

3644

37-
def tavily_search(text, config, top_k: int):
38-
search = TavilySearchAPIWrapper(
39-
tavily_api_key=config["tavily_key"],
40-
)
41-
return search.results(text, top_k)
42-
43-
44-
def duckduckgo_search(text, config, top_k: int):
45-
search = DuckDuckGoSearchAPIWrapper()
46-
return search.results(text, top_k)
47-
48-
4945
def metaphor_search(
5046
text: str,
5147
config: dict,
52-
top_k:int
48+
top_k: int
5349
) -> List[Dict]:
5450
from metaphor_python import Metaphor
51+
from langchain.text_splitter import RecursiveCharacterTextSplitter
52+
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
53+
54+
from markdownify import markdownify
5555

5656
client = Metaphor(config["metaphor_api_key"])
5757
search = client.search(text, num_results=top_k, use_autoprompt=True)
@@ -101,20 +101,6 @@ def metaphor_search(
101101
}
102102

103103

104-
def search_result2docs(search_results) -> List[Document]:
105-
docs = []
106-
for result in search_results:
107-
doc = Document(
108-
page_content=result["snippet"] if "snippet" in result.keys() else "",
109-
metadata={
110-
"source": result["link"] if "link" in result.keys() else "",
111-
"filename": result["title"] if "title" in result.keys() else "",
112-
},
113-
)
114-
docs.append(doc)
115-
return docs
116-
117-
118104
def search_engine(query: str, top_k: int = 0, engine_name: str = "", config: dict = {}):
119105
config = config or get_tool_config("search_internet")
120106
if top_k <= 0:
@@ -124,12 +110,10 @@ def search_engine(query: str, top_k: int = 0, engine_name: str = "", config: dic
124110
results = search_engine_use(
125111
text=query, config=config["search_engine_config"][engine_name], top_k=top_k
126112
)
127-
docs = [x for x in search_result2docs(results) if x.page_content and x.page_content.strip()]
128-
return {"docs": docs, "search_engine": engine_name}
113+
return results
129114

130115

131116
@regist_tool(title="互联网搜索")
132117
def search_internet(query: str = Field(description="query for Internet search")):
133118
"""Use this tool to use bing search engine to search the internet and get information."""
134-
# return BaseToolOutput(search_engine(query=query), format=format_context)
135119
return BaseToolOutput(search_engine(query=query))

libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from chatchat.server.agent.tools_factory.tools_registry import (
55
BaseToolOutput,
66
regist_tool,
7-
format_context,
87
)
98
from chatchat.server.knowledge_base.kb_api import list_kbs
109
from chatchat.server.knowledge_base.kb_doc_api import search_docs

libs/chatchat-server/chatchat/server/agent/tools_factory/shell.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# LangChain 的 Shell 工具
2+
from pydantic import Field
23
from langchain_community.tools import ShellTool
3-
4-
from chatchat.server.pydantic_v1 import Field
5-
64
from .tools_registry import BaseToolOutput, regist_tool
75

86

libs/chatchat-server/chatchat/server/agent/tools_factory/text2image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import openai
88
from PIL import Image
9+
from pydantic import Field
910

1011
from chatchat.settings import Settings
11-
from chatchat.server.pydantic_v1 import Field
1212
from chatchat.server.utils import MsgType, get_tool_config, get_model_info
1313

1414
from .tools_registry import BaseToolOutput, regist_tool

libs/chatchat-server/chatchat/server/agent/tools_factory/text2promql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from requests.auth import HTTPBasicAuth
33
from urllib.parse import parse_qs
44
from typing import Optional
5+
from pydantic import Field
56

67
from langchain_core.prompts import ChatPromptTemplate
78
from langchain_core.output_parsers import StrOutputParser
89
from langchain_core.runnables import RunnablePassthrough
910

10-
from chatchat.server.pydantic_v1 import Field
1111
from chatchat.server.utils import (
1212
get_tool_config,
1313
get_default_llm,

libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from langchain_experimental.sql import SQLDatabaseChain, SQLDatabaseSequentialChain
55
from sqlalchemy import event
66
from sqlalchemy.exc import OperationalError
7+
from pydantic import Field
78

8-
from chatchat.server.pydantic_v1 import Field
99
from chatchat.server.utils import get_tool_config
1010

1111
from .tools_registry import BaseToolOutput, regist_tool

libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,24 @@ def format_context(self: BaseToolOutput) -> str:
165165
context += doc + "\n\n"
166166

167167
return context
168+
169+
170+
def format_context(self: BaseToolOutput) -> str:
171+
"""
172+
将包含知识库输出的ToolOutput格式化为 LLM 需要的字符串
173+
"""
174+
context = ""
175+
docs = self.data["docs"]
176+
source_documents = []
177+
178+
for inum, doc in enumerate(docs):
179+
doc = DocumentWithVSId.parse_obj(doc)
180+
source_documents.append(doc.page_content)
181+
182+
if len(source_documents) == 0:
183+
context = "没有找到相关文档,请更换关键词重试"
184+
else:
185+
for doc in source_documents:
186+
context += doc + "\n\n"
187+
188+
return context

libs/chatchat-server/chatchat/server/agent/tools_factory/url_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
"""
44
import requests
55
import re
6+
from pydantic import Field
67

7-
from chatchat.server.pydantic_v1 import Field
88
from chatchat.server.agent.tools_factory.tools_registry import format_context
99
from chatchat.server.utils import get_tool_config, build_logger
1010

libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
简单的单参数输入工具实现,用于查询现在天气的情况
33
"""
44
import requests
5+
from pydantic import Field
56

6-
from chatchat.server.pydantic_v1 import Field
77
from chatchat.server.utils import get_tool_config
88

99
from .tools_registry import BaseToolOutput, regist_tool

libs/chatchat-server/chatchat/server/agent/tools_factory/wikipedia_search.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
# LangChain 的 WikipediaQueryRun 工具
2+
from pydantic import Field
23
from langchain_community.tools import WikipediaQueryRun
34
from langchain_community.utilities import WikipediaAPIWrapper
4-
from chatchat.server.pydantic_v1 import Field
5-
6-
7-
85
from .tools_registry import BaseToolOutput, regist_tool
96

107

libs/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Langchain 自带的 Wolfram Alpha API 封装
2-
3-
from chatchat.server.pydantic_v1 import Field
2+
from pydantic import Field
43
from chatchat.server.utils import get_tool_config
5-
64
from .tools_registry import BaseToolOutput, regist_tool
75

86

libs/chatchat-server/chatchat/webui_pages/graph_agent/graph.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ async def handle_user_input(
256256
response_last = response["content"]
257257
elif "response" in response: # plan_execute_agent
258258
response_last = response["response"]
259+
elif "answer" in response: # reflexion
260+
response_last = response["answer"]
259261

260262
with st.status(node, expanded=True) as status:
261263
st.json(response, expanded=True)
@@ -356,7 +358,7 @@ def graph_agent_page():
356358

357359
tools_list = list_tools()
358360
# tool_names = ["None"] + list(tools_list)
359-
if selected_graph == "数据库查询机器人":
361+
if selected_graph == "数据库查询机器人[Beta]":
360362
selected_tools = st.multiselect(
361363
label="选择工具",
362364
options=["query_sql_data"],
@@ -394,7 +396,7 @@ def graph_agent_page():
394396
st.title("自媒体文章生成")
395397
with st.chat_message("assistant"):
396398
st.write("Hello 👋😊,我是自媒体文章生成 Agent,输入任意内容以启动工作流~")
397-
elif selected_graph == "数据库查询机器人":
399+
elif selected_graph == "数据库查询机器人[Beta]":
398400
st.title("数据库查询")
399401
with st.chat_message("assistant"):
400402
st.write("Hello 👋😊,我是数据库查询机器人,输入你想查询的内容~")
@@ -412,7 +414,7 @@ def graph_agent_page():
412414
st.rerun()
413415
if selected_graph == "article_generation":
414416
user_input = cols[2].chat_input("请你帮我生成一篇自媒体文章 (换行:Shift+Enter)")
415-
elif selected_graph == "数据库查询机器人":
417+
elif selected_graph == "数据库查询机器人[Beta]":
416418
user_input = cols[2].chat_input("请你帮忙调用工具, 查看组织`tcs_public`的成员有哪些?(换行:Shift+Enter)")
417419
else:
418420
user_input = cols[2].chat_input("尝试输入任何内容和我聊天呦 (换行:Shift+Enter)")
@@ -434,6 +436,7 @@ def graph_agent_page():
434436
max_tokens=None,
435437
temperature=st.session_state["temperature"],
436438
stream=True)
439+
logger.info(f"Loaded llm: {llm}")
437440

438441
# 创建 langgraph 实例
439442
graph_class = get_graph_class_by_label_and_title(label="agent", title=selected_graph)
@@ -446,15 +449,15 @@ def graph_agent_page():
446449
graph = graph_class.get_graph()
447450
if not graph:
448451
raise ValueError(f"Graph '{selected_graph}' is not registered.")
452+
st.toast(f"已加载工作流: {selected_graph}")
449453

450454
# langgraph 配置文件
451455
graph_config = {
452456
"configurable": {
453457
"thread_id": st.session_state["conversation_id"]
454458
},
455459
}
456-
457-
logger.info(f"graph: '{selected_graph}', configurable: '{graph_config}'")
460+
logger.info(f"Loaded graph: '{selected_graph}', configurable: '{graph_config}'")
458461

459462
# 绘制流程图并缓存
460463
graph_flow_image_name = f"{selected_graph}_flow_image"

0 commit comments

Comments
 (0)