Skip to content

Commit 6e72333

Browse files
authored
Merge pull request #40 from chatchat-space/dev
Dev
2 parents 417a6d6 + 71f938f commit 6e72333

File tree

6 files changed

+130
-21
lines changed

6 files changed

+130
-21
lines changed

libs/chatchat-server/chatchat/server/utils.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from langchain_openai.chat_models import ChatOpenAI
2626
from langgraph.checkpoint.memory import MemorySaver
2727
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
28-
# from langgraph.checkpoint.postgres import PostgresSaver
28+
from langgraph.checkpoint.postgres import PostgresSaver
2929
from memoization import cached, CachingAlgorithmFlag
3030

3131
from chatchat.settings import Settings, XF_MODELS_TYPES
@@ -774,28 +774,17 @@ def get_tool_config(name: str = None) -> Dict:
774774
return Settings.tool_settings.model_dump().get(name, {})
775775

776776

777-
def get_checkpoint(memory_type: Optional[Literal["memory", "sqlite", "postgres"]] = None) -> (Union)[
778-
MemorySaver,
779-
AsyncSqliteSaver,
780-
# PostgresSaver
781-
]:
777+
def get_checkpointer(memory_type: Optional[Literal["memory", "sqlite", "postgres"]] = None) -> (Union)[MemorySaver, AsyncSqliteSaver]:
782778
"""
783-
获取 graph 的 memory
779+
获取 graph 的 checkpointer(MemorySaver 和 AsyncSqliteSaver)
784780
"""
785-
import sqlalchemy as sa
786-
787781
if memory_type is None:
788782
memory_type = Settings.tool_settings.GRAPH_MEMORY_TYPE
789783

790784
if memory_type == "memory":
791785
return MemorySaver()
792786
elif memory_type == "sqlite":
793787
return AsyncSqliteSaver.from_conn_string(Settings.basic_settings.SQLITE_GRAPH_DATABASE_URI)
794-
# elif memory_type == "postgres":
795-
# import sqlalchemy as sa
796-
# engine = sa.create_engine(Settings.basic_settings.SQLALCHEMY_DATABASE_URI)
797-
# conn = engine.connect().connection
798-
# return PostgresSaver(conn)
799788

800789
raise ValueError("Invalid memory_type provided. Must be 'memory', 'sqlite', or 'postgres'.")
801790

libs/chatchat-server/chatchat/settings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ def BASE_TEMP_DIR(self) -> Path:
101101
SQLITE_GRAPH_DATABASE_URI: str = str(CHATCHAT_ROOT / "data/graph.db")
102102
"""工作流 SQLITE CHECKPOINT 数据库连接URI"""
103103

104+
POSTGRESQL_GRAPH_DATABASE_URI: str = "postgresql://username:password@localhost:5442/langgraph_chatchat"
105+
"""工作流 POSTGRESQL CHECKPOINT 数据库连接URI"""
106+
107+
POSTGRESQL_GRAPH_CONNECTION_POOLS_MAX_SIZE: int = 20
108+
"""工作流 POSTGRESQL CHECKPOINT 数据库连接池最大连接数限制"""
109+
110+
POSTGRESQL_GRAPH_CONNECTION_POOLS_KWARGS: dict = {
111+
"autocommit": True,
112+
"prepare_threshold": 0,
113+
}
114+
"""工作流 POSTGRESQL CHECKPOINT 数据库连接池关键字参数配置"""
115+
104116
OPEN_CROSS_DOMAIN: bool = False
105117
"""API 是否开启跨域"""
106118

libs/chatchat-server/chatchat/webui_pages/graph_agent/graph.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_tool,
1414
list_tools,
1515
create_agent_models,
16-
get_checkpoint
16+
get_checkpointer
1717
)
1818

1919
logger = build_logger()
@@ -153,7 +153,7 @@ async def create_graph(
153153
):
154154
if st.session_state["checkpoint_type"] == "memory":
155155
if "memory" not in st.session_state:
156-
st.session_state["memory"] = get_checkpoint()
156+
st.session_state["memory"] = get_checkpointer()
157157
checkpoint = st.session_state["memory"]
158158
graph_class = graph_class(llm=graph_llm,
159159
tools=graph_tools,
@@ -164,7 +164,7 @@ async def create_graph(
164164
raise ValueError(f"Graph '{graph_class}' is not registered.")
165165
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
166166
elif st.session_state["checkpoint_type"] == "sqlite":
167-
checkpoint_class = get_checkpoint()
167+
checkpoint_class = get_checkpointer()
168168
async with checkpoint_class as checkpoint:
169169
graph_class = graph_class(llm=graph_llm,
170170
tools=graph_tools,
@@ -174,6 +174,25 @@ async def create_graph(
174174
if not graph:
175175
raise ValueError(f"Graph '{graph_class}' is not registered.")
176176
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
177+
elif st.session_state["checkpoint_type"] == "postgres":
178+
from psycopg_pool import AsyncConnectionPool
179+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
180+
async with AsyncConnectionPool(
181+
conninfo=Settings.basic_settings.POSTGRESQL_GRAPH_DATABASE_URI,
182+
max_size=Settings.basic_settings.POSTGRESQL_GRAPH_CONNECTION_POOLS_MAX_SIZE,
183+
kwargs=Settings.basic_settings.POSTGRESQL_GRAPH_CONNECTION_POOLS_KWARGS,
184+
) as pool:
185+
checkpoint = AsyncPostgresSaver(pool)
186+
# NOTE: you need to call .setup() the first time you're using your checkpointer
187+
await checkpoint.setup()
188+
graph_class = graph_class(llm=graph_llm,
189+
tools=graph_tools,
190+
history_len=graph_history_len,
191+
checkpoint=checkpoint)
192+
graph = graph_class.get_graph()
193+
if not graph:
194+
raise ValueError(f"Graph '{graph_class}' is not registered.")
195+
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
177196

178197

179198
async def update_state(graph: CompiledStateGraph, graph_config: Dict, update_message: Dict, as_node: str):

libs/chatchat-server/chatchat/webui_pages/graph_rag/rag.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_tool,
1414
create_agent_models,
1515
list_tools,
16-
get_checkpoint,
16+
get_checkpointer,
1717
)
1818

1919
logger = build_logger()
@@ -32,7 +32,7 @@ async def create_graph(
3232
):
3333
if st.session_state["checkpoint_type"] == "memory":
3434
if "memory" not in st.session_state:
35-
st.session_state["memory"] = get_checkpoint()
35+
st.session_state["memory"] = get_checkpointer()
3636
checkpoint = st.session_state["memory"]
3737
graph_class = graph_class(llm=graph_llm,
3838
tools=graph_tools,
@@ -46,7 +46,7 @@ async def create_graph(
4646
raise ValueError(f"Graph '{graph_class}' is not registered.")
4747
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
4848
elif st.session_state["checkpoint_type"] == "sqlite":
49-
checkpoint_class = get_checkpoint()
49+
checkpoint_class = get_checkpointer()
5050
async with checkpoint_class as checkpoint:
5151
graph_class = graph_class(llm=graph_llm,
5252
tools=graph_tools,
@@ -59,6 +59,28 @@ async def create_graph(
5959
if not graph:
6060
raise ValueError(f"Graph '{graph_class}' is not registered.")
6161
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
62+
elif st.session_state["checkpoint_type"] == "postgres":
63+
from psycopg_pool import AsyncConnectionPool
64+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
65+
async with AsyncConnectionPool(
66+
conninfo=Settings.basic_settings.POSTGRESQL_GRAPH_DATABASE_URI,
67+
max_size=Settings.basic_settings.POSTGRESQL_GRAPH_CONNECTION_POOLS_MAX_SIZE,
68+
kwargs=Settings.basic_settings.POSTGRESQL_GRAPH_CONNECTION_POOLS_KWARGS,
69+
) as pool:
70+
checkpoint = AsyncPostgresSaver(pool)
71+
# NOTE: you need to call .setup() the first time you're using your checkpointer
72+
await checkpoint.setup()
73+
graph_class = graph_class(llm=graph_llm,
74+
tools=graph_tools,
75+
history_len=graph_history_len,
76+
checkpoint=checkpoint,
77+
knowledge_base=knowledge_base,
78+
top_k=top_k,
79+
score_threshold=score_threshold)
80+
graph = graph_class.get_graph()
81+
if not graph:
82+
raise ValueError(f"Graph '{graph_class}' is not registered.")
83+
await process_graph(graph_class=graph_class, graph=graph, graph_input=graph_input, graph_config=graph_config)
6284

6385

6486
async def graph_rag_page(api: ApiRequest):

libs/chatchat-server/poetry.lock

Lines changed: 66 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/chatchat-server/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ langchain-community = "^0.3.5"
7777
mysqlclient = "^2.2.5"
7878
metaphor-python = "^0.1.23"
7979
xinference = "^0.16.3"
80+
psycopg-pool = "^3.2.3"
81+
langgraph-checkpoint-postgres = "^2.0.2"
8082
[tool.poetry.extras]
8183
xinference = ["xinference_client"]
8284
zhipuai = ["zhipuai"]

0 commit comments

Comments
 (0)