@@ -42,7 +42,21 @@ class WhatTime(BaseModel):
42
42
import asyncio
43
43
import inspect
44
44
import json
45
- from typing import Any , Callable , Dict , Iterable , List , Optional , Type , TypedDict , Union , get_args , get_origin , overload
45
+ from typing import (
46
+ Any ,
47
+ Callable ,
48
+ Dict ,
49
+ Iterable ,
50
+ List ,
51
+ Optional ,
52
+ Type ,
53
+ TypedDict ,
54
+ Union ,
55
+ cast ,
56
+ get_args ,
57
+ get_origin ,
58
+ overload ,
59
+ )
46
60
47
61
from openai .types import FunctionDefinition
48
62
from openai .types .chat .completion_create_params import Function , FunctionCall
@@ -419,9 +433,6 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
419
433
if name is None :
420
434
raise UnknownFunctionError ("Function name must be provided" )
421
435
422
- function = self .get (name )
423
- parameters : dict = {}
424
-
425
436
# Handle the code interpreter hallucination
426
437
if name == "python" and self .python_hallucination_function is not None :
427
438
function = self .python_hallucination_function
@@ -433,24 +444,38 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
433
444
if asyncio .iscoroutinefunction (function ):
434
445
return await function (arguments )
435
446
return function (arguments )
436
- elif function is None :
447
+
448
+ possible_function = self .get (name )
449
+
450
+ if possible_function is None :
437
451
raise UnknownFunctionError (f"Function { name } is not registered" )
438
- elif arguments is None or arguments == "" :
439
- parameters = {}
440
- else :
452
+
453
+ function = possible_function
454
+
455
+ parameters : dict = {}
456
+
457
+ if arguments is not None :
441
458
try :
442
459
parameters = json .loads (arguments )
443
- # TODO: Validate parameters against schema
444
460
except json .JSONDecodeError :
445
461
raise FunctionArgumentError (f"Invalid Function call on { name } . Arguments must be a valid JSON object" )
446
462
447
- if function is None :
448
- raise UnknownFunctionError (f"Function { name } is not registered" )
463
+ prepared_arguments = {}
464
+
465
+ for param_name , param in inspect .signature (function ).parameters .items ():
466
+ param_type = param .annotation
467
+ arg_value = parameters .get (param_name )
468
+
469
+ # Check if parameter type is a subclass of BaseModel and deserialize JSON into Pydantic model
470
+ if inspect .isclass (param_type ) and issubclass (param_type , BaseModel ):
471
+ prepared_arguments [param_name ] = param_type .model_validate (arg_value )
472
+ else :
473
+ prepared_arguments [param_name ] = cast (Any , arg_value )
449
474
450
475
if asyncio .iscoroutinefunction (function ):
451
- result = await function (** parameters )
476
+ result = await function (** prepared_arguments )
452
477
else :
453
- result = function (** parameters )
478
+ result = function (** prepared_arguments )
454
479
return result
455
480
456
481
def __contains__ (self , name ) -> bool :
0 commit comments