Skip to content

Support tool calling by default in Chat #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.0.0]

- Support parallel tool calling by default in `Chat`.
- Legacy support for function calling is available by passing `legacy_function_calling=True` to the `Chat` constructor.

## [1.3.0]

- Support tool call format from `FunctionRegistry`. Enables parallel function calling (note: not in `Chat` yet). https://github.com/rgbkrk/chatlab/pull/122
Expand Down
118 changes: 93 additions & 25 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ._version import __version__
from .errors import ChatLabError
from .messaging import human
from .messaging import assistant_tool_calls, human
from .registry import FunctionRegistry, PythonHallucinationFunction
from .views import ToolArguments, AssistantMessageView

Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
chat_functions: Optional[List[Callable]] = None,
allow_hallucinated_python: bool = False,
python_hallucination_function: Optional[PythonHallucinationFunction] = None,
legacy_function_calling: bool = False,
):
"""Initialize a Chat with an optional initial context of messages.

Expand All @@ -99,6 +100,8 @@ def __init__(
self.api_key = openai_api_key
self.base_url = base_url

self.legacy_function_calling = legacy_function_calling

if initial_context is None:
initial_context = [] # type: ignore

Expand Down Expand Up @@ -142,11 +145,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre

async def __process_stream(
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[ToolArguments]]:
) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[ToolArguments] = None
finish_reason = None

tool_calls: list[ToolArguments] = []

async for result in resp: # Go through the results of the stream
choices = result.choices

Expand All @@ -158,18 +163,50 @@ async def __process_stream(

# Is stream choice?
if choice.delta is not None:
if choice.delta.content is not None:
if choice.delta.content is not None and choice.delta.content != "":
assistant_view.display_once()
assistant_view.append(choice.delta.content)
elif choice.delta.tool_calls is not None:
if not assistant_view.finished:
assistant_view.finished = True

if assistant_view.content != "":
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)
for tool_call in choice.delta.tool_calls:
if tool_call.function is None:
# This should not be occurring. We could continue instead.
raise ValueError("Tool call without function")
# If this is a continuation of a tool call, then we have to change the tool argument
if tool_call.index < len(tool_calls):
tool_argument = tool_calls[tool_call.index]
if tool_call.function.arguments is not None:
tool_argument.append_arguments(tool_call.function.arguments)
elif (
tool_call.function.name is not None
and tool_call.function.arguments is not None
and tool_call.id is not None
):
# Must build up
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_argument.display()
tool_calls.append(tool_argument)

elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
if not assistant_view.finished:
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)
assistant_view.finished = True
if assistant_view.content != "":
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)

# IDs are for the tool calling apparatus from newer versions of the API
# We will make use of it later.
# Function call just uses the name. It's 1:1, whereas tools allow for multiple calls.
function_view = ToolArguments(id="TBD", name=function_call.name)
function_view.display()
if function_call.arguments is not None:
Expand All @@ -189,15 +226,19 @@ async def __process_stream(
if finish_reason is None:
raise ValueError("No finish reason provided by OpenAI")

return (finish_reason, function_view)
return (finish_reason, function_view, tool_calls)

async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[ToolArguments]]:
async def __process_full_completion(
self, resp: ChatCompletion
) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[ToolArguments] = None

tool_calls: list[ToolArguments] = []

if len(resp.choices) == 0:
logger.warning(f"Result has no choices: {resp}")
return ("stop", None) # TODO
return ("stop", None, tool_calls) # TODO

choice = resp.choices[0]

Expand All @@ -211,8 +252,18 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op
function_call = message.function_call
function_view = ToolArguments(id="TBD", name=function_call.name, arguments=function_call.arguments)
function_view.display()
if message.tool_calls is not None:
for tool_call in message.tool_calls:
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_argument.display()
tool_calls.append(tool_argument)

# TODO: self.append the big tools payload, verify this
self.append(message.model_dump()) # type: ignore

return choice.finish_reason, function_view
return choice.finish_reason, function_view, tool_calls

async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Expand All @@ -231,6 +282,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
"""
full_messages: List[ChatCompletionMessageParam] = []
full_messages.extend(self.messages)

# TODO: Just keeping this aside while working on both stream and non-stream
tool_arguments: List[ToolArguments] = []

for message in messages:
if isinstance(message, str):
full_messages.append(human(message))
Expand All @@ -243,37 +298,39 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()
chat_create_kwargs = {
"model": self.model,
"messages": full_messages,
"temperature": kwargs.get("temperature", 0),
}

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
if self.legacy_function_calling:
chat_create_kwargs.update(self.function_registry.api_manifest())
else:
chat_create_kwargs["tools"] = self.function_registry.tools

streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
**chat_create_kwargs,
stream=True,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
finish_reason, function_call_request, tool_arguments = await self.__process_stream(streaming_response)
else:
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
**chat_create_kwargs,
stream=False,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

(
finish_reason,
function_call_request,
) = await self.__process_full_completion(full_response)
(finish_reason, function_call_request, tool_arguments) = await self.__process_full_completion(
full_response
)

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
Expand All @@ -299,6 +356,17 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
await self.submit(stream=stream, **kwargs)
return

if finish_reason == "tool_calls":
self.append(assistant_tool_calls(tool_arguments))
for tool_argument in tool_arguments:
# Oh crap I need to append the big assistant call of it too. May have to assume we've done it by here.
function_called = await tool_argument.call(self.function_registry)
# TODO: Format the tool message
self.append(function_called.get_tool_called_message())

await self.submit(stream=stream, **kwargs)
return

# All other finish reasons are valid for regular assistant messages
if finish_reason == "stop":
return
Expand Down
34 changes: 31 additions & 3 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@

"""

from typing import Optional
from typing import Optional, Iterable, Protocol, List

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionMessageToolCallParam,
)


def assistant(content: str) -> ChatCompletionMessageParam:
Expand Down Expand Up @@ -100,7 +104,29 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:
}


def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam:
class HasGetToolArgumentsParameter(Protocol):
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam:
...


def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam:
converted_tool_calls: List[ChatCompletionMessageToolCallParam] = []

for tool_call in tool_calls:
converted_tool_calls.append(tool_call.get_tool_arguments_parameter())

return {
"role": "assistant",
"tool_calls": converted_tool_calls,
}


class ChatCompletionToolMessageParamWithName(ChatCompletionToolMessageParam):
name: Optional[str]
"""The name of the tool."""


def tool_result(tool_call_id: str, content: str, name: str) -> ChatCompletionToolMessageParamWithName:
"""Create a tool result message.

Args:
Expand All @@ -112,10 +138,12 @@ def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessagePar
"""
return {
"role": "tool",
"name": name,
"content": content,
"tool_call_id": tool_call_id,
}


# Aliases
narrate = system
human = user
Expand Down
2 changes: 1 addition & 1 deletion chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:

parameters: dict = {}

if arguments is not None:
if arguments is not None and arguments != "":
try:
parameters = json.loads(arguments)
except json.JSONDecodeError:
Expand Down
5 changes: 2 additions & 3 deletions chatlab/views/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from ..messaging import assistant


class AssistantMessageView(Markdown):
content: str= ""
content: str = ""
finished: bool = False
has_displayed: bool = False

Expand All @@ -17,5 +18,3 @@ def display(self):
def display_once(self):
if not self.has_displayed:
self.display()


Loading