11import os
2- from typing import List , Any
32
43import streamlit as st
54from streamlit .logger import get_logger
65from langchain .callbacks .base import BaseCallbackHandler
7- from langchain .vectorstores .neo4j_vector import Neo4jVector
8-
9- from langchain .chat_models import ChatOpenAI , ChatOllama
10- from langchain .chains import RetrievalQAWithSourcesChain
11- from langchain .chains .qa_with_sources import load_qa_with_sources_chain
126from langchain .prompts .chat import (
137 ChatPromptTemplate ,
148 SystemMessagePromptTemplate ,
159 HumanMessagePromptTemplate ,
1610)
1711from langchain .graphs import Neo4jGraph
1812from dotenv import load_dotenv
19- from utils import extract_title_and_question , load_embedding_model
13+ from utils import (
14+ extract_title_and_question ,
15+ create_vector_index ,
16+ )
17+ from chains import (
18+ load_embedding_model ,
19+ load_llm ,
20+ configure_llm_only_chain ,
21+ configure_qa_rag_chain ,
22+ )
2023
2124load_dotenv (".env" )
2225
3336
3437# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
3538neo4j_graph = Neo4jGraph (url = url , username = username , password = password )
36-
37-
38- def create_vector_index (dimension : int ) -> None :
39- index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
40- try :
41- neo4j_graph .query (index_query , {"dimension" : dimension })
42- except : # Already exists
43- pass
44- index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
45- try :
46- neo4j_graph .query (index_query , {"dimension" : dimension })
47- except : # Already exists
48- pass
39+ embeddings , dimension = load_embedding_model (
40+ embedding_model_name , config = {ollama_base_url : ollama_base_url }, logger = logger
41+ )
42+ create_vector_index (neo4j_graph , dimension )
4943
5044
5145class StreamHandler (BaseCallbackHandler ):
@@ -58,142 +52,11 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
5852 self .container .markdown (self .text )
5953
6054
61- embeddings , dimension = load_embedding_model (
62- embedding_model_name , config = {ollama_base_url : ollama_base_url }, logger = logger
63- )
64-
65- create_vector_index (dimension )
66-
67- if llm_name == "gpt-4" :
68- llm = ChatOpenAI (temperature = 0 , model_name = "gpt-4" , streaming = True )
69- logger .info ("LLM: Using GPT-4" )
70- elif llm_name == "gpt-3.5" :
71- llm = ChatOpenAI (temperature = 0 , model_name = "gpt-3.5-turbo" , streaming = True )
72- logger .info ("LLM: Using GPT-3.5 Turbo" )
73- elif len (llm_name ):
74- llm = ChatOllama (
75- temperature = 0 ,
76- base_url = ollama_base_url ,
77- model = llm_name ,
78- streaming = True ,
79- top_k = 10 , # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
80- top_p = 0.3 , # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
81- num_ctx = 3072 , # Sets the size of the context window used to generate the next token.
82- )
83- logger .info (f"LLM: Using Ollama ({ llm_name } )" )
84- else :
85- llm = ChatOpenAI (temperature = 0 , model_name = "gpt-3.5-turbo" , streaming = True )
86- logger .info ("LLM: Using GPT-3.5 Turbo" )
87-
88- # LLM only response
89- template = """
90- You are a helpful assistant that helps a support agent with answering programming questions.
91- If you don't know the answer, just say that you don't know, don't try to make up an answer.
92- """
93- system_message_prompt = SystemMessagePromptTemplate .from_template (template )
94- human_template = "{text}"
95- human_message_prompt = HumanMessagePromptTemplate .from_template (human_template )
96- chat_prompt = ChatPromptTemplate .from_messages (
97- [system_message_prompt , human_message_prompt ]
98- )
99-
100-
101- def generate_llm_output (
102- user_input : str , callbacks : List [Any ], prompt = chat_prompt
103- ) -> str :
104- answer = llm (
105- prompt .format_prompt (
106- text = user_input ,
107- ).to_messages (),
108- callbacks = callbacks ,
109- ).content
110- return {"answer" : answer }
111-
112-
113- # Vector response
114- neo4j_db = Neo4jVector .from_existing_index (
115- embedding = embeddings ,
116- url = url ,
117- username = username ,
118- password = password ,
119- database = "neo4j" , # neo4j by default
120- index_name = "top_answers" , # vector by default
121- text_node_property = "body" , # text by default
122- retrieval_query = """
123- OPTIONAL MATCH (node)-[:ANSWERS]->(question)
124- RETURN 'Question: ' + question.title + '\n ' + question.body + '\n Answer: ' +
125- coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
126- ORDER BY score ASC // so that best answer are the last
127- """ ,
128- )
129-
130- general_system_template = """
131- Use the following pieces of context to answer the question at the end.
132- The context contains question-answer pairs and their links from Stackoverflow.
133- You should prefer information from accepted or more upvoted answers.
134- Make sure to rely on information from the answers and not on questions to provide accuate responses.
135- When you find particular answer in the context useful, make sure to cite it in the answer using the link.
136- If you don't know the answer, just say that you don't know, don't try to make up an answer.
137- ----
138- {summaries}
139- ----
140- Each answer you generate should contain a section at the end of links to
141- Stackoverflow questions and answers you found useful, which are described under Source value.
142- You can only use links to StackOverflow questions that are present in the context and always
143- add links to the end of the answer in the style of citations.
144- Generate concise answers with references sources section of links to
145- relevant StackOverflow questions only at the end of the answer.
146- """
147- general_user_template = "Question:```{question}```"
148- messages = [
149- SystemMessagePromptTemplate .from_template (general_system_template ),
150- HumanMessagePromptTemplate .from_template (general_user_template ),
151- ]
152- qa_prompt = ChatPromptTemplate .from_messages (messages )
153-
154- qa_chain = load_qa_with_sources_chain (
155- llm ,
156- chain_type = "stuff" ,
157- prompt = qa_prompt ,
158- )
159- qa = RetrievalQAWithSourcesChain (
160- combine_documents_chain = qa_chain ,
161- retriever = neo4j_db .as_retriever (search_kwargs = {"k" : 2 }),
162- reduce_k_below_max_tokens = True ,
163- max_tokens_limit = 3375 ,
164- )
165-
166- # Vector + Knowledge Graph response
167- kg = Neo4jVector .from_existing_index (
168- embedding = embeddings ,
169- url = url ,
170- username = username ,
171- password = password ,
172- database = "neo4j" , # neo4j by default
173- index_name = "stackoverflow" , # vector by default
174- text_node_property = "body" , # text by default
175- retrieval_query = """
176- WITH node AS question, score AS similarity
177- CALL { with question
178- MATCH (question)<-[:ANSWERS]-(answer)
179- WITH answer
180- ORDER BY answer.is_accepted DESC, answer.score DESC
181- WITH collect(answer)[..2] as answers
182- RETURN reduce(str='', answer IN answers | str +
183- '\n ### Answer (Accepted: '+ answer.is_accepted +
184- ' Score: ' + answer.score+ '): '+ answer.body + '\n ') as answerTexts
185- }
186- RETURN '##Question: ' + question.title + '\n ' + question.body + '\n '
187- + answerTexts AS text, similarity as score, {source: question.link} AS metadata
188- ORDER BY similarity ASC // so that best answers are the last
189- """ ,
190- )
55+ llm = load_llm (llm_name , logger = logger , config = {"ollama_base_url" : ollama_base_url })
19156
192- kg_qa = RetrievalQAWithSourcesChain (
193- combine_documents_chain = qa_chain ,
194- retriever = kg .as_retriever (search_kwargs = {"k" : 2 }),
195- reduce_k_below_max_tokens = False ,
196- max_tokens_limit = 3375 ,
57+ llm_chain = configure_llm_only_chain (llm )
58+ rag_chain = configure_qa_rag_chain (
59+ llm , embeddings , embeddings_store_url = url , username = username , password = password
19760)
19861
19962# Streamlit UI
@@ -280,11 +143,9 @@ def mode_select() -> str:
280143
281144name = mode_select ()
282145if name == "LLM only" or name == "Disabled" :
283- output_function = generate_llm_output
284- elif name == "Vector" :
285- output_function = qa
146+ output_function = llm_chain
286147elif name == "Vector + Graph" or name == "Enabled" :
287- output_function = kg_qa
148+ output_function = rag_chain
288149
289150
290151def generate_ticket ():
@@ -337,7 +198,7 @@ def generate_ticket():
337198 HumanMessagePromptTemplate .from_template ("{text}" ),
338199 ]
339200 )
340- llm_response = generate_llm_output (
201+ llm_response = llm_chain (
341202 f"Here's the question to rewrite in the expected format: ```{ q_prompt } ```" ,
342203 [],
343204 chat_prompt ,
0 commit comments