-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Chatlas is a really nice tool. I don't see a way to gain control over what arguments get passed to to a tool. Some of these would come from the LLM but some might be based on user input (e.g., a dataset dropdown). In the below I'd like to get the tool to use the get_data()
as the argument. Is there a way to do something like this?
chat_client.register_tool(get_data_summary, dct=get_data())
from shiny import App, ui, render, reactive
import pandas as pd
import os
import json
import requests
from pydantic import BaseModel, Field
from chatlas import ChatOpenAI
from dotenv import load_dotenv
load_dotenv()
# if this app doesn't run check that you have set the correct
# Python interpreter
# you can set this using the command pallette and typing
# "Select: Python Interpreter"
# There is also a button you can click in the bottom left of VS Code
# to select the Python interpreter
# to stop the app, press Ctrl+C in the terminal
class GetCurrentTemperature(BaseModel):
"""
Get the current weather given a latitude and longitude.
You MUST provide both latitude and longitude as numbers.
If the user provides a location name (not latitude/longitude),
you MUST first call the 'get_lat_long' tool to obtain the coordinates
then call this tool with those coordinates.
"""
latitude: float = Field(
description="The latitude of the location. Must be a float."
)
longitude: float = Field(
description="The longitude of the location. Must be a float."
)
def get_current_temperature(latitude: float, longitude: float) -> dict:
lat_lng = f"latitude={latitude}&longitude={longitude}"
url = f"https://api.open-meteo.com/v1/forecast?{lat_lng}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m"
response = requests.get(url)
js = response.json()
return js["current"]
class GetLatLong(BaseModel):
"""
Use this tool to get the latitude and longitude for a location name provided by the user.
If the user asks for weather and only provides a location name, call this tool first to get coordinates.
"""
location: str = Field(
description="The location name to get latitude and longitude for. Must be a string."
)
def get_lat_long(location: str) -> dict:
url = f"https://geocode.maps.co/search?q='{location}'&api_key={os.getenv('GEO_API_KEY')}"
response = requests.get(url)
js = response.json()
if len(js) == 0:
raise ValueError(
f"Could not find location: {location}. Try to determine the location from the LLM response and call this tool again with the new location."
)
else:
return {"latitude": float(js[0]["lat"]), "longitude": float(js[0]["lon"])}
def get_data_summary(dct: dict = {"a": [1, 2, 3], "b": [4, 5, 6]}) -> str:
"""
Use this tool to get summary statistics for a dictionary provided by the user.
"""
df = (
pd.DataFrame(dct)
.describe()
.T.reset_index()
.rename(columns={"index": "Statistic"})
)
return df.to_html(classes="table table-striped", index=False)
app_ui = ui.page_sidebar(
ui.sidebar(
ui.input_file("file", "Upload CSV File", accept=[".csv"]),
ui.output_data_frame("table"),
),
ui.panel_title("Weather Assistant"),
ui.chat_ui(id="my_chat"),
title="Weather Assistant",
)
start_messages = "Some example questions to ask the assistant:"
start_messages += "</br><span class='suggestion submit'>What is the lat-long for the Rady School of Management?</span>"
start_messages += "</br><span class='suggestion submit'>What is the weather like in New York City today?</span>"
start_messages += ("</br><span class='suggestion submit'>Summarize the uploaded file</span>")
def server(input, output, session):
chat = ui.Chat(id="my_chat", messages=[start_messages])
@reactive.calc
def get_dataframe():
if not input.file():
return pd.DataFrame()
file_path = input.file()[0]["datapath"]
df = pd.read_csv(file_path)
return df
@reactive.calc
def get_data():
return get_dataframe().to_dict(orient="records")
@render.data_frame
def table():
df = get_dataframe()
if df.empty:
return pd.DataFrame()
return df
chat_client = ChatOpenAI(
model="gpt-4o",
api_key=os.getenv("OPENAI_API_KEY"),
)
chat_client.register_tool(get_lat_long, model=GetLatLong)
chat_client.register_tool(get_current_temperature, model=GetCurrentTemperature)
chat_client.register_tool(get_data_summary)
@chat.on_user_submit
async def handle_user_input(user_input: str):
response = await chat_client.stream_async(user_input)
await chat.append_message_stream(response)
app = App(app_ui, server)
Metadata
Metadata
Assignees
Labels
No labels