-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathpdf_bot.py
115 lines (90 loc) · 3.2 KB
/
pdf_bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import streamlit as st
from PyPDF2 import PdfReader
from langchain.callbacks.base import BaseCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_neo4j import Neo4jVector
from streamlit.logger import get_logger
from chains import (
load_embedding_model,
load_llm,
)
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from utils import format_docs
# load api key lib
from dotenv import load_dotenv
load_dotenv(".env")
url = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
llm_name = os.getenv("LLM")
# Remapping for Langchain Neo4j integration
os.environ["NEO4J_URL"] = url
logger = get_logger(__name__)
embeddings, dimension = load_embedding_model(
embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
)
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
def main():
st.header("📄Chat with your pdf file")
# upload a your pdf file
pdf = st.file_uploader("Upload your PDF", type="pdf")
if pdf is not None:
pdf_reader = PdfReader(pdf)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# langchain_textspliter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, length_function=len
)
chunks = text_splitter.split_text(text=text)
qa_prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"Based on the provided summary: {summaries} \n Answer the following question:{question}",
)
]
)
# Store the chunks part in db (vector)
vectorstore = Neo4jVector.from_texts(
chunks,
url=url,
username=username,
password=password,
embedding=embeddings,
index_name="pdf_bot",
node_label="PdfBotChunk",
pre_delete_collection=True, # Delete existing PDF data
)
qa = (
RunnableParallel(
{
"summaries": vectorstore.as_retriever(search_kwargs={"k": 2})
| format_docs,
"question": RunnablePassthrough(),
}
)
| qa_prompt
| llm
| StrOutputParser()
)
# Accept user questions/query
query = st.text_input("Ask questions about your PDF file")
if query:
stream_handler = StreamHandler(st.empty())
qa.invoke(query, {"callbacks": [stream_handler]})
if __name__ == "__main__":
main()