Skip to content

Commit 5f6123b

Browse files
RCR sample (#57)
1 parent f67a91d commit 5f6123b

File tree

3 files changed

+295
-0
lines changed

3 files changed

+295
-0
lines changed

rcr/create_service.sql

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
-- Run these first
2+
-- https://github.com/Snowflake-Labs/sfguide-getting-started-with-cortex-analyst/blob/main/create_snowflake_objects.sql
3+
-- https://github.com/Snowflake-Labs/sfguide-getting-started-with-cortex-analyst/blob/main/load_data.sql
4+
-- https://github.com/Snowflake-Labs/sfguide-getting-started-with-cortex-analyst/blob/main/cortex_search_create.sql
5+
6+
7+
USE ROLE accountadmin;
8+
CREATE ROLE service_user_role;
9+
10+
CREATE DATABASE IF NOT EXISTS app_db;
11+
GRANT OWNERSHIP ON DATABASE app_db TO ROLE service_user_role COPY CURRENT GRANTS;
12+
13+
GRANT BIND SERVICE ENDPOINT ON ACCOUNT TO ROLE service_user_role;
14+
CREATE COMPUTE POOL app_compute_pool
15+
MIN_NODES = 1
16+
MAX_NODES = 1
17+
INSTANCE_FAMILY = CPU_XS
18+
AUTO_SUSPEND_SECS = 3600;
19+
20+
GRANT USAGE, MONITOR ON COMPUTE POOL app_compute_pool TO ROLE service_user_role;
21+
GRANT OWNERSHIP ON SCHEMA app_db.public TO ROLE service_user_role;
22+
GRANT CREATE NETWORK RULE ON SCHEMA app_db.public TO ROLE accountadmin;
23+
24+
CREATE OR REPLACE NETWORK RULE app_db.public.dependencies_network_rule
25+
MODE = EGRESS
26+
TYPE = HOST_PORT
27+
VALUE_LIST = ('pypi.python.org', 'pypi.org', 'cdn.pypi.org','pythonhosted.org', 'files.pythonhosted.org', 'github.com', 'githubusercontent.com');
28+
29+
CREATE EXTERNAL ACCESS INTEGRATION dependencies_access_integration
30+
ALLOWED_NETWORK_RULES = (app_db.public.dependencies_network_rule)
31+
ENABLED = true;
32+
33+
GRANT USAGE ON INTEGRATION dependencies_access_integration TO ROLE service_user_role;
34+
35+
-- Grant restricted caller privileges to service_user_role
36+
GRANT CALLER USAGE ON DATABASE cortex_analyst_demo TO ROLE service_user_role;
37+
GRANT INHERITED CALLER USAGE ON ALL SCHEMAS IN DATABASE cortex_analyst_demo TO ROLE service_user_role;
38+
GRANT INHERITED CALLER USAGE,READ ON ALL STAGES IN SCHEMA cortex_analyst_demo.revenue_timeseries TO ROLE service_user_role;
39+
GRANT INHERITED CALLER SELECT ON ALL TABLES IN DATABASE cortex_analyst_demo TO ROLE service_user_role;
40+
GRANT CALLER USAGE ON CORTEX SEARCH SERVICE cortex_analyst_demo.revenue_timeseries.product_line_search_service TO ROLE service_user_role;
41+
GRANT CALLER USAGE ON DATABASE snowflake TO ROLE service_user_role;
42+
GRANT INHERITED CALLER USAGE ON ALL SCHEMAS IN DATABASE snowflake TO ROLE service_user_role;
43+
GRANT INHERITED CALLER USAGE ON ALL FUNCTIONS IN DATABASE snowflake TO ROLE service_user_role;
44+
45+
46+
USE ROLE service_user_role;
47+
USE DATABASE app_db;
48+
49+
CREATE IMAGE REPOSITORY IF NOT EXISTS repo;
50+
SHOW IMAGE REPOSITORIES IN SCHEMA app_db.public;
51+
SHOW IMAGES IN IMAGE REPOSITORY app_db.public.repo;
52+
53+
CREATE SERVICE analyst_ui
54+
IN COMPUTE POOL app_compute_pool
55+
FROM SPECIFICATION $$
56+
spec:
57+
containers:
58+
- name: ui
59+
image: <registry>/<repo>/<image>:<version>
60+
readinessProbe:
61+
port: 8080
62+
path: /healthcheck
63+
endpoints:
64+
- name: chat
65+
port: 8080
66+
public: true
67+
capabilities:
68+
securityContext:
69+
executeAsCaller: true
70+
$$
71+
MIN_INSTANCES=1
72+
MAX_INSTANCES=1;

rcr/ui/Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
ARG BASE_IMAGE=python:3.10-slim-buster
2+
FROM $BASE_IMAGE
3+
COPY chatbot.py ./
4+
RUN pip install --upgrade pip && \
5+
pip install --upgrade gradio && \
6+
pip install --upgrade requests && \
7+
pip install --upgrade uvicorn && \
8+
pip install --upgrade fastapi
9+
CMD ["python3", "chatbot.py"]
10+

rcr/ui/chatbot.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import logging
2+
import os
3+
import sys
4+
import gradio as gr
5+
import requests
6+
from fastapi import FastAPI, Request
7+
8+
SERVICE_HOST = os.getenv('SERVER_HOST', '0.0.0.0')
9+
SERVICE_PORT = os.getenv('SERVER_PORT', 8080)
10+
SNOWFLAKE_HOST = os.getenv("SNOWFLAKE_HOST")
11+
12+
SEMANTIC_MODEL_PATH = "@cortex_analyst_demo.revenue_timeseries.raw_data/revenue_timeseries.yaml"
13+
API_TIMEOUT = 60 # in seconds
14+
15+
def get_logger(logger_name):
16+
logger = logging.getLogger(logger_name)
17+
logger.setLevel(logging.DEBUG)
18+
handler = logging.StreamHandler(sys.stdout)
19+
handler.setLevel(logging.DEBUG)
20+
handler.setFormatter(
21+
logging.Formatter(
22+
'%(name)s [%(asctime)s] [%(levelname)s] %(message)s'))
23+
logger.addHandler(handler)
24+
return logger
25+
26+
logger = get_logger('chatbot')
27+
28+
app = FastAPI()
29+
30+
@app.middleware("http")
31+
async def get_ingress_user_token(request: Request, call_next):
32+
"""Capture the current user token from ingress"""
33+
global ingress_user_token
34+
ingress_user_token = request.headers.get('Sf-Context-Current-User-Token')
35+
response = await call_next(request)
36+
return response
37+
38+
def get_login_token():
39+
with open("/snowflake/session/token", "r") as f:
40+
return f.read()
41+
42+
def get_request_headers():
43+
return {
44+
"Authorization": f"Bearer {get_login_token()}.{ingress_user_token}",
45+
"X-Snowflake-Authorization-Token-Type": "OAUTH",
46+
"Content-Type": "application/json",
47+
"Accept": "application/json"
48+
}
49+
50+
analyst_history = None
51+
52+
def clear_state():
53+
global analyst_history
54+
analyst_history = []
55+
56+
clear_state()
57+
58+
def send_sql(sql_query):
59+
"""Executes a SQL query using Snowflake's REST API."""
60+
logger.debug(f"Executing SQL query: {sql_query}")
61+
62+
request_body = {
63+
"statement": sql_query,
64+
}
65+
66+
url = f"https://{SNOWFLAKE_HOST}/api/v2/statements"
67+
68+
try:
69+
resp = requests.post(url=url, json=request_body, headers=get_request_headers(), timeout=API_TIMEOUT)
70+
71+
if resp.status_code >= 400:
72+
raise gr.Error(f"SQL Error: HTTP {resp.status_code} - {resp.text}")
73+
74+
response_data = resp.json()
75+
logger.debug(f"SQL Response data: {response_data}")
76+
77+
# Format the SQL result for display
78+
if "resultSetMetaData" in response_data and "rowType" in response_data["resultSetMetaData"]:
79+
# Get column names
80+
columns = [col["name"] for col in response_data["resultSetMetaData"].get("rowType", [])]
81+
rows = response_data["data"]
82+
83+
# Format as a table
84+
result ="| " + " | ".join(columns) + " |\n"
85+
result += "| " + " | ".join(["---" for _ in columns]) + " |\n"
86+
for row in rows:
87+
result +="| " + " | ".join(str(cell) for cell in row) + " |\n"
88+
89+
return result
90+
91+
return "SQL query executed successfully but returned no data."
92+
93+
except Exception as e:
94+
raise gr.Error(f"SQL Error: {str(e)}")
95+
96+
97+
async def send_message(message, history):
98+
logger.debug(f"Received message: {message}");
99+
analyst_history.append({
100+
"role": "user",
101+
"content": [
102+
{
103+
"type": "text",
104+
"text": message
105+
}
106+
]
107+
})
108+
request_body = {
109+
"messages": analyst_history,
110+
"semantic_model_file": SEMANTIC_MODEL_PATH
111+
}
112+
113+
url = f"https://{SNOWFLAKE_HOST}/api/v2/cortex/analyst/message"
114+
115+
try:
116+
resp = requests.post(url=url, json=request_body, headers=get_request_headers(), timeout=API_TIMEOUT)
117+
118+
if resp.status_code >= 400:
119+
raise gr.Error(f"HTTP Error: {resp.status_code} - {resp.text}")
120+
121+
response_data = resp.json()
122+
logger.debug(f"Response data: {response_data}")
123+
124+
# Process the response message content
125+
response_messages = []
126+
yield response_messages
127+
128+
# Extract text content from the message
129+
if "message" in response_data and "content" in response_data["message"]:
130+
analyst_history.append(response_data["message"])
131+
logger.debug(f"History: {analyst_history}")
132+
133+
for content_item in response_data["message"]["content"]:
134+
if content_item.get("type") == "text":
135+
response_messages.append(
136+
gr.ChatMessage(
137+
content = content_item.get("text")
138+
)
139+
)
140+
elif content_item.get("type") == "sql":
141+
statement = content_item.get('statement')
142+
m = gr.ChatMessage(
143+
content = f"Executing the following SQL query:\n```sql\n{statement}\n```",
144+
metadata={"title": "Running SQL", "status": "pending"}
145+
)
146+
response_messages.append(m)
147+
yield response_messages
148+
sql_result = send_sql(statement);
149+
m.metadata["status"] = "done"
150+
response_messages.append(
151+
gr.ChatMessage(
152+
content=f"Response data:\n{sql_result}"
153+
)
154+
)
155+
elif content_item.get("type") == "suggestions":
156+
m = gr.ChatMessage(
157+
content= "",
158+
options = []
159+
)
160+
for suggestion_index, suggestion in enumerate(content_item["suggestions"]):
161+
m.options.append({
162+
"value": suggestion
163+
})
164+
response_messages.append(m)
165+
else:
166+
pass
167+
yield response_messages
168+
169+
# Add warnings if they exist
170+
if "warnings" in response_data and response_data["warnings"]:
171+
for warning in response_data["warnings"]:
172+
gr.Warning(warning.get('message', ''))
173+
174+
# If no content was found, return a default message
175+
if len(response_messages) == 0:
176+
response_messages.append(
177+
gr.ChatMessage(content = "No response text received")
178+
)
179+
180+
except Exception as e:
181+
err = f"Unexpected error: {e}"
182+
logger.error(err)
183+
clear_state()
184+
raise gr.Error(err)
185+
186+
logger.debug(f"Response messages: {response_messages}")
187+
yield response_messages
188+
189+
@app.get("/healthcheck")
190+
async def readiness_probe():
191+
return "I'm ready"
192+
193+
# Build chatbot
194+
with gr.Blocks() as bot:
195+
chatbot = gr.Chatbot(type="messages")
196+
chatbot.clear(clear_state)
197+
gr.ChatInterface(
198+
send_message,
199+
type="messages",
200+
examples=["What questions can I ask?"],
201+
title="Cortex Analyst",
202+
chatbot=chatbot
203+
)
204+
205+
206+
# Mount the Gradio app to FastAPI
207+
gr.mount_gradio_app(app, bot, path="")
208+
209+
# Start the app
210+
if __name__ == "__main__":
211+
import uvicorn
212+
uvicorn.run(app, host=SERVICE_HOST, port=int(SERVICE_PORT))
213+
logger.debug(f"Chatbot app running on {SERVICE_HOST}:{SERVICE_PORT}")

0 commit comments

Comments
 (0)