Open
Description
Hi,..
I am part of the GenAI Bootcamp (cohort 4) and we have just completed clinic #2. I have tried to take the script that Matt gave us (10_ai_data_science_team.py) and make a streamlit web app from it. I do get the web app up and running, however the output is total nonsense, not matter what prompt i give, it just runs the exact same sql query every time:
I believe i have made all the necessary changes to the code so that it correctly takes the Langgraph dictionary outputs from the agent and incorporates it correctly into the "Render current messages from StreamlitChatMessageHistory" section of the streamlit app. Please find my code below:
# Import Libraries
import sqlalchemy as sql
import pandas as pd
from langchain_openai import ChatOpenAI
import os
import yaml
from ai_data_science_team.multiagents import SQLDataAnalyst
from ai_data_science_team.agents import SQLDatabaseAgent, DataVisualizationAgent
from pprint import pprint
from IPython.display import display, Markdown
import sqlparse
import streamlit as st
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import plotly.io as pio
# * Setup
# this SQL Agent has logging capabilities but it is turned off by default
MODEL = 'gpt-4.1-mini'
LOG = False
LOG_PATH = os.path.join(os.getcwd(), "logs/")
# Setup AI
os.environ["OPENAI_API_KEY"] = yaml.safe_load(open('credentials.yml'))['openai']
llm = ChatOpenAI(model = MODEL)
sql_engine = sql.create_engine("sqlite:///database/leads_scored.db")
conn = sql_engine.connect()
# * STREAMLIT APP SETUP ----
st.set_page_config(page_title="Your Business Intelligence AI Copilot")
st.title("Your Business Intelligence AI Copilot")
st.markdown("""
I'm a handy business intelligence agent that connects to a selected SQLite database mimicking an ERP system. You can ask me Business Intelligence, Customer Analytics, and Data Visualization Questions, and I will report the results. If I accidentally output a table where I should have output a chart, please select a higher performance LLM model and try again.
""")
# * Make the agent
sql_data_analyst = SQLDataAnalyst(
model = llm,
sql_database_agent = SQLDatabaseAgent(
model = llm,
connection = conn,
n_samples = 1,
log = LOG,
log_path = LOG_PATH,
bypass_recommended_steps=True,
),
data_visualization_agent = DataVisualizationAgent(
model = llm,
n_samples = 1000,
log = LOG,
log_path = LOG_PATH,
)
)
# * STREAMLIT
example_questions = st.expander("Try out these example questions")
with example_questions:
"""
Example Questions:
5. What are the fields in the leads_scored table?
6. What is the average p1 lead score of leads in the database?
7. What is the average p1 lead score of leads by member rating in the database?
8. Calculate the average p1 lead score of leads by member rating. Return a scatter plot. Show member rating on the x-axis. Include a linear regression line in orange. Show the formula for the linear regression in black font on a white background.
9. Which 10 customers have the highest p1 probability of purchase who have NOT purchased "Learning Labs Pro - Paid Course"?
10. What are the top 5 products for sales revenue, group by product name? Make a donut chart. Use suggested price for the sales revenue and a unit quantity of 1 for all transactions. Make sure each donut slice is a different colour. Make sure the legend is on the right, middle.
"""
# Set up memory
msgs = StreamlitChatMessageHistory(key="langchain_messages")
if len(msgs.messages) == 0:
msgs.add_ai_message("How can I help you?")
# Initialize plot storage in session state
if "plots" not in st.session_state:
st.session_state.plots = []
# Initialize dataframe storage in session state
if "dataframes" not in st.session_state:
st.session_state.dataframes = []
# Function to display chat messages including Plotly charts and dataframes
def display_chat_history():
for i, msg in enumerate(msgs.messages):
with st.chat_message(msg.type):
if "PLOT_INDEX:" in msg.content:
plot_index = int(msg.content.split("PLOT_INDEX:")[1])
st.plotly_chart(st.session_state.plots[plot_index], key=f"history_plot_{plot_index}")
elif "DATAFRAME_INDEX:" in msg.content:
df_index = int(msg.content.split("DATAFRAME_INDEX:")[1])
st.dataframe(st.session_state.dataframes[df_index], key=f"history_dataframe_{df_index}")
else:
st.write(msg.content)
# Render current messages from StreamlitChatMessageHistory
display_chat_history()
if question := st.chat_input("Enter your question here:", key="query_input"):
with st.spinner("Thinking..."):
st.chat_message("human").write(question)
msgs.add_user_message(question)
# Run the app
inputs = {"user_question": question}
error_occured = False
try:
result = sql_data_analyst.invoke(inputs)
except Exception as e:
error_occured = True
print(e)
if not error_occured:
if result['routing_preprocessor_decision'] == 'table':
# Table was requested
response_text = f"Returning the table...\n\nSQL Query:\n```sql\n{result['sql_database_function']}\n```"
response_df = pd.DataFrame(result['data_sql'])
# Store the dataframe and keep its index
df_index = len(st.session_state.dataframes)
st.session_state.dataframes.append(response_df)
# Store the response text and dataframe index in the messages
msgs.add_ai_message(response_text)
msgs.add_ai_message(f"DATAFRAME_INDEX:{df_index}")
st.chat_message("ai").write(response_text)
st.dataframe(response_df)
elif result['routing_preprocessor_decision'] == 'chart' and result['chart_plotly_error'] is False:
# Chart was requested and produced correctly
response_text = f"Returning the plot...\n\nSQL Query:\n```sql\n{result['sql_database_function']}\n```"
response_plot = pio.from_json(result["chart_plotly_json"])
# Store the plot and keep its index
plot_index = len(st.session_state.plots)
st.session_state.plots.append(response_plot)
# Store the response text and plot index in the messages
msgs.add_ai_message(response_text)
msgs.add_ai_message(f"PLOT_INDEX:{plot_index}")
st.chat_message("ai").write(response_text)
st.plotly_chart(response_plot)
else:
# Chart error occurred, return Table instead
response_text = f"I apologize. There was an error during the plotting process. Returning the table instead...\n\nSQL
Query:\n```sql\n{result['sql_database_function']}\n```"
df = pd.DataFrame(result['data_sql'])
# Store the dataframe and keep its index
df_index = len(st.session_state.dataframes)
st.session_state.dataframes.append(df)
# Store the response text and dataframe index in the messages
msgs.add_ai_message(response_text)
msgs.add_ai_message(f"DATAFRAME_INDEX:{df_index}")
st.chat_message("ai").write(response_text)
st.dataframe(df)
else:
# SQL error occurred
response_text = f"An error occurred in generating the SQL. I apologize. Please try again or format the question differently and I'll try my
best to provide a helpful answer."
msgs.add_ai_message(response_text)
st.chat_message("ai").write(response_text)
Metadata
Metadata
Assignees
Labels
No labels