1515 DEFAULT_BUSINESS_TABLE_DESCRP ,
1616 DEFAULT_VIOLATIONS_TABLE_DESCRP ,
1717 DEFAULT_INSPECTIONS_TABLE_DESCRP ,
18- DEFAULT_LC_TOOL_DESCRP
18+ DEFAULT_LC_TOOL_DESCRP ,
1919)
2020from utils import get_sql_index_tool , get_llm
2121
2222
2323@st .cache_resource
24- def initialize_index (llm_name , model_temperature , table_context_dict , api_key , sql_path = DEFAULT_SQL_PATH ):
24+ def initialize_index (
25+ llm_name , model_temperature , table_context_dict , api_key , sql_path = DEFAULT_SQL_PATH
26+ ):
2527 """Create the GPTSQLStructStoreIndex object."""
2628 llm = get_llm (llm_name , model_temperature , api_key )
2729
@@ -30,78 +32,116 @@ def initialize_index(llm_name, model_temperature, table_context_dict, api_key, s
3032
3133 context_container = None
3234 if table_context_dict is not None :
33- context_builder = SQLContextContainerBuilder (sql_database , context_dict = table_context_dict )
35+ context_builder = SQLContextContainerBuilder (
36+ sql_database , context_dict = table_context_dict
37+ )
3438 context_container = context_builder .build_context_container ()
35-
39+
3640 service_context = ServiceContext .from_defaults (llm_predictor = LLMPredictor (llm = llm ))
37- index = GPTSQLStructStoreIndex ([],
38- sql_database = sql_database ,
39- sql_context_container = context_container ,
40- service_context = service_context )
41+ index = GPTSQLStructStoreIndex (
42+ [],
43+ sql_database = sql_database ,
44+ sql_context_container = context_container ,
45+ service_context = service_context ,
46+ )
4147
4248 return index
4349
4450
4551@st .cache_resource
4652def initialize_chain (llm_name , model_temperature , lc_descrp , api_key , _sql_index ):
4753 """Create a (rather hacky) custom agent and sql_index tool."""
48- sql_tool = Tool (name = "SQL Index" ,
49- func = get_sql_index_tool (_sql_index , _sql_index .sql_context_container .context_dict ),
50- description = lc_descrp )
54+ sql_tool = Tool (
55+ name = "SQL Index" ,
56+ func = get_sql_index_tool (
57+ _sql_index , _sql_index .sql_context_container .context_dict
58+ ),
59+ description = lc_descrp ,
60+ )
5161
5262 llm = get_llm (llm_name , model_temperature , api_key = api_key )
5363
5464 memory = ConversationBufferMemory (memory_key = "chat_history" , return_messages = True )
5565
56- agent_chain = initialize_agent ([sql_tool ], llm , agent = "chat-conversational-react-description" , verbose = True , memory = memory )
66+ agent_chain = initialize_agent (
67+ [sql_tool ],
68+ llm ,
69+ agent = "chat-conversational-react-description" ,
70+ verbose = True ,
71+ memory = memory ,
72+ )
5773
5874 return agent_chain
5975
6076
6177st .title ("🦙 Llama Index SQL Sandbox 🦙" )
62- st .markdown ((
63- "This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n \n "
64- "The database contains information on health violations and inspections at restaurants in San Francisco."
65- "This data is spread across three tables - businesses, inspections, and violations.\n \n "
66- "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
67- "The other tabs will perform chatbot and text2sql operations.\n \n "
68- "Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
69- ))
78+ st .markdown (
79+ (
80+ "This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n \n "
81+ "The database contains information on health violations and inspections at restaurants in San Francisco."
82+ "This data is spread across three tables - businesses, inspections, and violations.\n \n "
83+ "Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
84+ "The other tabs will perform chatbot and text2sql operations.\n \n "
85+ "Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
86+ )
87+ )
7088
7189
72- setup_tab , llama_tab , lc_tab = st .tabs (["Setup" , "Llama Index" , "Langchain+Llama Index" ])
90+ setup_tab , llama_tab , lc_tab = st .tabs (
91+ ["Setup" , "Llama Index" , "Langchain+Llama Index" ]
92+ )
7393
7494with setup_tab :
7595 st .subheader ("LLM Setup" )
7696 api_key = st .text_input ("Enter your OpenAI API key here" , type = "password" )
77- llm_name = st .selectbox ('Which LLM?' , ["text-davinci-003" , "gpt-3.5-turbo" , "gpt-4" ])
78- model_temperature = st .slider ("LLM Temperature" , min_value = 0.0 , max_value = 1.0 , step = 0.1 )
97+ llm_name = st .selectbox (
98+ "Which LLM?" , ["text-davinci-003" , "gpt-3.5-turbo" , "gpt-4" ]
99+ )
100+ model_temperature = st .slider (
101+ "LLM Temperature" , min_value = 0.0 , max_value = 1.0 , step = 0.1
102+ )
79103
80104 st .subheader ("Table Setup" )
81- business_table_descrp = st .text_area ("Business table description" , value = DEFAULT_BUSINESS_TABLE_DESCRP )
82- violations_table_descrp = st .text_area ("Business table description" , value = DEFAULT_VIOLATIONS_TABLE_DESCRP )
83- inspections_table_descrp = st .text_area ("Business table description" , value = DEFAULT_INSPECTIONS_TABLE_DESCRP )
84-
85- table_context_dict = {"businesses" : business_table_descrp ,
86- "inspections" : inspections_table_descrp ,
87- "violations" : violations_table_descrp }
88-
105+ business_table_descrp = st .text_area (
106+ "Business table description" , value = DEFAULT_BUSINESS_TABLE_DESCRP
107+ )
108+ violations_table_descrp = st .text_area (
109+ "Business table description" , value = DEFAULT_VIOLATIONS_TABLE_DESCRP
110+ )
111+ inspections_table_descrp = st .text_area (
112+ "Business table description" , value = DEFAULT_INSPECTIONS_TABLE_DESCRP
113+ )
114+
115+ table_context_dict = {
116+ "businesses" : business_table_descrp ,
117+ "inspections" : inspections_table_descrp ,
118+ "violations" : violations_table_descrp ,
119+ }
120+
89121 use_table_descrp = st .checkbox ("Use table descriptions?" , value = True )
90122 lc_descrp = st .text_area ("LangChain Tool Description" , value = DEFAULT_LC_TOOL_DESCRP )
91123
92124with llama_tab :
93125 st .subheader ("Text2SQL with Llama Index" )
94126 if st .button ("Initialize Index" , key = "init_index_1" ):
95- st .session_state ['llama_index' ] = initialize_index (llm_name , model_temperature , table_context_dict if use_table_descrp else None , api_key )
96-
127+ st .session_state ["llama_index" ] = initialize_index (
128+ llm_name ,
129+ model_temperature ,
130+ table_context_dict if use_table_descrp else None ,
131+ api_key ,
132+ )
133+
97134 if "llama_index" in st .session_state :
98- query_text = st .text_input ("Query:" , value = "Which restaurant has the most violations?" )
135+ query_text = st .text_input (
136+ "Query:" , value = "Which restaurant has the most violations?"
137+ )
138+ use_nl = st .checkbox ("Return natural language response?" )
99139 if st .button ("Run Query" ) and query_text :
100140 with st .spinner ("Getting response..." ):
101141 try :
102- response = st .session_state [' llama_index' ] .query (query_text )
142+ response = st .session_state [" llama_index" ]. as_query_engine ( synthesize_response = use_nl ) .query (query_text )
103143 response_text = str (response )
104- response_sql = response .extra_info [' sql_query' ]
144+ response_sql = response .extra_info [" sql_query" ]
105145 except Exception as e :
106146 response_text = "Error running SQL Query."
107147 response_sql = str (e )
@@ -119,19 +159,31 @@ def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index
119159 st .subheader ("Langchain + Llama Index SQL Demo" )
120160
121161 if st .button ("Initialize Agent" ):
122- st .session_state ['llama_index' ] = initialize_index (llm_name , model_temperature , table_context_dict if use_table_descrp else None , api_key )
123- st .session_state ['lc_agent' ] = initialize_chain (llm_name , model_temperature , lc_descrp , api_key , st .session_state ['llama_index' ])
124- st .session_state ['chat_history' ] = []
125-
126- model_input = st .text_input ("Message:" , value = "Which restaurant has the most violations?" )
127- if 'lc_agent' in st .session_state and st .button ("Send" ):
162+ st .session_state ["llama_index" ] = initialize_index (
163+ llm_name ,
164+ model_temperature ,
165+ table_context_dict if use_table_descrp else None ,
166+ api_key ,
167+ )
168+ st .session_state ["lc_agent" ] = initialize_chain (
169+ llm_name ,
170+ model_temperature ,
171+ lc_descrp ,
172+ api_key ,
173+ st .session_state ["llama_index" ],
174+ )
175+ st .session_state ["chat_history" ] = []
176+
177+ model_input = st .text_input (
178+ "Message:" , value = "Which restaurant has the most violations?"
179+ )
180+ if "lc_agent" in st .session_state and st .button ("Send" ):
128181 model_input = "User: " + model_input
129- st .session_state [' chat_history' ].append (model_input )
182+ st .session_state [" chat_history" ].append (model_input )
130183 with st .spinner ("Getting response..." ):
131- response = st .session_state [' lc_agent' ].run (input = model_input )
132- st .session_state [' chat_history' ].append (response )
184+ response = st .session_state [" lc_agent" ].run (input = model_input )
185+ st .session_state [" chat_history" ].append (response )
133186
134- if ' chat_history' in st .session_state :
135- for msg in st .session_state [' chat_history' ]:
187+ if " chat_history" in st .session_state :
188+ for msg in st .session_state [" chat_history" ]:
136189 st_message (msg .split ("User: " )[- 1 ], is_user = "User: " in msg )
137-
0 commit comments