Skip to content

Commit 9051bb7

Browse files
committed
adapt to newer OpenAI FunctionCall types
1 parent 5866e9c commit 9051bb7

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

chatlab/chat.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
import asyncio
1515
import logging
1616
import os
17-
from typing import AsyncIterator, Callable, List, Optional, Tuple, Type, Union, overload
17+
from typing import Callable, List, Optional, Tuple, Type, Union, overload
1818

1919
import openai
2020
from deprecation import deprecated
2121
from IPython.core.async_helpers import get_asyncio_loop
22-
from openai import AsyncOpenAI
22+
from openai import AsyncOpenAI, AsyncStream
23+
from openai.types import FunctionDefinition
2324
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
2425
from pydantic import BaseModel
2526

@@ -29,7 +30,7 @@
2930
from .display import ChatFunctionCall
3031
from .errors import ChatLabError
3132
from .messaging import human
32-
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
33+
from .registry import FunctionRegistry, PythonHallucinationFunction
3334
from .views.assistant import AssistantMessageView
3435

3536
logger = logging.getLogger(__name__)
@@ -138,7 +139,7 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre
138139
return await self.submit(*messages, stream=stream, **kwargs)
139140

140141
async def __process_stream(
141-
self, resp: AsyncIterator[ChatCompletionChunk]
142+
self, resp: AsyncStream[ChatCompletionChunk]
142143
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
143144
assistant_view: AssistantMessageView = AssistantMessageView()
144145
function_view: Optional[AssistantFunctionCallView] = None
@@ -232,15 +233,15 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
232233
try:
233234
client = AsyncOpenAI()
234235

235-
manifest = self.function_registry.api_manifest()
236+
api_manifest = self.function_registry.api_manifest()
236237

237238
# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
238239
# two cases separately
239240
if stream:
240241
streaming_response = await client.chat.completions.create(
241242
model=self.model,
242243
messages=full_messages,
243-
**manifest,
244+
**api_manifest,
244245
stream=True,
245246
temperature=kwargs.get("temperature", 0),
246247
)
@@ -250,7 +251,7 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
250251
full_response = await client.chat.completions.create(
251252
model=self.model,
252253
messages=full_messages,
253-
**manifest,
254+
**api_manifest,
254255
stream=False,
255256
temperature=kwargs.get("temperature", 0),
256257
)
@@ -333,14 +334,14 @@ def register(
333334
self,
334335
function: Callable,
335336
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
336-
) -> FunctionSchema:
337+
) -> FunctionDefinition:
337338
...
338339

339340
def register(
340341
self,
341342
function: Optional[Callable] = None,
342343
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
343-
) -> Union[Callable, FunctionSchema]:
344+
) -> Union[Callable, FunctionDefinition]:
344345
"""Register a function with the ChatLab instance.
345346
346347
This can be used as a decorator like so:

chatlab/registry.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class WhatTime(BaseModel):
4444
import json
4545
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypedDict, Union, get_args, get_origin, overload
4646

47-
from openai.types.chat.completion_create_params import Function as FunctionSchema
48-
from openai.types.chat.completion_create_params import FunctionCall as FunctionCallOption
47+
from openai.types import FunctionDefinition
48+
from openai.types.chat.completion_create_params import Function, FunctionCall
4949
from pydantic import BaseModel, create_model
5050

5151
from .decorators import ChatlabMetadata
@@ -54,10 +54,10 @@ class WhatTime(BaseModel):
5454
class APIManifest(TypedDict, total=False):
5555
"""The schema for the API."""
5656

57-
functions: List[FunctionSchema]
57+
functions: List[Function]
5858
"""A list of functions that the model can call during the conversation."""
5959

60-
function_call: FunctionCallOption
60+
function_call: FunctionCall
6161
"""The policy for when to call functions.
6262
6363
One of "auto", "none", or a dictionary with a "name" key.
@@ -110,7 +110,7 @@ class FunctionSchemaConfig:
110110
def generate_function_schema(
111111
function: Callable,
112112
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
113-
) -> FunctionSchema:
113+
) -> FunctionDefinition:
114114
"""Generate a function schema for sending to OpenAI."""
115115
doc = function.__doc__
116116
func_name = function.__name__
@@ -122,7 +122,7 @@ def generate_function_schema(
122122
if not doc:
123123
raise Exception("Only functions with docstrings can be registered")
124124

125-
schema = FunctionSchema(
125+
schema = FunctionDefinition(
126126
name=func_name,
127127
description=doc,
128128
parameters={},
@@ -184,14 +184,23 @@ def generate_function_schema(
184184
if "required" not in parameters:
185185
parameters["required"] = []
186186

187-
schema["parameters"] = parameters
187+
schema.parameters = parameters
188188
return schema
189189

190190

191191
# Declare the type for the python hallucination
192192
PythonHallucinationFunction = Callable[[str], Any]
193193

194194

195+
def adapt_function_definition(fd: FunctionDefinition) -> Function:
196+
"""Adapt a FunctionDefinition to a Function for working with the OpenAI API."""
197+
return {
198+
"name": fd.name,
199+
"parameters": fd.parameters,
200+
"description": fd.description if fd.description is not None else "",
201+
}
202+
203+
195204
class FunctionRegistry:
196205
"""Registry of functions and their schemas for calling them.
197206
@@ -229,7 +238,7 @@ class WhatTime(BaseModel):
229238
"""
230239

231240
__functions: dict[str, Callable]
232-
__schemas: dict[str, FunctionSchema]
241+
__schemas: dict[str, FunctionDefinition]
233242

234243
# Allow passing in a callable that accepts a single string for the python
235244
# hallucination function. This is useful for testing.
@@ -265,14 +274,14 @@ def register(
265274
self,
266275
function: Callable,
267276
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
268-
) -> FunctionSchema:
277+
) -> FunctionDefinition:
269278
...
270279

271280
def register(
272281
self,
273282
function: Optional[Callable] = None,
274283
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
275-
) -> Union[Callable, FunctionSchema]:
284+
) -> Union[Callable, FunctionDefinition]:
276285
"""Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function.
277286
278287
>>> registry = FunctionRegistry()
@@ -303,7 +312,7 @@ def register_function(
303312
self,
304313
function: Callable,
305314
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
306-
) -> FunctionSchema:
315+
) -> FunctionDefinition:
307316
"""Register a single function."""
308317
final_schema = generate_function_schema(function, parameter_schema)
309318

@@ -327,7 +336,7 @@ def get(self, function_name) -> Optional[Callable]:
327336

328337
return self.__functions.get(function_name)
329338

330-
def get_schema(self, function_name) -> Optional[FunctionSchema]:
339+
def get_schema(self, function_name) -> Optional[FunctionDefinition]:
331340
"""Get a function schema by name."""
332341
return self.__schemas.get(function_name)
333342

@@ -341,7 +350,7 @@ def get_chatlab_metadata(self, function_name) -> ChatlabMetadata:
341350
chatlab_metadata = getattr(function, "chatlab_metadata", ChatlabMetadata())
342351
return chatlab_metadata
343352

344-
def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> APIManifest:
353+
def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManifest:
345354
"""Get a dictionary containing function definitions and calling options.
346355
347356
This is designed to be used with OpenAI's Chat Completion API, where the
@@ -394,12 +403,14 @@ def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> API
394403
stream=True,
395404
)
396405
"""
397-
if len(self.function_definitions) == 0:
406+
function_definitions = [adapt_function_definition(f) for f in self.__schemas.values()]
407+
408+
if len(function_definitions) == 0:
398409
# When there are no functions, we can't send an empty functions array to OpenAI
399410
return {}
400411

401412
return {
402-
"functions": self.function_definitions,
413+
"functions": function_definitions,
403414
"function_call": function_call_option,
404415
}
405416

@@ -449,6 +460,6 @@ def __contains__(self, name) -> bool:
449460
return name in self.__functions
450461

451462
@property
452-
def function_definitions(self) -> list[FunctionSchema]:
463+
def function_definitions(self) -> list[FunctionDefinition]:
453464
"""Get a list of function definitions."""
454465
return list(self.__schemas.values())

0 commit comments

Comments
 (0)