Skip to content

Commit f6cb57a

Browse files
Add streamlit sql sanbox demo (#5)
* Add streamlit SQL sandbox * Update gitignore, remove files
1 parent c705e9b commit f6cb57a

File tree

15 files changed

+216
-8
lines changed

15 files changed

+216
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
.envrc
22
.direnv
3-
.streamlit/secrets.toml
3+
*/.streamlit/secrets.toml
44
.mypy_cache
5+
__pycache__
56
node_modules
67
build
78
index.json

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,20 @@ There are two main example folders
3434
![react_frontend screenshot](./flask_react/react_frontend.png)
3535

3636

37-
- streamlit (runs on localhost:8501)
37+
- streamlit_vector (runs on localhost:8501)
3838
- `streamlit run streamlit_demo.py`
39-
- creates a simple UI using streamlit
39+
- creates a simple UI using streamlit and GPTSimpleVectorIndex
4040
- loads text from the documents folder (using `st.cache_resource`, so it only loads once)
4141
- provides an input text-box and a button to run the query
4242
- the string response is displayed after it finishes
4343
- want to see this example in action? Check it out [here](https://llama-index.streamlit.app/)
4444

45+
- streamlit_sql_sandbox (runs on localhost:8501)
46+
- `streamlit run streamlit_demo.py`
47+
- creates a streamlit app using a local SQL database about restaurant inspections in San Francisco ([data sample](https://docs.google.com/spreadsheets/d/1Ag5DBIviYiuRrt2yr3nXmbPFV-FOg5fDH5SM3ZEDnpw/edit#gid=780513932))
48+
- The "Setup" tab allows you to configure various LLM and LLama Index settings
49+
- The "Llama Index" tab demos some basic Text2SQL capabilities using only Llama Index
50+
- The "Langchain+Llama Index" tab uses a custom langchain agent, and uses the SQL index from Llama Index as a tool during conversations.
4551

4652
## Docker
4753
Each example contains a `Dockerfile`. You can run `docker build -t my_tag_name .` to build a python3.11-slim docker image inside your desired folder. It ends up being about 600MB-900MB depending on the example.

flask_react/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Flask==2.2.3
22
Flask-Cors==3.0.10
3-
langchain==0.0.115
4-
llama-index==0.4.30
3+
langchain==0.0.123
4+
llama-index==0.4.39
55
PyPDF2==3.0.1

streamlit/requirements.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.
File renamed without changes.

streamlit_sql_sandbox/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
DEFAULT_SQL_PATH = "sqlite:///sfscores.sqlite"
2+
DEFAULT_BUSINESS_TABLE_DESCRP = (
3+
"This table gives information on the IDs, addresses, and other location "
4+
"information for several restaurants in San Francisco. This table will "
5+
"need to be referenced when users ask about specific businesses."
6+
)
7+
DEFAULT_VIOLATIONS_TABLE_DESCRP = (
8+
"This table gives information on which business IDs have recorded health violations, "
9+
"including the date, risk, and description of each violation. The user may query "
10+
"about specific businesses, whose names can be found by mapping the business_id "
11+
"to the 'businesses' table."
12+
)
13+
DEFAULT_INSPECTIONS_TABLE_DESCRP = (
14+
"This table gives information on when each business ID was inspected, including "
15+
"the score, date, and type of inspection. The user may query about specific "
16+
"businesses, whose names can be found by mapping the business_id to the 'businesses' table."
17+
)
18+
DEFAULT_LC_TOOL_DESCRP = "Useful for when you want to answer queries about violations and inspections of businesses."
19+
20+
DEFAULT_INGEST_DOCUMENT = (
21+
"The restaurant KING-KONG had an routine unscheduled inspection on 2023/12/31. "
22+
"The business achieved a score of 50. We two violations, a high risk "
23+
"vermin infestation as well as a high risk food holding temperatures."
24+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
langchain==0.0.123
2+
llama-index==0.4.39
3+
streamlit==1.19.0
4+
streamlit-chat==0.0.2.2
9.19 MB
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"index_struct_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "docstore": {"docs": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"text": null, "doc_id": "b52fad59-0c00-4392-b775-f9cd3fdb6deb", "embedding": null, "doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4", "extra_info": null, "context_dict": {}, "__type__": "sql"}}, "ref_doc_info": {"b52fad59-0c00-4392-b775-f9cd3fdb6deb": {"doc_hash": "08a14830cef184731c6b6a0bdd67fa351d923556941aa99027b276bd839a07a4"}}}, "sql_context_container": {"context_dict": {"violations": "Schema of table violations:\nTable 'violations' has columns: business_id (TEXT), date (TEXT), ViolationTypeID (TEXT), risk_category (TEXT), description (TEXT) and foreign keys: .\nContext of table violations:\nThis table gives information on which business IDs have recorded health violations, including the date, risk, and description of each violation. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table.", "businesses": "Schema of table businesses:\nTable 'businesses' has columns: business_id (INTEGER), name (VARCHAR(64)), address (VARCHAR(50)), city (VARCHAR(23)), postal_code (VARCHAR(9)), latitude (FLOAT), longitude (FLOAT), phone_number (BIGINT), TaxCode (VARCHAR(4)), business_certificate (INTEGER), application_date (DATE), owner_name (VARCHAR(99)), owner_address (VARCHAR(74)), owner_city (VARCHAR(22)), owner_state (VARCHAR(14)), owner_zip (VARCHAR(15)) and foreign keys: .\nContext of table businesses:\nThis table gives information on the IDs, addresses, and other location information for several restaruants in San Fransisco. This table will need to be referenced when users ask about specific bussinesses.", "inspections": "Schema of table inspections:\nTable 'inspections' has columns: business_id (TEXT), Score (INTEGER), date (TEXT), type (VARCHAR(33)) and foreign keys: .\nContext of table inspections:\nThis table gives information on when each bussiness ID was inspected, including the score, date, and type of inspection. The user may query about specific businesses, whose names can be found by mapping the business_id to the 'businesses' table."}, "context_str": null}}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import streamlit as st
3+
from streamlit_chat import message as st_message
4+
from sqlalchemy import create_engine
5+
6+
from langchain.agents import Tool, initialize_agent
7+
from langchain.chains.conversation.memory import ConversationBufferMemory
8+
9+
from llama_index import GPTSQLStructStoreIndex, LLMPredictor
10+
from llama_index import SQLDatabase as llama_SQLDatabase
11+
from llama_index.indices.struct_store import SQLContextContainerBuilder
12+
13+
from constants import (
14+
DEFAULT_SQL_PATH,
15+
DEFAULT_BUSINESS_TABLE_DESCRP,
16+
DEFAULT_VIOLATIONS_TABLE_DESCRP,
17+
DEFAULT_INSPECTIONS_TABLE_DESCRP,
18+
DEFAULT_LC_TOOL_DESCRP
19+
)
20+
from utils import get_sql_index_tool, get_llm
21+
22+
# NOTE: for local testing only, do NOT deploy with your key hardcoded
23+
# to use this for yourself, create a file called .streamlit/secrets.toml with your api key
24+
# Learn more about Streamlit on the docs: https://docs.streamlit.io/
25+
os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"]
26+
27+
28+
@st.cache_resource
29+
def initialize_index(llm_name, model_temperature, table_context_dict, sql_path=DEFAULT_SQL_PATH):
30+
"""Create the GPTSQLStructStoreIndex object."""
31+
llm = get_llm(llm_name, model_temperature)
32+
33+
engine = create_engine(sql_path)
34+
sql_database = llama_SQLDatabase(engine)
35+
36+
context_container = None
37+
if table_context_dict is not None:
38+
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict)
39+
context_container = context_builder.build_context_container()
40+
41+
index = GPTSQLStructStoreIndex([],
42+
sql_database=sql_database,
43+
sql_context_container=context_container,
44+
llm_predictor=LLMPredictor(llm=llm))
45+
46+
return index
47+
48+
49+
@st.cache_resource
50+
def initialize_chain(llm_name, model_temperature, lc_descrp, _sql_index):
51+
"""Create a (rather hacky) custom agent and sql_index tool."""
52+
sql_tool = Tool(name="SQL Index",
53+
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict),
54+
description=lc_descrp)
55+
56+
llm = get_llm(llm_name, model_temperature)
57+
58+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
59+
60+
agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory)
61+
62+
return agent_chain
63+
64+
65+
st.title("🦙 Llama Index SQL Sandbox 🦙")
66+
st.markdown((
67+
"This sandbox uses a sqlite database by default, containing information on health violations and inspections at restaurants in San Francisco.\n\n"
68+
"The database contains three tables - businesses, inspections, and violations.\n\n"
69+
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
70+
))
71+
72+
73+
setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"])
74+
75+
with setup_tab:
76+
st.subheader("LLM Setup")
77+
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1)
78+
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"])
79+
80+
st.subheader("Table Setup")
81+
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP)
82+
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP)
83+
inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP)
84+
85+
table_context_dict = {"businesses": business_table_descrp,
86+
"inspections": inspections_table_descrp,
87+
"violations": violations_table_descrp}
88+
89+
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
90+
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
91+
92+
with llama_tab:
93+
st.subheader("Text2SQL with Llama Index")
94+
if st.button("Initialize Index", key="init_index_1"):
95+
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)
96+
97+
if "llama_index" in st.session_state:
98+
query_text = st.text_input("Query:")
99+
if st.button("Run Query") and query_text:
100+
with st.spinner("Getting response..."):
101+
try:
102+
response = st.session_state['llama_index'].query(query_text)
103+
response_text = str(response)
104+
response_sql = response.extra_info['sql_query']
105+
except Exception as e:
106+
response_text = "Error running SQL Query."
107+
response_sql = str(e)
108+
109+
col1, col2 = st.columns(2)
110+
with col1:
111+
st.text("SQL Result:")
112+
st.markdown(response_text)
113+
114+
with col2:
115+
st.text("SQL Query:")
116+
st.markdown(response_sql)
117+
118+
with lc_tab:
119+
st.subheader("Langchain + Llama Index SQL Demo")
120+
121+
if st.button("Initialize Agent"):
122+
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None)
123+
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, st.session_state['llama_index'])
124+
st.session_state['chat_history'] = []
125+
126+
model_input = st.text_input("Message:")
127+
if 'lc_agent' in st.session_state and st.button("Send"):
128+
model_input = "User: " + model_input
129+
st.session_state['chat_history'].append(model_input)
130+
with st.spinner("Getting response..."):
131+
response = st.session_state['lc_agent'].run(input=model_input)
132+
st.session_state['chat_history'].append(response)
133+
134+
if 'chat_history' in st.session_state:
135+
for msg in st.session_state['chat_history']:
136+
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
137+

0 commit comments

Comments
 (0)