Skip to content

Async Mysql Memory for kb_chat #5124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from typing import Any, Dict, List

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult

from chatchat.server.db.repository import update_message
from typing import (
Dict,
Any,
List
)


class ConversationCallbackHandler(BaseCallbackHandler):
raise_error: bool = True

def __init__(
self, conversation_id: str, message_id: str, chat_type: str, query: str
self,
conversation_id: str,
message_id: str,
chat_type: str,
query: str
):
self.conversation_id = conversation_id
self.message_id = message_id
Expand All @@ -20,15 +26,20 @@ def __init__(

@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
# TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
async def on_llm_end(
self,
response: LLMResult,
**kwargs: Any,
) -> None:
answer = response.generations[0][0].text
update_message(self.message_id, answer)
await update_message(self.message_id, answer)
42 changes: 34 additions & 8 deletions libs/chatchat-server/chatchat/server/chat/kb_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,16 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
try:
nonlocal history, prompt_name, max_tokens

history = [History.from_data(h) for h in history]

message_id = await add_message_to_db(user_id=user_id,
conversation_id=conversation_id,
conversation_name=conversation_name,
prompt_name=prompt_name,
query=query)
# history = [History.from_data(h) for h in history]
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id,
message_id=message_id,
chat_type=prompt_name,
query=query)
if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name)
ok, msg = kb.check_embed_model()
Expand Down Expand Up @@ -138,7 +145,9 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:

if max_tokens in [None, 0]:
max_tokens = Settings.model_settings.MAX_TOKENS

callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
callbacks.append(conversation_callback)
llm = get_ChatOpenAI(
model_name=model,
temperature=temperature,
Expand All @@ -164,13 +173,30 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:

if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_name = "empty"
prompt_template = get_prompt_template("rag", prompt_name)
#prompt_template = get_prompt_template("rag", prompt_name)
#input_msg = History(role="user", content=prompt_template).to_msg_template(False)
#chat_prompt = ChatPromptTemplate.from_messages(
#[i.to_msg_template() for i in history] + [input_msg])
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=llm,
chat_type=prompt_name,
message_limit=10)
history = await memory.buffer()
import numpy as np

prompt_template = default_prompt.get(prompt_name)
system_msg = History(role="system",
content="你是一位善于结合历史对话信息,以及相关文档回答问题的高智商助手").to_msg_template(
is_raw=False)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])

history = [History.from_data(h) for h in history]
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [system_msg, input_msg] # [input_msg]
)
chain = chat_prompt | llm

#chain = chat_prompt | llm

# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.ainvoke({"context": context, "question": query}),
Expand Down
24 changes: 15 additions & 9 deletions libs/chatchat-server/chatchat/server/db/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import json

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.orm import sessionmaker

from chatchat.settings import Settings

username = 'root'
hostname = ''
database_name = ''
password = "123456"

SQLALCHEMY_DATABASE_URI = f"mysql+asyncmy://{username}:{password}@{hostname}/{database_name}?charset=utf8mb4"
print(SQLALCHEMY_DATABASE_URI)

engine = create_engine(
Settings.basic_settings.SQLALCHEMY_DATABASE_URI,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
async_engine = create_async_engine(
SQLALCHEMY_DATABASE_URI,
echo=True,
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
AsyncSessionLocal = sessionmaker(bind=async_engine, class_=AsyncSession, expire_on_commit=False)

Base: DeclarativeMeta = declarative_base()


11 changes: 11 additions & 0 deletions libs/chatchat-server/chatchat/server/db/create_all_models
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import asyncio
from NorthinfoChat.server.db.base import Base
from NorthinfoChat.server.db.base import async_engine, AsyncSessionLocal


async def create_tables(engine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

if __name__ == '__main__':
asyncio.run(create_tables(async_engine))
25 changes: 25 additions & 0 deletions libs/chatchat-server/chatchat/server/db/create_all_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine
from NorthinfoChat.server.db.base import Base

username = ''
hostname = ''
database_name = ''
password = ""

SQLALCHEMY_DATABASE_URI = f"mysql+asyncmy://{username}:{password}@{hostname}/{database_name}?charset=utf8mb4"

async_engine = create_async_engine(
SQLALCHEMY_DATABASE_URI,
echo=True
)


async def drop_tables(engine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)


if __name__ == '__main__':
asyncio.run(drop_tables(async_engine))

Original file line number Diff line number Diff line change
@@ -1,85 +1,73 @@
from chatchat.server.db.models.knowledge_base_model import (
KnowledgeBaseModel,
KnowledgeBaseSchema,
)
from chatchat.server.db.session import with_session


@with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
# 创建知识库实例
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
from chatchat.server.db.model.knowledge_base_model import KnowledgeBaseModel, KnowledgeBaseSchema
from sqlalchemy import select, delete
from chatchat.server.db.session import with_async_session, async_session_scope
import asyncio


@with_async_session
async def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, user_id):
kb = await session.execute(
select(KnowledgeBaseModel)
.where(KnowledgeBaseModel.kb_name.ilike(kb_name))
)
kb = kb.scalars().first()

if not kb:
kb = KnowledgeBaseModel(
kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model
kb_name=kb_name,
kb_info=kb_info,
vs_type=vs_type,
embed_model=embed_model,
user_id=user_id
)
session.add(kb)
else: # update kb with new vs_type and embed_model
else:
kb.kb_info = kb_info
kb.vs_type = vs_type
kb.embed_model = embed_model
kb.user_id = user_id

await session.commit()
return True


@with_session
def list_kbs_from_db(session, min_file_count: int = -1):
kbs = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.file_count > min_file_count)
.all()
@with_async_session
async def list_kbs_from_db(session, min_file_count: int = -1):
result = await session.execute(
select(KnowledgeBaseModel.kb_name)
.where(KnowledgeBaseModel.file_count > min_file_count)
)
kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs]
return kbs

kbs = [kb for kb in result.scalars().all()]
return kbs

@with_session
def kb_exists(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
@with_async_session
async def list_kbs_from_db2(session, min_file_count: int = -1):
result = await session.execute(
select(KnowledgeBaseModel)
.where(KnowledgeBaseModel.file_count > min_file_count)
)
status = True if kb else False
return status

kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in result.scalars().all()]
return kbs

@with_session
def load_kb_from_db(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if kb:
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
else:
kb_name, vs_type, embed_model = None, None, None
return kb_name, vs_type, embed_model


@with_session
def delete_kb_from_db(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if kb:
session.delete(kb)
return True
@with_async_session
async def kb_exists(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
status = True if kb else False
return status


@with_session
def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
@with_async_session
async def get_kb_detail(session, kb_name: str) -> dict:
stmt = select(
KnowledgeBaseModel
).where(KnowledgeBaseModel.kb_name.ilike(kb_name))
result = await session.execute(stmt)
kb = result.scalars().first()

if kb:
return {
"kb_name": kb.kb_name,
Expand All @@ -91,3 +79,34 @@ def get_kb_detail(session, kb_name: str) -> dict:
}
else:
return {}


@with_async_session
async def delete_kb_from_db(session, kb_name):
await session.execute(
delete(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
)
return True


@with_async_session
async def load_kb_from_db(session, kb_name):
stmt = select(KnowledgeBaseModel).where(KnowledgeBaseModel.kb_name.ilike(kb_name))
result = await session.execute(stmt)
kb = result.scalar_one_or_none()

if kb:
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
else:
kb_name, vs_type, embed_model = None, None, None
return kb_name, vs_type, embed_model


if __name__ == '__main__':
r = asyncio.run(list_kbs_from_db())
print(r)


if __name__ == '__main__':
r = asyncio.run(list_kbs_from_db())
print(r)
Loading