Skip to content

Commit

Permalink
[Demo] Vizro-AI chart UI updates (#847)
Browse files Browse the repository at this point in the history
  • Loading branch information
nadijagraca authored Nov 5, 2024
1 parent 5430972 commit 64523bb
Show file tree
Hide file tree
Showing 9 changed files with 659 additions and 400 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
4 changes: 2 additions & 2 deletions vizro-ai/docs/pages/user-guides/customize-vizro-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ vizro_ai = VizroAI(model="<chosen model>")

To use Anthropic with Vizro-AI, you must have an account with paid-for credits available. None of the free accounts will suffice. You can check [all available Anthropic models including pricing on their website](https://docs.anthropic.com/en/docs/about-claude/models).

- `claude-3-5-sonnet-20240620`
- `claude-3-opus-20240229`
- `claude-3-5-sonnet-latest`
- `claude-3-opus-latest`
- `claude-3-sonnet-20240229`
- `claude-3-haiku-20240307`

Expand Down
97 changes: 77 additions & 20 deletions vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import logging

import black
import dash
import dash_bootstrap_components as dbc
import pandas as pd
from _utils import check_file_extension
from dash.exceptions import PreventUpdate
Expand All @@ -13,16 +15,50 @@
from vizro.models.types import capture
from vizro_ai import VizroAI

try:
from langchain_anthropic import ChatAnthropic
except ImportError:
ChatAnthropic = None

try:
from langchain_mistralai import ChatMistralAI
except ImportError:
ChatMistralAI = None

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI}
SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI}

SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4o-mini",
"gpt-4o",
"gpt-4",
"gpt-4-turbo",
"gpt-3.5-turbo",
],
"Anthropic": [
"claude-3-opus-latest",
"claude-3-5-sonnet-latest",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
}


def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
"""VizroAi plot configuration."""
vendor = SUPPORTED_VENDORS[vendor_input]
llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)

if vendor_input == "OpenAI":
llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
if vendor_input == "Anthropic":
llm = vendor(model=model, anthropic_api_key=api_key, anthropic_api_url=api_base)
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base)

vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, return_elements=True)

Expand All @@ -33,36 +69,31 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input): # noqa: PLR0913
"""Gets the AI response and adds it to the text window."""

def create_response(ai_response, figure, user_prompt, filename):
plotly_fig = figure.to_json()
return (
ai_response,
figure,
{"ai_response": ai_response, "figure": plotly_fig, "prompt": user_prompt, "filename": filename},
)
def create_response(ai_response, figure, ai_outputs):
return (ai_response, figure, {"ai_outputs": ai_outputs})

if not n_clicks:
raise PreventUpdate

if not data:
ai_response = "Please upload data to proceed!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, None)
return create_response(ai_response, figure, ai_outputs=None)

if not api_key:
ai_response = "API key not found. Make sure you enter your API key!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
return create_response(ai_response, figure, ai_outputs=None)

if api_key.startswith('"'):
ai_response = "Make sure you enter your API key without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
return create_response(ai_response, figure, ai_outputs=None)

if api_base is not None and api_base.startswith('"'):
ai_response = "Make sure you enter your API base without quotes!"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
return create_response(ai_response, figure, ai_outputs=None)

try:
logger.info("Attempting chart code.")
Expand All @@ -75,20 +106,25 @@ def create_response(ai_response, figure, user_prompt, filename):
api_base=api_base,
vendor_input=vendor_input,
)
ai_code = ai_outputs.code
figure = ai_outputs.get_fig_object(data_frame=df)
ai_code = ai_outputs.code_vizro
figure_vizro = ai_outputs.get_fig_object(data_frame=df, vizro=True)
figure_plotly = ai_outputs.get_fig_object(data_frame=df, vizro=False)
formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100))
ai_code_outputs = {
"vizro": {"code": ai_outputs.code_vizro, "fig": figure_vizro.to_json()},
"plotly": {"code": ai_outputs.code, "fig": figure_plotly.to_json()},
}

ai_response = "\n".join(["```python", formatted_code, "```"])
logger.info("Successful query produced.")
return create_response(ai_response, figure, user_prompt, data["filename"])
return create_response(ai_response, figure_vizro, ai_outputs=ai_code_outputs)

except Exception as exc:
logger.debug(exc)
logger.info("Chart creation failed.")
ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}"
figure = go.Figure()
return create_response(ai_response, figure, user_prompt, data["filename"])
return create_response(ai_response, figure, ai_outputs=None)


@capture("action")
Expand All @@ -98,7 +134,11 @@ def data_upload_action(contents, filename):
raise PreventUpdate

if not check_file_extension(filename=filename):
return {"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."}
return (
{"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."},
{"color": "gray"},
{"display": "none"},
)

content_type, content_string = contents.split(",")

Expand All @@ -112,11 +152,15 @@ def data_upload_action(contents, filename):
df = pd.read_excel(io.BytesIO(decoded))

data = df.to_dict("records")
return {"data": data, "filename": filename}
return {"data": data, "filename": filename}, {"cursor": "pointer"}, {}

except Exception as e:
logger.debug(e)
return {"error_message": "There was an error processing this file."}
return (
{"error_message": "There was an error processing this file."},
{"color": "gray", "cursor": "default"},
{"display": "none"},
)


@capture("action")
Expand All @@ -127,3 +171,16 @@ def display_filename(data):

display_message = data.get("filename") or data.get("error_message")
return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message


@capture("action")
def update_table(data):
"""Custom action for updating data."""
if not data:
return dash.no_update
df = pd.DataFrame(data["data"])
filename = data.get("filename") or data.get("error_message")
modal_title = f"Data sample preview for {filename} file"
df_sample = df.sample(5)
table = dbc.Table.from_dataframe(df_sample, striped=False, bordered=True, hover=True)
return table, modal_title
Loading

0 comments on commit 64523bb

Please sign in to comment.