Skip to content

Commit

Permalink
Merge pull request #115 from rgbkrk/gorilla-functions
Browse files Browse the repository at this point in the history
Custom Servers -- allow passing through base url and api key
  • Loading branch information
rgbkrk authored Nov 17, 2023
2 parents c2a7b9a + 93073a9 commit 0bcaa01
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
14 changes: 11 additions & 3 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Chat:
def __init__(
self,
*initial_context: Union[ChatCompletionMessageParam, str],
base_url=None,
api_key=None,
model="gpt-3.5-turbo-0613",
function_registry: Optional[FunctionRegistry] = None,
chat_functions: Optional[List[Callable]] = None,
Expand All @@ -85,8 +87,8 @@ def __init__(
"""
# Sometimes people set the API key with an environment variables and sometimes
# they set it on the openai module. We'll check both.
openai_api_key = os.getenv("OPENAI_API_KEY") or openai.api_key
if openai_api_key is None:
openai_api_key = api_key or os.getenv("OPENAI_API_KEY") or openai.api_key
if openai_api_key is None or not isinstance(openai_api_key, str):
raise ChatLabError(
"You must set the environment variable `OPENAI_API_KEY` to use this package.\n"
"This key allows chatlab to communicate with OpenAI servers.\n\n"
Expand All @@ -97,6 +99,9 @@ def __init__(
else:
pass

self.api_key = openai_api_key
self.base_url = base_url

if initial_context is None:
initial_context = [] # type: ignore

Expand Down Expand Up @@ -231,7 +236,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
full_messages.append(message)

try:
client = AsyncOpenAI()
client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()

Expand Down
72 changes: 72 additions & 0 deletions notebooks/open-model-functions.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"sql.query(\"select count(*) from products order by count desc limit 10\")"
],
"text/plain": [
"sql.query(\"select count(*) from products order by count desc limit 10\")"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
}
],
"source": [
"from chatlab import Chat, system\n",
"\n",
"chat = Chat(\n",
" system(\"You are a data engineer\"),\n",
" model=\"gorilla-openfunctions-v0\",\n",
" base_url=\"http://luigi.millennium.berkeley.edu:8000/v1\",\n",
" api_key=\"EMPTY\",\n",
")\n",
"\n",
"\n",
"def sql(query: str):\n",
" \"\"\"Runs SQL query\"\"\"\n",
" # Totally fake, returns an empty table for demonstration\n",
"\n",
" return []\n",
"\n",
"\n",
"chat.register(sql)\n",
"\n",
"await chat(\"Show me the top 10 most popular products\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chatlab-3kMKfU-i-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 0bcaa01

Please sign in to comment.