Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge code from app and multirag #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__/**
insf_venv/**
*.pyc
.env/*
94 changes: 50 additions & 44 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import uuid
import datasets
import tempfile

from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
Expand All @@ -19,6 +21,8 @@
from urllib3.exceptions import ProtocolError
from langchain.retrievers import ContextualCompressionRetriever
from transformers import AutoTokenizer
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
from langchain_community.document_loaders import PyPDFLoader

from tools import get_tools
from tei_rerank import TEIRerank
Expand All @@ -29,10 +33,10 @@
import yaml
from yaml.loader import SafeLoader

from langchain.globals import set_verbose, set_debug
# from langchain.globals import set_verbose, set_debug

set_verbose(True)
set_debug(True)
# set_verbose(True)
# set_debug(True)

st.set_page_config(layout="wide", page_title="InSightful")

Expand Down Expand Up @@ -129,7 +133,8 @@ def setup_huggingface_embeddings():

@st.cache_resource
def load_prompt_and_system_ins(
template_file_path="templates/prompt_template.tmpl", template=None
template_file_path: str = "templates/prompt_template.tmpl",
template: str | None = None,
):
# prompt = hub.pull("hwchase17/react-chat")
prompt = PromptTemplate.from_file(template_file_path)
Expand All @@ -149,10 +154,11 @@ def load_prompt_and_system_ins(
return prompt, system_instructions


class RAG:
def __init__(self, collection_name, db_client):
self.collection_name = collection_name
class RAG(object):
def __init__(self, llm: ChatOpenAI, db_client, embedding_function):
self.llm = llm
self.db_client = db_client
self.embedding_function = embedding_function

@retry(
retry=retry_if_exception_type(ProtocolError),
Expand Down Expand Up @@ -182,14 +188,14 @@ def chunk_doc(self, pages, chunk_size=512, chunk_overlap=30):
print("Document chunked")
return chunks

def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32):
def insert_embeddings(self, chunks, collection_name, batch_size=32):
print(
"Inserting embeddings into collection: {collection_name}".format(
collection_name=self.collection_name
collection_name=collection_name
)
)
collection = self.db_client.get_or_create_collection(
self.collection_name, embedding_function=chroma_embedding_function
collection_name, embedding_function=self.embedding_function
)
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
Expand Down Expand Up @@ -219,44 +225,39 @@ def get_retriever(self, vector_store, use_reranker=False):
return retriever

def query_docs(
self, model, question, vector_store, prompt, chat_history, use_reranker=False
self, question, vector_store, prompt, chat_history, use_reranker=False
):
retriever = self.get_retriever(vector_store, use_reranker)
pass_question = lambda input: input["question"]
rag_chain = (
RunnablePassthrough.assign(context=pass_question | retriever | format_docs)
| prompt
| model
| self.llm
| StrOutputParser()
)

return rag_chain.stream({"question": question, "chat_history": chat_history})

def load_pdf(self, doc):
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(doc.name)[1]
) as tmp:
tmp.write(doc.getvalue())
tmp_path = tmp.name
loader = PyPDFLoader(tmp_path)
documents = loader.load()
cleaned_pages = []
for doc in documents:
doc.page_content = clean_extra_whitespace(doc.page_content)
doc.page_content = group_broken_paragraphs(doc.page_content)
cleaned_pages.append(doc)
return cleaned_pages


def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)


def create_retriever(
name, description, client, chroma_embedding_function, embedding_svc, reranker=False
):
collection_name = "software-slacks"
rag = RAG(collection_name=collection_name, db_client=client)
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
rag.insert_embeddings(chunks, chroma_embedding_function)
vector_store = Chroma(
embedding_function=embedding_svc,
collection_name=collection_name,
client=client,
)
retriever = rag.get_retriever(vector_store, use_reranker=reranker)

retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
)
return create_retriever_tool(retriever, name, description)

@st.cache_resource
def setup_agent(_model, _prompt, _tools):
agent = create_react_agent(
Expand All @@ -280,17 +281,25 @@ def main():
model = setup_chat_endpoint()
embedder = setup_huggingface_embeddings()
use_reranker = os.getenv("USE_RERANKER", "False") == "True"

retriever_tool = create_retriever(
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
client,
chroma_embedding_function,
embedder,
reranker=use_reranker,
rag = RAG(llm=model, db_client=client, embedding_function=chroma_embedding_function)
collection_name = "software-slacks"
pages = rag.load_documents("spencer/software_slacks", num_docs=100)
chunks = rag.chunk_doc(pages)
rag.insert_embeddings(chunks, collection_name)
vector_store = Chroma(
embedding_function=embedder,
collection_name=collection_name,
client=client,
)
retriever = rag.get_retriever(vector_store, use_reranker=use_reranker)
_tools = get_tools()
_tools.append(retriever_tool)
_tools.append(
create_retriever_tool(
retriever,
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
)
)

agent_executor = setup_agent(model, prompt, _tools)

Expand Down Expand Up @@ -328,7 +337,4 @@ def main():


if __name__ == "__main__":
# authenticator = authenticate()
# if st.session_state['authentication_status']:
# authenticator.logout()
main()
75 changes: 24 additions & 51 deletions multi_tenant_rag.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import logging
import tempfile
import os
import yaml
from yaml.loader import SafeLoader
import streamlit as st
import streamlit_authenticator as stauth
from streamlit_authenticator.utilities import RegisterError
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores.chroma import Chroma
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
from langchain_chroma import Chroma

from tools import get_tools

from app import (
Expand All @@ -21,17 +20,19 @@
setup_agent,
)


SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
log: logging.Logger = logging.getLogger(__name__)


def configure_authenticator():
auth_config = os.getenv("AUTH_CONFIG_FILE_PATH", default=".streamlit/config.yaml")
print(f"auth_config: {auth_config}")
log.info(f"auth_config: {auth_config}")
with open(file=auth_config) as file:
config = yaml.load(file, Loader=SafeLoader)

Expand Down Expand Up @@ -67,49 +68,32 @@ def authenticate(op):
return authenticator


class MultiTenantRAG(RAG):
def __init__(self, user_id, collection_name, db_client):
self.user_id = user_id
super().__init__(collection_name, db_client)

def load_documents(self, doc):
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(doc.name)[1]
) as tmp:
tmp.write(doc.getvalue())
tmp_path = tmp.name
loader = PyPDFLoader(tmp_path)
documents = loader.load()
cleaned_pages = []
for doc in documents:
doc.page_content = clean_extra_whitespace(doc.page_content)
doc.page_content = group_broken_paragraphs(doc.page_content)
cleaned_pages.append(doc)
return cleaned_pages


def main():
authenticator = authenticate("login")
if st.session_state["authentication_status"]:
st.sidebar.text(f"Welcome {st.session_state['username']}")
authenticator.logout(location="sidebar")
user_id = st.session_state["username"]
if not user_id:
st.error("Please login to continue")
return

use_reranker = st.sidebar.toggle("Use reranker", False)
use_tools = st.sidebar.toggle("Use tools", False)
uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"])
question = st.chat_input("Chat with your docs or apis")

llm = setup_chat_endpoint()

embedding_svc = setup_huggingface_embeddings()

chroma_embeddings = hf_embedding_server()

user_id = st.session_state["username"]

client = setup_chroma_client()

# Set up prompt template
template = """
Based on the retrieved context, respond with an accurate answer.

Be concise and always provide accurate, specific, and relevant information.
"""

template_file_path = "templates/multi_tenant_rag_prompt_template.tmpl"
if use_tools:
template_file_path = "templates/multi_tenant_rag_prompt_template_tools.tmpl"
Expand All @@ -118,6 +102,7 @@ def main():
template_file_path=template_file_path,
template=template,
)
log.info(f"prompt: {prompt} system_instructions: {system_instructions}")

chat_history = st.session_state.get(
"chat_history", [{"role": SYSTEM, "content": system_instructions.content}]
Expand All @@ -127,38 +112,31 @@ def main():
with st.chat_message(message["role"]):
st.markdown(message["content"])

if not user_id:
st.error("Please login to continue")
return

collection = client.get_or_create_collection(
f"user-collection-{user_id}", embedding_function=chroma_embeddings
)

logger = logging.getLogger(__name__)
logger.info(
log.info(
f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}"
)
rag = MultiTenantRAG(user_id, collection.name, client)
rag = RAG(llm=llm, db_client=client, embedding_function=chroma_embeddings)

if use_tools:
tools = get_tools()
agent_executor = setup_agent(llm, prompt, tools)

# prompt = hub.pull("rlm/rag-prompt")

vectorstore = Chroma(
embedding_function=embedding_svc,
collection_name=collection.name,
client=client,
)

if uploaded_file:
document = rag.load_documents(uploaded_file)
document = rag.load_pdf(uploaded_file)
chunks = rag.chunk_doc(document)
rag.insert_embeddings(
chunks=chunks,
chroma_embedding_function=chroma_embeddings,
collection_name=collection.name,
batch_size=32,
)

Expand All @@ -174,10 +152,9 @@ def main():
)["output"]
with st.chat_message(ASSISTANT):
st.write(answer)
logger.info(f"answer: {answer}")
log.info(f"answer: {answer}")
else:
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
Expand All @@ -186,16 +163,12 @@ def main():
)
with st.chat_message(ASSISTANT):
answer = st.write_stream(answer)
logger.info(f"answer: {answer}")
log.info(f"answer: {answer}")

chat_history.append({"role": USER, "content": question})
chat_history.append({"role": ASSISTANT, "content": answer})
st.session_state["chat_history"] = chat_history


if __name__ == "__main__":
authenticator = authenticate("login")
if st.session_state["authentication_status"]:
st.sidebar.text(f"Welcome {st.session_state['username']}")
authenticator.logout(location="sidebar")
main()
main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
chromadb==0.5.3
datasets==2.20.0
langchain==0.2.12
langchain_chroma==0.1.2
langchain_chroma==0.1.3
langchain_community==0.2.11
langchain_core==0.2.28
langchain_huggingface==0.0.3
Expand Down