Skip to content

Commit 1ffc573

Browse files
authored
Merge pull request #119 from rgbkrk/automatic-model-creation
automatically detect pydantic models and instantiate them
2 parents 45a8fd5 + f7b2383 commit 1ffc573

File tree

3 files changed

+371
-429
lines changed

3 files changed

+371
-429
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- ⏃ Automatically instantiate pydantic models when they are passed as parameters to a function call
11+
1012
## [1.1.1]
1113

1214
- Support setting a custom `base_url`

chatlab/registry.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,21 @@ class WhatTime(BaseModel):
4242
import asyncio
4343
import inspect
4444
import json
45-
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypedDict, Union, get_args, get_origin, overload
45+
from typing import (
46+
Any,
47+
Callable,
48+
Dict,
49+
Iterable,
50+
List,
51+
Optional,
52+
Type,
53+
TypedDict,
54+
Union,
55+
cast,
56+
get_args,
57+
get_origin,
58+
overload,
59+
)
4660

4761
from openai.types import FunctionDefinition
4862
from openai.types.chat.completion_create_params import Function, FunctionCall
@@ -419,9 +433,6 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
419433
if name is None:
420434
raise UnknownFunctionError("Function name must be provided")
421435

422-
function = self.get(name)
423-
parameters: dict = {}
424-
425436
# Handle the code interpreter hallucination
426437
if name == "python" and self.python_hallucination_function is not None:
427438
function = self.python_hallucination_function
@@ -433,24 +444,38 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
433444
if asyncio.iscoroutinefunction(function):
434445
return await function(arguments)
435446
return function(arguments)
436-
elif function is None:
447+
448+
possible_function = self.get(name)
449+
450+
if possible_function is None:
437451
raise UnknownFunctionError(f"Function {name} is not registered")
438-
elif arguments is None or arguments == "":
439-
parameters = {}
440-
else:
452+
453+
function = possible_function
454+
455+
parameters: dict = {}
456+
457+
if arguments is not None:
441458
try:
442459
parameters = json.loads(arguments)
443-
# TODO: Validate parameters against schema
444460
except json.JSONDecodeError:
445461
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")
446462

447-
if function is None:
448-
raise UnknownFunctionError(f"Function {name} is not registered")
463+
prepared_arguments = {}
464+
465+
for param_name, param in inspect.signature(function).parameters.items():
466+
param_type = param.annotation
467+
arg_value = parameters.get(param_name)
468+
469+
# Check if parameter type is a subclass of BaseModel and deserialize JSON into Pydantic model
470+
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
471+
prepared_arguments[param_name] = param_type.model_validate(arg_value)
472+
else:
473+
prepared_arguments[param_name] = cast(Any, arg_value)
449474

450475
if asyncio.iscoroutinefunction(function):
451-
result = await function(**parameters)
476+
result = await function(**prepared_arguments)
452477
else:
453-
result = function(**parameters)
478+
result = function(**prepared_arguments)
454479
return result
455480

456481
def __contains__(self, name) -> bool:

0 commit comments

Comments
 (0)