Skip to content

Commit 8a179d8

Browse files
committed
fix: fix UI and add some extra-features
1 parent a5f89dd commit 8a179d8

File tree

7 files changed

+161
-12
lines changed

7 files changed

+161
-12
lines changed

.env

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
GROQ_API_KEY=""
1+
GROQ_API_KEY="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"
22
GROQ_MODEL="llama-3.3-70b-specdec"
3-
ANALYZING_MODEL="deepseek-r1-distill-llama-70b-specdec"
3+
ANALYZING_MODEL="deepseek-r1-distill-llama-70b"
44
OLLAMA_MODEL="llama3.2"

app.py

+61-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import streamlit as st
22
from performer.performer import graph
33
from agentstate.agent_state import AgentState
4+
from utils.sql_utils import extract_sql_queries
45

56
st.title("PostgreSQL Database Optimization Assistant")
67
st.markdown("""
@@ -29,28 +30,83 @@ def run_analysis():
2930
if "analysis" in event:
3031
st.session_state.analysis_history.append(event["analysis"])
3132

33+
def extract_thinking_content(analysis):
34+
"""Extracts the content between <think> and </think> tags."""
35+
start_tag = "<think>"
36+
end_tag = "</think>"
37+
start_index = analysis.find(start_tag)
38+
end_index = analysis.find(end_tag)
39+
if start_index != -1 and end_index != -1:
40+
return analysis[start_index + len(start_tag):end_index].strip()
41+
return None
42+
43+
def execute_sql_queries(edited_queries):
44+
"""Extract and execute SQL queries from the final analysis."""
45+
current_state = graph.get_state(thread)
46+
schema = current_state.values.get("schema", "")
47+
48+
# Update state with edited queries
49+
graph.update_state(thread, {"execute_query": edited_queries})
50+
51+
# Execute each query
52+
for query in edited_queries.split(";"):
53+
query = query.strip()
54+
if query: # Skip empty queries
55+
st.write(f"Executing: {query}")
56+
try:
57+
# Simulate execution (replace with actual SQL execution logic if needed)
58+
st.write("Query executed successfully.")
59+
except Exception as e:
60+
st.error(f"Failed to execute query: {query}. Error: {str(e)}")
61+
3262
if st.button("Analyze"):
3363
if query and schema:
3464
run_analysis()
3565
else:
3666
st.error("Please provide both query and schema")
3767

3868
if st.session_state.analysis_history:
69+
latest_analysis = st.session_state.analysis_history[-1]
70+
start_tag = "<think>"
71+
end_tag = "</think>"
72+
if start_tag in latest_analysis and end_tag in latest_analysis:
73+
start_index = latest_analysis.index(start_tag) + len(start_tag)
74+
end_index = latest_analysis.index(end_tag)
75+
thinking_content = latest_analysis[start_index:end_index].strip()
76+
else:
77+
thinking_content = "No detailed reasoning available for this analysis."
78+
79+
with st.expander("Thinking Mode: View Detailed Reasoning", expanded=False):
80+
st.markdown(thinking_content)
81+
3982
st.subheader("Analysis History")
4083
for i, analysis in enumerate(st.session_state.analysis_history, 1):
4184
st.write(f"Iteration {i}:")
42-
st.code(analysis)
85+
st.markdown(analysis)
4386

4487
st.subheader("Feedback")
4588
col1, col2 = st.columns(2)
4689

4790
with col1:
4891
if st.button("Yes - Accept Analysis"):
4992
graph.update_state(thread, {"execute": True})
50-
st.success("Analysis accepted! Proceeding to execution...")
51-
st.query_params = {"status": "completed"}
52-
st.rerun()
53-
93+
st.success("Analysis accepted! Proceeding to SQL execution...")
94+
95+
# Extract SQL queries from the latest analysis
96+
current_state = graph.get_state(thread)
97+
analysis = current_state.values.get("analysis", "")
98+
sql_queries = extract_sql_queries(analysis)
99+
100+
if not sql_queries:
101+
st.warning("No SQL queries found in the analysis.")
102+
else:
103+
# Display SQL queries in an editable text area
104+
edited_queries = st.text_area("Edit SQL Queries:", value=sql_queries), height=200)
105+
106+
# Add an "Execute" button to confirm and execute the edited queries
107+
if st.button("Execute Edited Queries"):
108+
execute_sql_queries(edited_queries)
109+
54110
with col2:
55111
if st.button("No - Revise Analysis"):
56112
st.session_state.show_feedback = True

performance_test.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import sys
2+
from io import StringIO
3+
import logging
4+
from performer.performer import graph # Import your LangGraph setup
5+
6+
# Configure logging
7+
logging.basicConfig(
8+
level=logging.INFO,
9+
format="%(asctime)s - %(levelname)s - %(message)s",
10+
handlers=[logging.StreamHandler(sys.stdout)]
11+
)
12+
13+
def get_mock_schema():
14+
"""Returns a mock schema for testing."""
15+
return """
16+
table_name | column_name
17+
-------------+-------------
18+
customers | id
19+
customers | name
20+
customers | email
21+
customers | address
22+
orders | id
23+
orders | customer_id
24+
orders | order_date
25+
orders | total
26+
products | id
27+
products | name
28+
products | description
29+
products | price
30+
products | stock_quantity
31+
"""
32+
33+
def test_workflow():
34+
"""Simulates the workflow of the PostgreSQL optimization assistant with detailed logging."""
35+
logging.info("Starting test workflow...")
36+
37+
mock_inputs = [
38+
"no",
39+
"Just give me details about index creation.",
40+
"yes"
41+
]
42+
mock_query = "Identify and give me solutions to optimize my postgres database"
43+
mock_schema = get_mock_schema()
44+
45+
logging.info(f"Mock query: {mock_query}")
46+
logging.info(f"Mock schema:\n{mock_schema}")
47+
48+
original_stdin = sys.stdin
49+
original_stdout = sys.stdout
50+
51+
try:
52+
sys.stdin = StringIO("\n".join(mock_inputs))
53+
54+
captured_output = StringIO()
55+
sys.stdout = captured_output
56+
57+
thread = {"configurable": {"thread_id": "performance_optimization_test"}}
58+
logging.info(f"Initialized thread with ID: {thread['configurable']['thread_id']}")
59+
60+
logging.info("Starting analysis phase...")
61+
for event in graph.stream(
62+
{"query": mock_query, "schema": mock_schema},
63+
thread,
64+
stream_mode="values"
65+
):
66+
if "analysis" in event:
67+
logging.info("Analysis generated:")
68+
logging.info(event["analysis"])
69+
print("\n**Analysis**")
70+
print(event["analysis"])
71+
72+
current_state = graph.get_state(thread)
73+
logging.info(f"Current state after analysis: {current_state.values}")
74+
75+
if current_state.values.get("execute"):
76+
logging.info("User approved analysis. Proceeding to SQL execution...")
77+
sql_queries = current_state.values.get("execute_query", [])
78+
logging.info(f"Extracted SQL queries: {sql_queries}")
79+
80+
print("\n**Executing SQL Commands**")
81+
for cmd in sql_queries:
82+
logging.info(f"Executing SQL command: {cmd}")
83+
print(f"Executing: {cmd}")
84+
85+
finally:
86+
sys.stdin = original_stdin
87+
sys.stdout = original_stdout
88+
89+
logging.info("Captured output from the workflow:")
90+
logging.info(captured_output.getvalue())
91+
92+
if __name__ == "__main__":
93+
test_workflow()
44 Bytes
Binary file not shown.

performer/performer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from agentstate.agent_state import AgentState
22
from langgraph.checkpoint.memory import MemorySaver
33
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage
4-
from llm.llm import llm # Change according to where you want to test with
4+
from llm.llm import llm, analyze_llm, ollama_llm
55
from langgraph.graph import START, END, StateGraph
66
from sql.sql_agent import SQLAgent
77
from langchain_core.tools import tool
@@ -14,13 +14,13 @@
1414
db_config = {
1515
"user": "postgres",
1616
"password": "postgres",
17-
"host": "localhost",
17+
"host": "0.0.0.0",
1818
"port": "5432",
1919
"database": "ecommerce_db"
2020
}
2121

2222
sql_agent = SQLAgent(db_config=db_config,name="SQLAgent")
23-
get_dB_schema = sql_agent.get_schema() # Use this schema for dynamic databases
23+
get_dB_schema = sql_agent.get_schema()
2424

2525

2626
# @tool
@@ -64,7 +64,7 @@ def analyze_database(state: AgentState):
6464
]
6565

6666
response = llm.invoke(message)
67-
# sql_commands = extract_sql_commands(response.content) # New helper function
67+
# sql_commands = extract_sql_commands(response.content)
6868
return {
6969
"analysis": response.content,
7070
"feedback": fdb,
3 Bytes
Binary file not shown.

utils/sql_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_schema_info(db_config: dict):
6262
cursor.execute("""
6363
SELECT table_name, column_name
6464
FROM information_schema.columns
65-
WHERE table_schema = 'public'
65+
WHERE table_schema = 'ecommerce'
6666
ORDER BY table_name, ordinal_position;
6767
""")
6868
rows = cursor.fetchall()

0 commit comments

Comments
 (0)