Skip to content

Commit

Permalink
adapt to newer OpenAI FunctionCall types
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Nov 10, 2023
1 parent 5866e9c commit 9051bb7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
19 changes: 10 additions & 9 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
import asyncio
import logging
import os
from typing import AsyncIterator, Callable, List, Optional, Tuple, Type, Union, overload
from typing import Callable, List, Optional, Tuple, Type, Union, overload

import openai
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncStream
from openai.types import FunctionDefinition
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel

Expand All @@ -29,7 +30,7 @@
from .display import ChatFunctionCall
from .errors import ChatLabError
from .messaging import human
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
from .registry import FunctionRegistry, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,7 +139,7 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre
return await self.submit(*messages, stream=stream, **kwargs)

async def __process_stream(
self, resp: AsyncIterator[ChatCompletionChunk]
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
Expand Down Expand Up @@ -232,15 +233,15 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
try:
client = AsyncOpenAI()

manifest = self.function_registry.api_manifest()
api_manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**manifest,
**api_manifest,
stream=True,
temperature=kwargs.get("temperature", 0),
)
Expand All @@ -250,7 +251,7 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**manifest,
**api_manifest,
stream=False,
temperature=kwargs.get("temperature", 0),
)
Expand Down Expand Up @@ -333,14 +334,14 @@ def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionSchema:
) -> FunctionDefinition:
...

def register(
self,
function: Optional[Callable] = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Union[Callable, FunctionSchema]:
) -> Union[Callable, FunctionDefinition]:
"""Register a function with the ChatLab instance.
This can be used as a decorator like so:
Expand Down
43 changes: 27 additions & 16 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class WhatTime(BaseModel):
import json
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypedDict, Union, get_args, get_origin, overload

from openai.types.chat.completion_create_params import Function as FunctionSchema
from openai.types.chat.completion_create_params import FunctionCall as FunctionCallOption
from openai.types import FunctionDefinition
from openai.types.chat.completion_create_params import Function, FunctionCall
from pydantic import BaseModel, create_model

from .decorators import ChatlabMetadata
Expand All @@ -54,10 +54,10 @@ class WhatTime(BaseModel):
class APIManifest(TypedDict, total=False):
"""The schema for the API."""

functions: List[FunctionSchema]
functions: List[Function]
"""A list of functions that the model can call during the conversation."""

function_call: FunctionCallOption
function_call: FunctionCall
"""The policy for when to call functions.
One of "auto", "none", or a dictionary with a "name" key.
Expand Down Expand Up @@ -110,7 +110,7 @@ class FunctionSchemaConfig:
def generate_function_schema(
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionSchema:
) -> FunctionDefinition:
"""Generate a function schema for sending to OpenAI."""
doc = function.__doc__
func_name = function.__name__
Expand All @@ -122,7 +122,7 @@ def generate_function_schema(
if not doc:
raise Exception("Only functions with docstrings can be registered")

schema = FunctionSchema(
schema = FunctionDefinition(
name=func_name,
description=doc,
parameters={},
Expand Down Expand Up @@ -184,14 +184,23 @@ def generate_function_schema(
if "required" not in parameters:
parameters["required"] = []

schema["parameters"] = parameters
schema.parameters = parameters
return schema


# Declare the type for the python hallucination
PythonHallucinationFunction = Callable[[str], Any]


def adapt_function_definition(fd: FunctionDefinition) -> Function:
"""Adapt a FunctionDefinition to a Function for working with the OpenAI API."""
return {
"name": fd.name,
"parameters": fd.parameters,
"description": fd.description if fd.description is not None else "",
}


class FunctionRegistry:
"""Registry of functions and their schemas for calling them.
Expand Down Expand Up @@ -229,7 +238,7 @@ class WhatTime(BaseModel):
"""

__functions: dict[str, Callable]
__schemas: dict[str, FunctionSchema]
__schemas: dict[str, FunctionDefinition]

# Allow passing in a callable that accepts a single string for the python
# hallucination function. This is useful for testing.
Expand Down Expand Up @@ -265,14 +274,14 @@ def register(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionSchema:
) -> FunctionDefinition:
...

def register(
self,
function: Optional[Callable] = None,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> Union[Callable, FunctionSchema]:
) -> Union[Callable, FunctionDefinition]:
"""Register a function for use in `Chat`s. Can be used as a decorator or directly to register a function.
>>> registry = FunctionRegistry()
Expand Down Expand Up @@ -303,7 +312,7 @@ def register_function(
self,
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
) -> FunctionSchema:
) -> FunctionDefinition:
"""Register a single function."""
final_schema = generate_function_schema(function, parameter_schema)

Expand All @@ -327,7 +336,7 @@ def get(self, function_name) -> Optional[Callable]:

return self.__functions.get(function_name)

def get_schema(self, function_name) -> Optional[FunctionSchema]:
def get_schema(self, function_name) -> Optional[FunctionDefinition]:
"""Get a function schema by name."""
return self.__schemas.get(function_name)

Expand All @@ -341,7 +350,7 @@ def get_chatlab_metadata(self, function_name) -> ChatlabMetadata:
chatlab_metadata = getattr(function, "chatlab_metadata", ChatlabMetadata())
return chatlab_metadata

def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> APIManifest:
def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManifest:
"""Get a dictionary containing function definitions and calling options.
This is designed to be used with OpenAI's Chat Completion API, where the
Expand Down Expand Up @@ -394,12 +403,14 @@ def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> API
stream=True,
)
"""
if len(self.function_definitions) == 0:
function_definitions = [adapt_function_definition(f) for f in self.__schemas.values()]

if len(function_definitions) == 0:
# When there are no functions, we can't send an empty functions array to OpenAI
return {}

return {
"functions": self.function_definitions,
"functions": function_definitions,
"function_call": function_call_option,
}

Expand Down Expand Up @@ -449,6 +460,6 @@ def __contains__(self, name) -> bool:
return name in self.__functions

@property
def function_definitions(self) -> list[FunctionSchema]:
def function_definitions(self) -> list[FunctionDefinition]:
"""Get a list of function definitions."""
return list(self.__schemas.values())

0 comments on commit 9051bb7

Please sign in to comment.