Skip to content

Commit

Permalink
automatically detect pydantic models and instantiate them
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Nov 30, 2023
1 parent 45a8fd5 commit 7acb147
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,6 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
if name is None:
raise UnknownFunctionError("Function name must be provided")

function = self.get(name)
parameters: dict = {}

# Handle the code interpreter hallucination
if name == "python" and self.python_hallucination_function is not None:
function = self.python_hallucination_function
Expand All @@ -433,24 +430,39 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
if asyncio.iscoroutinefunction(function):
return await function(arguments)
return function(arguments)
elif function is None:

possible_function = self.get(name)

if possible_function is None:
raise UnknownFunctionError(f"Function {name} is not registered")
elif arguments is None or arguments == "":
parameters = {}
else:

function = possible_function

parameters: dict = {}

if arguments is not None:
try:
parameters = json.loads(arguments)
# TODO: Validate parameters against schema
except json.JSONDecodeError:
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")

if function is None:
raise UnknownFunctionError(f"Function {name} is not registered")
prepared_arguments = {}

for param_name, param in inspect.signature(function).parameters.items():
param_type = param.annotation
arg_value = parameters.get(param_name)

# Check if parameter type is a subclass of BaseModel and deserialize JSON into Pydantic model
if issubclass(param_type, BaseModel):
prepared_arguments[param_name] = param_type.model_validate(arg_value)
else:
prepared_arguments[param_name] = arg_value


if asyncio.iscoroutinefunction(function):
result = await function(**parameters)
result = await function(**prepared_arguments)
else:
result = function(**parameters)
result = function(**prepared_arguments)
return result

def __contains__(self, name) -> bool:
Expand Down

0 comments on commit 7acb147

Please sign in to comment.