Skip to content

Chatlas in a shinyapp: How to specify an uploaded file as the tool argument? #104

@vnijs

Description

@vnijs

Chatlas is a really nice tool. I don't see a way to gain control over what arguments get passed to to a tool. Some of these would come from the LLM but some might be based on user input (e.g., a dataset dropdown). In the below I'd like to get the tool to use the get_data() as the argument. Is there a way to do something like this?

chat_client.register_tool(get_data_summary, dct=get_data())

from shiny import App, ui, render, reactive
import pandas as pd
import os
import json
import requests
from pydantic import BaseModel, Field
from chatlas import ChatOpenAI
from dotenv import load_dotenv

load_dotenv()

# if this app doesn't run check that you have set the correct
# Python interpreter
# you can set this using the command pallette and typing
# "Select: Python Interpreter"
# There is also a button you can click in the bottom left of VS Code
# to select the Python interpreter

# to stop the app, press Ctrl+C in the terminal


class GetCurrentTemperature(BaseModel):
    """
    Get the current weather given a latitude and longitude.
    You MUST provide both latitude and longitude as numbers.
    If the user provides a location name (not latitude/longitude),
    you MUST first call the 'get_lat_long' tool to obtain the coordinates
    then call this tool with those coordinates.
    """

    latitude: float = Field(
        description="The latitude of the location. Must be a float."
    )
    longitude: float = Field(
        description="The longitude of the location. Must be a float."
    )


def get_current_temperature(latitude: float, longitude: float) -> dict:
    lat_lng = f"latitude={latitude}&longitude={longitude}"
    url = f"https://api.open-meteo.com/v1/forecast?{lat_lng}&current=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m"
    response = requests.get(url)
    js = response.json()
    return js["current"]


class GetLatLong(BaseModel):
    """
    Use this tool to get the latitude and longitude for a location name provided by the user.
    If the user asks for weather and only provides a location name, call this tool first to get coordinates.
    """

    location: str = Field(
        description="The location name to get latitude and longitude for. Must be a string."
    )


def get_lat_long(location: str) -> dict:
    url = f"https://geocode.maps.co/search?q='{location}'&api_key={os.getenv('GEO_API_KEY')}"
    response = requests.get(url)
    js = response.json()
    if len(js) == 0:
        raise ValueError(
            f"Could not find location: {location}. Try to determine the location from the LLM response and call this tool again with the new location."
        )
    else:
        return {"latitude": float(js[0]["lat"]), "longitude": float(js[0]["lon"])}


def get_data_summary(dct: dict = {"a": [1, 2, 3], "b": [4, 5, 6]}) -> str:
    """
    Use this tool to get summary statistics for a dictionary provided by the user.
    """
    df = (
        pd.DataFrame(dct)
        .describe()
        .T.reset_index()
        .rename(columns={"index": "Statistic"})
    )
    return df.to_html(classes="table table-striped", index=False)


app_ui = ui.page_sidebar(
    ui.sidebar(
        ui.input_file("file", "Upload CSV File", accept=[".csv"]),
        ui.output_data_frame("table"),
    ),
    ui.panel_title("Weather Assistant"),
    ui.chat_ui(id="my_chat"),
    title="Weather Assistant",
)

start_messages = "Some example questions to ask the assistant:"
start_messages += "</br><span class='suggestion submit'>What is the lat-long for the Rady School of Management?</span>"
start_messages += "</br><span class='suggestion submit'>What is the weather like in New York City today?</span>"
start_messages += ("</br><span class='suggestion submit'>Summarize the uploaded file</span>")


def server(input, output, session):
    chat = ui.Chat(id="my_chat", messages=[start_messages])

    @reactive.calc
    def get_dataframe():
        if not input.file():
            return pd.DataFrame()

        file_path = input.file()[0]["datapath"]
        df = pd.read_csv(file_path)
        return df

    @reactive.calc
    def get_data():
        return get_dataframe().to_dict(orient="records")

    @render.data_frame
    def table():
        df = get_dataframe()
        if df.empty:
            return pd.DataFrame()
        return df

    chat_client = ChatOpenAI(
        model="gpt-4o",
        api_key=os.getenv("OPENAI_API_KEY"),
    )

    chat_client.register_tool(get_lat_long, model=GetLatLong)
    chat_client.register_tool(get_current_temperature, model=GetCurrentTemperature)
    chat_client.register_tool(get_data_summary)

    @chat.on_user_submit
    async def handle_user_input(user_input: str):
        response = await chat_client.stream_async(user_input)
        await chat.append_message_stream(response)


app = App(app_ui, server)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions