From 93073a9c0659493584c9510ce3ff3bed999511e9 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Thu, 16 Nov 2023 21:38:54 -0800 Subject: [PATCH] allow passing through base url and api key for custom servers --- chatlab/chat.py | 14 ++++-- notebooks/open-model-functions.ipynb | 72 ++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 notebooks/open-model-functions.ipynb diff --git a/chatlab/chat.py b/chatlab/chat.py index 1edd0d1..7e35077 100644 --- a/chatlab/chat.py +++ b/chatlab/chat.py @@ -69,6 +69,8 @@ class Chat: def __init__( self, *initial_context: Union[ChatCompletionMessageParam, str], + base_url=None, + api_key=None, model="gpt-3.5-turbo-0613", function_registry: Optional[FunctionRegistry] = None, chat_functions: Optional[List[Callable]] = None, @@ -85,8 +87,8 @@ def __init__( """ # Sometimes people set the API key with an environment variables and sometimes # they set it on the openai module. We'll check both. - openai_api_key = os.getenv("OPENAI_API_KEY") or openai.api_key - if openai_api_key is None: + openai_api_key = api_key or os.getenv("OPENAI_API_KEY") or openai.api_key + if openai_api_key is None or not isinstance(openai_api_key, str): raise ChatLabError( "You must set the environment variable `OPENAI_API_KEY` to use this package.\n" "This key allows chatlab to communicate with OpenAI servers.\n\n" @@ -97,6 +99,9 @@ def __init__( else: pass + self.api_key = openai_api_key + self.base_url = base_url + if initial_context is None: initial_context = [] # type: ignore @@ -231,7 +236,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream full_messages.append(message) try: - client = AsyncOpenAI() + client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url, + ) api_manifest = self.function_registry.api_manifest() diff --git a/notebooks/open-model-functions.ipynb b/notebooks/open-model-functions.ipynb new file mode 100644 index 0000000..0579b0c --- /dev/null +++ b/notebooks/open-model-functions.ipynb @@ -0,0 +1,72 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "sql.query(\"select count(*) from products order by count desc limit 10\")" + ], + "text/plain": [ + "sql.query(\"select count(*) from products order by count desc limit 10\")" + ] + }, + "metadata": { + "text/markdown": { + "chatlab": { + "default": true + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "from chatlab import Chat, system\n", + "\n", + "chat = Chat(\n", + " system(\"You are a data engineer\"),\n", + " model=\"gorilla-openfunctions-v0\",\n", + " base_url=\"http://luigi.millennium.berkeley.edu:8000/v1\",\n", + " api_key=\"EMPTY\",\n", + ")\n", + "\n", + "\n", + "def sql(query: str):\n", + " \"\"\"Runs SQL query\"\"\"\n", + " # Totally fake, returns an empty table for demonstration\n", + "\n", + " return []\n", + "\n", + "\n", + "chat.register(sql)\n", + "\n", + "await chat(\"Show me the top 10 most popular products\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chatlab-3kMKfU-i-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}