Skip to content

Commit 54568ab

Browse files
Update to LlamaIndex v0.6.13 (#18)
* update streamlit demos to llama-index v0.6.13 * update flask demo to use llama_index v0.6.13
1 parent 336256b commit 54568ab

File tree

18 files changed

+213
-104
lines changed

18 files changed

+213
-104
lines changed

flask_react/index_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
from multiprocessing import Lock
88
from multiprocessing.managers import BaseManager
9-
from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex, Document, ServiceContext
9+
from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, Document, ServiceContext, StorageContext, load_index_from_storage
1010

1111
index = None
1212
stored_docs = {}
1313
lock = Lock()
1414

15-
index_name = "./index.json"
15+
index_name = "./saved_index"
1616
pkl_name = "stored_documents.pkl"
1717

1818

@@ -23,10 +23,10 @@ def initialize_index():
2323
service_context = ServiceContext.from_defaults(chunk_size_limit=512)
2424
with lock:
2525
if os.path.exists(index_name):
26-
index = GPTSimpleVectorIndex.load_from_disk(index_name, service_context=service_context)
26+
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=index_name), service_context=service_context)
2727
else:
28-
index = GPTSimpleVectorIndex([], service_context=service_context)
29-
index.save_to_disk(index_name)
28+
index = GPTVectorStoreIndex([], service_context=service_context)
29+
index.storage_context.persist(persist_dir=index_name)
3030
if os.path.exists(pkl_name):
3131
with open(pkl_name, "rb") as f:
3232
stored_docs = pickle.load(f)
@@ -35,7 +35,7 @@ def initialize_index():
3535
def query_index(query_text):
3636
"""Query the global index."""
3737
global index
38-
response = index.query(query_text)
38+
response = index.as_query_engine().query(query_text)
3939
return response
4040

4141

@@ -51,7 +51,7 @@ def insert_into_index(doc_file_path, doc_id=None):
5151
stored_docs[document.doc_id] = document.text[0:200] # only take the first 200 chars
5252

5353
index.insert(document)
54-
index.save_to_disk(index_name)
54+
index.storage_context.persist(persist_dir=index_name)
5555

5656
with open(pkl_name, "wb") as f:
5757
pickle.dump(stored_docs, f)

flask_react/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Flask==2.2.3
22
Flask-Cors==3.0.10
3-
langchain==0.0.128
4-
llama-index==0.5.4
5-
PyPDF2==3.0.1
3+
langchain==0.0.154
4+
llama-index==0.6.13
5+
pypdf==3.9.0

streamlit_sql_sandbox/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
2222
"The business achieved a score of 50. We two violations, a high risk "
2323
"vermin infestation as well as a high risk food holding temperatures."
24-
)
24+
)
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
langchain==0.0.128
2-
llama-index==0.5.4
1+
altair==4.2.2
2+
langchain==0.0.154
3+
llama-index==0.6.13
34
streamlit==1.19.0
45
streamlit-chat==0.0.2.2
6+
transformers==4.29.2

streamlit_sql_sandbox/sql_index.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

streamlit_sql_sandbox/streamlit_demo.py

Lines changed: 100 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
DEFAULT_BUSINESS_TABLE_DESCRP,
1616
DEFAULT_VIOLATIONS_TABLE_DESCRP,
1717
DEFAULT_INSPECTIONS_TABLE_DESCRP,
18-
DEFAULT_LC_TOOL_DESCRP
18+
DEFAULT_LC_TOOL_DESCRP,
1919
)
2020
from utils import get_sql_index_tool, get_llm
2121

2222

2323
@st.cache_resource
24-
def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH):
24+
def initialize_index(
25+
llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH
26+
):
2527
"""Create the GPTSQLStructStoreIndex object."""
2628
llm = get_llm(llm_name, model_temperature, api_key)
2729

@@ -30,78 +32,116 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
3032

3133
context_container = None
3234
if table_context_dict is not None:
33-
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
35+
context_builder = SQLContextContainerBuilder(
36+
sql_database, context_dict=table_context_dict
37+
)
3438
context_container = context_builder.build_context_container()
35-
39+
3640
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
37-
index = GPTSQLStructStoreIndex([],
38-
sql_database=sql_database,
39-
sql_context_container=context_container,
40-
service_context=service_context)
41+
index = GPTSQLStructStoreIndex(
42+
[],
43+
sql_database=sql_database,
44+
sql_context_container=context_container,
45+
service_context=service_context,
46+
)
4147

4248
return index
4349

4450

4551
@st.cache_resource
4652
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
4753
"""Create a (rather hacky) custom agent and sql_index tool."""
48-
sql_tool = Tool(name="SQL Index",
49-
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
50-
description=lc_descrp)
54+
sql_tool = Tool(
55+
name="SQL Index",
56+
func=get_sql_index_tool(
57+
_sql_index, _sql_index.sql_context_container.context_dict
58+
),
59+
description=lc_descrp,
60+
)
5161

5262
llm = get_llm(llm_name, model_temperature, api_key=api_key)
5363

5464
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
5565

56-
agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)
66+
agent_chain = initialize_agent(
67+
[sql_tool],
68+
llm,
69+
agent="chat-conversational-react-description",
70+
verbose=True,
71+
memory=memory,
72+
)
5773

5874
return agent_chain
5975

6076

6177
st.title("🦙 Llama Index SQL Sandbox 🦙")
62-
st.markdown((
63-
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
64-
"The database contains information on health violations and inspections at restaurants in San Francisco."
65-
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
66-
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
67-
"The other tabs will perform chatbot and text2sql operations.\n\n"
68-
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
69-
))
78+
st.markdown(
79+
(
80+
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
81+
"The database contains information on health violations and inspections at restaurants in San Francisco."
82+
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
83+
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
84+
"The other tabs will perform chatbot and text2sql operations.\n\n"
85+
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
86+
)
87+
)
7088

7189

72-
setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])
90+
setup_tab, llama_tab, lc_tab = st.tabs(
91+
["Setup", "Llama Index", "Langchain+Llama Index"]
92+
)
7393

7494
with setup_tab:
7595
st.subheader("LLM Setup")
7696
api_key = st.text_input("Enter your OpenAI API key here", type="password")
77-
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])
78-
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
97+
llm_name = st.selectbox(
98+
"Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
99+
)
100+
model_temperature = st.slider(
101+
"LLM Temperature", min_value=0.0, max_value=1.0, step=0.1
102+
)
79103

80104
st.subheader("Table Setup")
81-
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
82-
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
83-
inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)
84-
85-
table_context_dict = {"businesses": business_table_descrp,
86-
"inspections": inspections_table_descrp,
87-
"violations": violations_table_descrp}
88-
105+
business_table_descrp = st.text_area(
106+
"Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP
107+
)
108+
violations_table_descrp = st.text_area(
109+
"Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP
110+
)
111+
inspections_table_descrp = st.text_area(
112+
"Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP
113+
)
114+
115+
table_context_dict = {
116+
"businesses": business_table_descrp,
117+
"inspections": inspections_table_descrp,
118+
"violations": violations_table_descrp,
119+
}
120+
89121
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
90122
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
91123

92124
with llama_tab:
93125
st.subheader("Text2SQL with Llama Index")
94126
if st.button("Initialize Index", key="init_index_1"):
95-
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
96-
127+
st.session_state["llama_index"] = initialize_index(
128+
llm_name,
129+
model_temperature,
130+
table_context_dict if use_table_descrp else None,
131+
api_key,
132+
)
133+
97134
if "llama_index" in st.session_state:
98-
query_text = st.text_input("Query:", value="Which restaurant has the most violations?")
135+
query_text = st.text_input(
136+
"Query:", value="Which restaurant has the most violations?"
137+
)
138+
use_nl = st.checkbox("Return natural language response?")
99139
if st.button("Run Query") and query_text:
100140
with st.spinner("Getting response..."):
101141
try:
102-
response = st.session_state['llama_index'].query(query_text)
142+
response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text)
103143
response_text = str(response)
104-
response_sql = response.extra_info['sql_query']
144+
response_sql = response.extra_info["sql_query"]
105145
except Exception as e:
106146
response_text = "Error running SQL Query."
107147
response_sql = str(e)
@@ -119,19 +159,31 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index
119159
st.subheader("Langchain + Llama Index SQL Demo")
120160

121161
if st.button("Initialize Agent"):
122-
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key)
123-
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index'])
124-
st.session_state['chat_history'] = []
125-
126-
model_input = st.text_input("Message:", value="Which restaurant has the most violations?")
127-
if 'lc_agent' in st.session_state and st.button("Send"):
162+
st.session_state["llama_index"] = initialize_index(
163+
llm_name,
164+
model_temperature,
165+
table_context_dict if use_table_descrp else None,
166+
api_key,
167+
)
168+
st.session_state["lc_agent"] = initialize_chain(
169+
llm_name,
170+
model_temperature,
171+
lc_descrp,
172+
api_key,
173+
st.session_state["llama_index"],
174+
)
175+
st.session_state["chat_history"] = []
176+
177+
model_input = st.text_input(
178+
"Message:", value="Which restaurant has the most violations?"
179+
)
180+
if "lc_agent" in st.session_state and st.button("Send"):
128181
model_input = "User: " + model_input
129-
st.session_state['chat_history'].append(model_input)
182+
st.session_state["chat_history"].append(model_input)
130183
with st.spinner("Getting response..."):
131-
response = st.session_state['lc_agent'].run(input=model_input)
132-
st.session_state['chat_history'].append(response)
184+
response = st.session_state["lc_agent"].run(input=model_input)
185+
st.session_state["chat_history"].append(response)
133186

134-
if 'chat_history' in st.session_state:
135-
for msg in st.session_state['chat_history']:
187+
if "chat_history" in st.session_state:
188+
for msg in st.session_state["chat_history"]:
136189
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
137-

streamlit_sql_sandbox/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@
55

66
def get_sql_index_tool(sql_index, table_context_dict):
77
table_context_str = "\n".join(table_context_dict.values())
8+
89
def run_sql_index_query(query_text):
910
try:
10-
response = sql_index.query(query_text)
11+
response = sql_index.as_query_engine(synthesize_response=False).query(query_text)
1112
except Exception as e:
1213
return f"Error running SQL {e}.\nNot able to retrieve answer."
1314
text = str(response)
14-
sql = response.extra_info['sql_query']
15+
sql = response.extra_info["sql_query"]
1516
return f"Here are the details on the SQL table: {table_context_str}\nSQL Query Used: {sql}\nSQL Result: {text}\n"
16-
#return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
17+
# return f"SQL Query Used: {sql}\nSQL Result: {text}\n"
18+
1719
return run_sql_index_query
1820

1921

2022
def get_llm(llm_name, model_temperature, api_key):
21-
os.environ['OPENAI_API_KEY'] = api_key
23+
os.environ["OPENAI_API_KEY"] = api_key
2224
if llm_name == "text-davinci-003":
2325
return OpenAI(temperature=model_temperature, model_name=llm_name)
2426
else:

streamlit_term_definition/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"---------------------\n"
1414
"{context_str}"
1515
"\n---------------------\n"
16-
"Given the context information answer the following question "
16+
"Given the context information, directly answer the following question "
1717
"(if you don't know the answer, use the best of your knowledge): {query_str}\n"
1818
)
1919
TEXT_QA_TEMPLATE = QuestionAnswerPrompt(DEFAULT_TEXT_QA_PROMPT_TMPL)
@@ -29,6 +29,7 @@
2929
"------------\n"
3030
"Given the new context and using the best of your knowledge, improve the existing answer. "
3131
"If you can't improve the existing answer, just repeat it again. "
32+
"Do not include un-needed or un-helpful information that is shown in the new context. "
3233
"Do not mention that you've read the above context."
3334
)
3435
DEFAULT_REFINE_PROMPT = RefinePrompt(DEFAULT_REFINE_PROMPT_TMPL)
@@ -44,6 +45,7 @@
4445
"------------\n"
4546
"Given the new context and using the best of your knowledge, improve the existing answer. "
4647
"If you can't improve the existing answer, just repeat it again. "
48+
"Do not include un-needed or un-helpful information that is shown in the new context. "
4749
"Do not mention that you've read the above context."
4850
),
4951
]
@@ -56,9 +58,7 @@
5658
default_prompt=DEFAULT_REFINE_PROMPT.get_langchain_prompt(),
5759
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT.get_langchain_prompt())],
5860
)
59-
REFINE_TEMPLATE = RefinePrompt(
60-
langchain_prompt_selector=DEFAULT_REFINE_PROMPT_SEL_LC
61-
)
61+
REFINE_TEMPLATE = RefinePrompt(langchain_prompt_selector=DEFAULT_REFINE_PROMPT_SEL_LC)
6262

6363
DEFAULT_TERM_STR = (
6464
"Make a list of terms and definitions that are defined in the context, "

streamlit_term_definition/initial_index/docstore.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)