Skip to content

Commit 48be0ca

Browse files
committed
Refactor logic into separate files for easier reading
1 parent 325a026 commit 48be0ca

File tree

6 files changed

+205
-215
lines changed

6 files changed

+205
-215
lines changed

bot.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ COPY requirements.txt .
1313

1414
RUN pip install --upgrade -r requirements.txt
1515

16-
# COPY .env .
1716
COPY bot.py .
1817
COPY utils.py .
18+
COPY chains.py .
1919

2020
EXPOSE 8501
2121

bot.py

Lines changed: 21 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
import os
2-
from typing import List, Any
32

43
import streamlit as st
54
from streamlit.logger import get_logger
65
from langchain.callbacks.base import BaseCallbackHandler
7-
from langchain.vectorstores.neo4j_vector import Neo4jVector
8-
9-
from langchain.chat_models import ChatOpenAI, ChatOllama
10-
from langchain.chains import RetrievalQAWithSourcesChain
11-
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
126
from langchain.prompts.chat import (
137
ChatPromptTemplate,
148
SystemMessagePromptTemplate,
159
HumanMessagePromptTemplate,
1610
)
1711
from langchain.graphs import Neo4jGraph
1812
from dotenv import load_dotenv
19-
from utils import extract_title_and_question, load_embedding_model
13+
from utils import (
14+
extract_title_and_question,
15+
create_vector_index,
16+
)
17+
from chains import (
18+
load_embedding_model,
19+
load_llm,
20+
configure_llm_only_chain,
21+
configure_qa_rag_chain,
22+
)
2023

2124
load_dotenv(".env")
2225

@@ -33,19 +36,10 @@
3336

3437
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
3538
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
36-
37-
38-
def create_vector_index(dimension: int) -> None:
39-
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
40-
try:
41-
neo4j_graph.query(index_query, {"dimension": dimension})
42-
except: # Already exists
43-
pass
44-
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
45-
try:
46-
neo4j_graph.query(index_query, {"dimension": dimension})
47-
except: # Already exists
48-
pass
39+
embeddings, dimension = load_embedding_model(
40+
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
41+
)
42+
create_vector_index(neo4j_graph, dimension)
4943

5044

5145
class StreamHandler(BaseCallbackHandler):
@@ -58,142 +52,11 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
5852
self.container.markdown(self.text)
5953

6054

61-
embeddings, dimension = load_embedding_model(
62-
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
63-
)
64-
65-
create_vector_index(dimension)
66-
67-
if llm_name == "gpt-4":
68-
llm = ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
69-
logger.info("LLM: Using GPT-4")
70-
elif llm_name == "gpt-3.5":
71-
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
72-
logger.info("LLM: Using GPT-3.5 Turbo")
73-
elif len(llm_name):
74-
llm = ChatOllama(
75-
temperature=0,
76-
base_url=ollama_base_url,
77-
model=llm_name,
78-
streaming=True,
79-
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
80-
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
81-
num_ctx=3072, # Sets the size of the context window used to generate the next token.
82-
)
83-
logger.info(f"LLM: Using Ollama ({llm_name})")
84-
else:
85-
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
86-
logger.info("LLM: Using GPT-3.5 Turbo")
87-
88-
# LLM only response
89-
template = """
90-
You are a helpful assistant that helps a support agent with answering programming questions.
91-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
92-
"""
93-
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
94-
human_template = "{text}"
95-
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
96-
chat_prompt = ChatPromptTemplate.from_messages(
97-
[system_message_prompt, human_message_prompt]
98-
)
99-
100-
101-
def generate_llm_output(
102-
user_input: str, callbacks: List[Any], prompt=chat_prompt
103-
) -> str:
104-
answer = llm(
105-
prompt.format_prompt(
106-
text=user_input,
107-
).to_messages(),
108-
callbacks=callbacks,
109-
).content
110-
return {"answer": answer}
111-
112-
113-
# Vector response
114-
neo4j_db = Neo4jVector.from_existing_index(
115-
embedding=embeddings,
116-
url=url,
117-
username=username,
118-
password=password,
119-
database="neo4j", # neo4j by default
120-
index_name="top_answers", # vector by default
121-
text_node_property="body", # text by default
122-
retrieval_query="""
123-
OPTIONAL MATCH (node)-[:ANSWERS]->(question)
124-
RETURN 'Question: ' + question.title + '\n' + question.body + '\nAnswer: ' +
125-
coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
126-
ORDER BY score ASC // so that best answer are the last
127-
""",
128-
)
129-
130-
general_system_template = """
131-
Use the following pieces of context to answer the question at the end.
132-
The context contains question-answer pairs and their links from Stackoverflow.
133-
You should prefer information from accepted or more upvoted answers.
134-
Make sure to rely on information from the answers and not on questions to provide accuate responses.
135-
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
136-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
137-
----
138-
{summaries}
139-
----
140-
Each answer you generate should contain a section at the end of links to
141-
Stackoverflow questions and answers you found useful, which are described under Source value.
142-
You can only use links to StackOverflow questions that are present in the context and always
143-
add links to the end of the answer in the style of citations.
144-
Generate concise answers with references sources section of links to
145-
relevant StackOverflow questions only at the end of the answer.
146-
"""
147-
general_user_template = "Question:```{question}```"
148-
messages = [
149-
SystemMessagePromptTemplate.from_template(general_system_template),
150-
HumanMessagePromptTemplate.from_template(general_user_template),
151-
]
152-
qa_prompt = ChatPromptTemplate.from_messages(messages)
153-
154-
qa_chain = load_qa_with_sources_chain(
155-
llm,
156-
chain_type="stuff",
157-
prompt=qa_prompt,
158-
)
159-
qa = RetrievalQAWithSourcesChain(
160-
combine_documents_chain=qa_chain,
161-
retriever=neo4j_db.as_retriever(search_kwargs={"k": 2}),
162-
reduce_k_below_max_tokens=True,
163-
max_tokens_limit=3375,
164-
)
165-
166-
# Vector + Knowledge Graph response
167-
kg = Neo4jVector.from_existing_index(
168-
embedding=embeddings,
169-
url=url,
170-
username=username,
171-
password=password,
172-
database="neo4j", # neo4j by default
173-
index_name="stackoverflow", # vector by default
174-
text_node_property="body", # text by default
175-
retrieval_query="""
176-
WITH node AS question, score AS similarity
177-
CALL { with question
178-
MATCH (question)<-[:ANSWERS]-(answer)
179-
WITH answer
180-
ORDER BY answer.is_accepted DESC, answer.score DESC
181-
WITH collect(answer)[..2] as answers
182-
RETURN reduce(str='', answer IN answers | str +
183-
'\n### Answer (Accepted: '+ answer.is_accepted +
184-
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
185-
}
186-
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
187-
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
188-
ORDER BY similarity ASC // so that best answers are the last
189-
""",
190-
)
55+
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
19156

192-
kg_qa = RetrievalQAWithSourcesChain(
193-
combine_documents_chain=qa_chain,
194-
retriever=kg.as_retriever(search_kwargs={"k": 2}),
195-
reduce_k_below_max_tokens=False,
196-
max_tokens_limit=3375,
57+
llm_chain = configure_llm_only_chain(llm)
58+
rag_chain = configure_qa_rag_chain(
59+
llm, embeddings, embeddings_store_url=url, username=username, password=password
19760
)
19861

19962
# Streamlit UI
@@ -280,11 +143,9 @@ def mode_select() -> str:
280143

281144
name = mode_select()
282145
if name == "LLM only" or name == "Disabled":
283-
output_function = generate_llm_output
284-
elif name == "Vector":
285-
output_function = qa
146+
output_function = llm_chain
286147
elif name == "Vector + Graph" or name == "Enabled":
287-
output_function = kg_qa
148+
output_function = rag_chain
288149

289150

290151
def generate_ticket():
@@ -337,7 +198,7 @@ def generate_ticket():
337198
HumanMessagePromptTemplate.from_template("{text}"),
338199
]
339200
)
340-
llm_response = generate_llm_output(
201+
llm_response = llm_chain(
341202
f"Here's the question to rewrite in the expected format: ```{q_prompt}```",
342203
[],
343204
chat_prompt,

chains.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from langchain.embeddings.openai import OpenAIEmbeddings
2+
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
3+
from langchain.chat_models import ChatOpenAI, ChatOllama
4+
from langchain.vectorstores.neo4j_vector import Neo4jVector
5+
from langchain.chains import RetrievalQAWithSourcesChain
6+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
7+
from langchain.prompts.chat import (
8+
ChatPromptTemplate,
9+
SystemMessagePromptTemplate,
10+
HumanMessagePromptTemplate,
11+
)
12+
from typing import List, Any
13+
from utils import BaseLogger
14+
15+
16+
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
17+
if embedding_model_name == "ollama":
18+
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
19+
dimension = 4096
20+
logger.info("Embedding: Using Ollama")
21+
elif embedding_model_name == "openai":
22+
embeddings = OpenAIEmbeddings()
23+
dimension = 1536
24+
logger.info("Embedding: Using OpenAI")
25+
else:
26+
embeddings = SentenceTransformerEmbeddings(
27+
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
28+
)
29+
dimension = 384
30+
logger.info("Embedding: Using SentenceTransformer")
31+
return embeddings, dimension
32+
33+
34+
def load_llm(llm_name: str, logger=BaseLogger(), config={}):
35+
if llm_name == "gpt-4":
36+
logger.info("LLM: Using GPT-4")
37+
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
38+
elif llm_name == "gpt-3.5":
39+
logger.info("LLM: Using GPT-3.5")
40+
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
41+
elif len(llm_name):
42+
logger.info(f"LLM: Using Ollama: {llm_name}")
43+
return ChatOllama(
44+
temperature=0,
45+
base_url=config["ollama_base_url"],
46+
model=llm_name,
47+
streaming=True,
48+
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
49+
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
50+
num_ctx=3072, # Sets the size of the context window used to generate the next token.
51+
)
52+
logger.info("LLM: Using GPT-3.5")
53+
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
54+
55+
56+
def configure_llm_only_chain(llm):
57+
# LLM only response
58+
template = """
59+
You are a helpful assistant that helps a support agent with answering programming questions.
60+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
61+
"""
62+
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
63+
human_template = "{text}"
64+
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
65+
chat_prompt = ChatPromptTemplate.from_messages(
66+
[system_message_prompt, human_message_prompt]
67+
)
68+
69+
def generate_llm_output(
70+
user_input: str, callbacks: List[Any], prompt=chat_prompt
71+
) -> str:
72+
answer = llm(
73+
prompt.format_prompt(
74+
text=user_input,
75+
).to_messages(),
76+
callbacks=callbacks,
77+
).content
78+
return {"answer": answer}
79+
80+
return generate_llm_output
81+
82+
83+
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
84+
# RAG response
85+
general_system_template = """
86+
Use the following pieces of context to answer the question at the end.
87+
The context contains question-answer pairs and their links from Stackoverflow.
88+
You should prefer information from accepted or more upvoted answers.
89+
Make sure to rely on information from the answers and not on questions to provide accuate responses.
90+
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
91+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
92+
----
93+
{summaries}
94+
----
95+
Each answer you generate should contain a section at the end of links to
96+
Stackoverflow questions and answers you found useful, which are described under Source value.
97+
You can only use links to StackOverflow questions that are present in the context and always
98+
add links to the end of the answer in the style of citations.
99+
Generate concise answers with references sources section of links to
100+
relevant StackOverflow questions only at the end of the answer.
101+
"""
102+
general_user_template = "Question:```{question}```"
103+
messages = [
104+
SystemMessagePromptTemplate.from_template(general_system_template),
105+
HumanMessagePromptTemplate.from_template(general_user_template),
106+
]
107+
qa_prompt = ChatPromptTemplate.from_messages(messages)
108+
109+
qa_chain = load_qa_with_sources_chain(
110+
llm,
111+
chain_type="stuff",
112+
prompt=qa_prompt,
113+
)
114+
115+
# Vector + Knowledge Graph response
116+
kg = Neo4jVector.from_existing_index(
117+
embedding=embeddings,
118+
url=embeddings_store_url,
119+
username=username,
120+
password=password,
121+
database="neo4j", # neo4j by default
122+
index_name="stackoverflow", # vector by default
123+
text_node_property="body", # text by default
124+
retrieval_query="""
125+
WITH node AS question, score AS similarity
126+
CALL { with question
127+
MATCH (question)<-[:ANSWERS]-(answer)
128+
WITH answer
129+
ORDER BY answer.is_accepted DESC, answer.score DESC
130+
WITH collect(answer)[..2] as answers
131+
RETURN reduce(str='', answer IN answers | str +
132+
'\n### Answer (Accepted: '+ answer.is_accepted +
133+
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
134+
}
135+
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
136+
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
137+
ORDER BY similarity ASC // so that best answers are the last
138+
""",
139+
)
140+
141+
kg_qa = RetrievalQAWithSourcesChain(
142+
combine_documents_chain=qa_chain,
143+
retriever=kg.as_retriever(search_kwargs={"k": 2}),
144+
reduce_k_below_max_tokens=False,
145+
max_tokens_limit=3375,
146+
)
147+
return kg_qa

loader.Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ RUN pip install --upgrade -r requirements.txt
1414

1515
COPY loader.py .
1616
COPY utils.py .
17+
COPY chains.py .
1718
COPY images ./images
1819

1920
EXPOSE 8502

0 commit comments

Comments
 (0)