@@ -44,8 +44,8 @@ class WhatTime(BaseModel):
44
44
import json
45
45
from typing import Any , Callable , Dict , Iterable , List , Optional , Type , TypedDict , Union , get_args , get_origin , overload
46
46
47
- from openai .types . chat . completion_create_params import Function as FunctionSchema
48
- from openai .types .chat .completion_create_params import FunctionCall as FunctionCallOption
47
+ from openai .types import FunctionDefinition
48
+ from openai .types .chat .completion_create_params import Function , FunctionCall
49
49
from pydantic import BaseModel , create_model
50
50
51
51
from .decorators import ChatlabMetadata
@@ -54,10 +54,10 @@ class WhatTime(BaseModel):
54
54
class APIManifest (TypedDict , total = False ):
55
55
"""The schema for the API."""
56
56
57
- functions : List [FunctionSchema ]
57
+ functions : List [Function ]
58
58
"""A list of functions that the model can call during the conversation."""
59
59
60
- function_call : FunctionCallOption
60
+ function_call : FunctionCall
61
61
"""The policy for when to call functions.
62
62
63
63
One of "auto", "none", or a dictionary with a "name" key.
@@ -110,7 +110,7 @@ class FunctionSchemaConfig:
110
110
def generate_function_schema (
111
111
function : Callable ,
112
112
parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None ,
113
- ) -> FunctionSchema :
113
+ ) -> FunctionDefinition :
114
114
"""Generate a function schema for sending to OpenAI."""
115
115
doc = function .__doc__
116
116
func_name = function .__name__
@@ -122,7 +122,7 @@ def generate_function_schema(
122
122
if not doc :
123
123
raise Exception ("Only functions with docstrings can be registered" )
124
124
125
- schema = FunctionSchema (
125
+ schema = FunctionDefinition (
126
126
name = func_name ,
127
127
description = doc ,
128
128
parameters = {},
@@ -184,14 +184,23 @@ def generate_function_schema(
184
184
if "required" not in parameters :
185
185
parameters ["required" ] = []
186
186
187
- schema [ " parameters" ] = parameters
187
+ schema . parameters = parameters
188
188
return schema
189
189
190
190
191
191
# Declare the type for the python hallucination
192
192
PythonHallucinationFunction = Callable [[str ], Any ]
193
193
194
194
195
+ def adapt_function_definition (fd : FunctionDefinition ) -> Function :
196
+ """Adapt a FunctionDefinition to a Function for working with the OpenAI API."""
197
+ return {
198
+ "name" : fd .name ,
199
+ "parameters" : fd .parameters ,
200
+ "description" : fd .description if fd .description is not None else "" ,
201
+ }
202
+
203
+
195
204
class FunctionRegistry :
196
205
"""Registry of functions and their schemas for calling them.
197
206
@@ -229,7 +238,7 @@ class WhatTime(BaseModel):
229
238
"""
230
239
231
240
__functions : dict [str , Callable ]
232
- __schemas : dict [str , FunctionSchema ]
241
+ __schemas : dict [str , FunctionDefinition ]
233
242
234
243
# Allow passing in a callable that accepts a single string for the python
235
244
# hallucination function. This is useful for testing.
@@ -265,14 +274,14 @@ def register(
265
274
self ,
266
275
function : Callable ,
267
276
parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None ,
268
- ) -> FunctionSchema :
277
+ ) -> FunctionDefinition :
269
278
...
270
279
271
280
def register (
272
281
self ,
273
282
function : Optional [Callable ] = None ,
274
283
parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None ,
275
- ) -> Union [Callable , FunctionSchema ]:
284
+ ) -> Union [Callable , FunctionDefinition ]:
276
285
"""Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function.
277
286
278
287
>>> registry = FunctionRegistry()
@@ -303,7 +312,7 @@ def register_function(
303
312
self ,
304
313
function : Callable ,
305
314
parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None ,
306
- ) -> FunctionSchema :
315
+ ) -> FunctionDefinition :
307
316
"""Register a single function."""
308
317
final_schema = generate_function_schema (function , parameter_schema )
309
318
@@ -327,7 +336,7 @@ def get(self, function_name) -> Optional[Callable]:
327
336
328
337
return self .__functions .get (function_name )
329
338
330
- def get_schema (self , function_name ) -> Optional [FunctionSchema ]:
339
+ def get_schema (self , function_name ) -> Optional [FunctionDefinition ]:
331
340
"""Get a function schema by name."""
332
341
return self .__schemas .get (function_name )
333
342
@@ -341,7 +350,7 @@ def get_chatlab_metadata(self, function_name) -> ChatlabMetadata:
341
350
chatlab_metadata = getattr (function , "chatlab_metadata" , ChatlabMetadata ())
342
351
return chatlab_metadata
343
352
344
- def api_manifest (self , function_call_option : FunctionCallOption = "auto" ) -> APIManifest :
353
+ def api_manifest (self , function_call_option : FunctionCall = "auto" ) -> APIManifest :
345
354
"""Get a dictionary containing function definitions and calling options.
346
355
347
356
This is designed to be used with OpenAI's Chat Completion API, where the
@@ -394,12 +403,14 @@ def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> API
394
403
stream=True,
395
404
)
396
405
"""
397
- if len (self .function_definitions ) == 0 :
406
+ function_definitions = [adapt_function_definition (f ) for f in self .__schemas .values ()]
407
+
408
+ if len (function_definitions ) == 0 :
398
409
# When there are no functions, we can't send an empty functions array to OpenAI
399
410
return {}
400
411
401
412
return {
402
- "functions" : self . function_definitions ,
413
+ "functions" : function_definitions ,
403
414
"function_call" : function_call_option ,
404
415
}
405
416
@@ -449,6 +460,6 @@ def __contains__(self, name) -> bool:
449
460
return name in self .__functions
450
461
451
462
@property
452
- def function_definitions (self ) -> list [FunctionSchema ]:
463
+ def function_definitions (self ) -> list [FunctionDefinition ]:
453
464
"""Get a list of function definitions."""
454
465
return list (self .__schemas .values ())
0 commit comments