Skip to content

Commit

Permalink
Merge pull request #133 from rgbkrk/tool-calling
Browse files Browse the repository at this point in the history
Support tool calling by default in Chat
  • Loading branch information
rgbkrk authored Feb 27, 2024
2 parents f700287 + c9ae2b9 commit c3e75e6
Show file tree
Hide file tree
Showing 7 changed files with 485 additions and 406 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.0.0]

- Support parallel tool calling by default in `Chat`.
- Legacy support for function calling is available by passing `legacy_function_calling=True` to the `Chat` constructor.

## [1.3.0]

- Support tool call format from `FunctionRegistry`. Enables parallel function calling (note: not in `Chat` yet). https://github.com/rgbkrk/chatlab/pull/122
Expand Down
118 changes: 93 additions & 25 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ._version import __version__
from .errors import ChatLabError
from .messaging import human
from .messaging import assistant_tool_calls, human
from .registry import FunctionRegistry, PythonHallucinationFunction
from .views import ToolArguments, AssistantMessageView

Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
chat_functions: Optional[List[Callable]] = None,
allow_hallucinated_python: bool = False,
python_hallucination_function: Optional[PythonHallucinationFunction] = None,
legacy_function_calling: bool = False,
):
"""Initialize a Chat with an optional initial context of messages.
Expand All @@ -99,6 +100,8 @@ def __init__(
self.api_key = openai_api_key
self.base_url = base_url

self.legacy_function_calling = legacy_function_calling

if initial_context is None:
initial_context = [] # type: ignore

Expand Down Expand Up @@ -142,11 +145,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre

async def __process_stream(
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[ToolArguments]]:
) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[ToolArguments] = None
finish_reason = None

tool_calls: list[ToolArguments] = []

async for result in resp: # Go through the results of the stream
choices = result.choices

Expand All @@ -158,18 +163,50 @@ async def __process_stream(

# Is stream choice?
if choice.delta is not None:
if choice.delta.content is not None:
if choice.delta.content is not None and choice.delta.content != "":
assistant_view.display_once()
assistant_view.append(choice.delta.content)
elif choice.delta.tool_calls is not None:
if not assistant_view.finished:
assistant_view.finished = True

if assistant_view.content != "":
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)
for tool_call in choice.delta.tool_calls:
if tool_call.function is None:
# This should not be occurring. We could continue instead.
raise ValueError("Tool call without function")
# If this is a continuation of a tool call, then we have to change the tool argument
if tool_call.index < len(tool_calls):
tool_argument = tool_calls[tool_call.index]
if tool_call.function.arguments is not None:
tool_argument.append_arguments(tool_call.function.arguments)
elif (
tool_call.function.name is not None
and tool_call.function.arguments is not None
and tool_call.id is not None
):
# Must build up
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_argument.display()
tool_calls.append(tool_argument)

elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
if not assistant_view.finished:
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)
assistant_view.finished = True
if assistant_view.content != "":
# Flush out the finished assistant message
message = assistant_view.get_message()
self.append(message)

# IDs are for the tool calling apparatus from newer versions of the API
# We will make use of it later.
# Function call just uses the name. It's 1:1, whereas tools allow for multiple calls.
function_view = ToolArguments(id="TBD", name=function_call.name)
function_view.display()
if function_call.arguments is not None:
Expand All @@ -189,15 +226,19 @@ async def __process_stream(
if finish_reason is None:
raise ValueError("No finish reason provided by OpenAI")

return (finish_reason, function_view)
return (finish_reason, function_view, tool_calls)

async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[ToolArguments]]:
async def __process_full_completion(
self, resp: ChatCompletion
) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[ToolArguments] = None

tool_calls: list[ToolArguments] = []

if len(resp.choices) == 0:
logger.warning(f"Result has no choices: {resp}")
return ("stop", None) # TODO
return ("stop", None, tool_calls) # TODO

choice = resp.choices[0]

Expand All @@ -211,8 +252,18 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op
function_call = message.function_call
function_view = ToolArguments(id="TBD", name=function_call.name, arguments=function_call.arguments)
function_view.display()
if message.tool_calls is not None:
for tool_call in message.tool_calls:
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_argument.display()
tool_calls.append(tool_argument)

# TODO: self.append the big tools payload, verify this
self.append(message.model_dump()) # type: ignore

return choice.finish_reason, function_view
return choice.finish_reason, function_view, tool_calls

async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Expand All @@ -231,6 +282,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
"""
full_messages: List[ChatCompletionMessageParam] = []
full_messages.extend(self.messages)

# TODO: Just keeping this aside while working on both stream and non-stream
tool_arguments: List[ToolArguments] = []

for message in messages:
if isinstance(message, str):
full_messages.append(human(message))
Expand All @@ -243,37 +298,39 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()
chat_create_kwargs = {
"model": self.model,
"messages": full_messages,
"temperature": kwargs.get("temperature", 0),
}

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
if self.legacy_function_calling:
chat_create_kwargs.update(self.function_registry.api_manifest())
else:
chat_create_kwargs["tools"] = self.function_registry.tools

streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
**chat_create_kwargs,
stream=True,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
finish_reason, function_call_request, tool_arguments = await self.__process_stream(streaming_response)
else:
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
**chat_create_kwargs,
stream=False,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

(
finish_reason,
function_call_request,
) = await self.__process_full_completion(full_response)
(finish_reason, function_call_request, tool_arguments) = await self.__process_full_completion(
full_response
)

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
Expand All @@ -299,6 +356,17 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
await self.submit(stream=stream, **kwargs)
return

if finish_reason == "tool_calls":
self.append(assistant_tool_calls(tool_arguments))
for tool_argument in tool_arguments:
# Oh crap I need to append the big assistant call of it too. May have to assume we've done it by here.
function_called = await tool_argument.call(self.function_registry)
# TODO: Format the tool message
self.append(function_called.get_tool_called_message())

await self.submit(stream=stream, **kwargs)
return

# All other finish reasons are valid for regular assistant messages
if finish_reason == "stop":
return
Expand Down
34 changes: 31 additions & 3 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
"""

from typing import Optional
from typing import Optional, Iterable, Protocol, List

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionMessageToolCallParam,
)


def assistant(content: str) -> ChatCompletionMessageParam:
Expand Down Expand Up @@ -100,7 +104,29 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:
}


def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam:
class HasGetToolArgumentsParameter(Protocol):
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam:
...


def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam:
converted_tool_calls: List[ChatCompletionMessageToolCallParam] = []

for tool_call in tool_calls:
converted_tool_calls.append(tool_call.get_tool_arguments_parameter())

return {
"role": "assistant",
"tool_calls": converted_tool_calls,
}


class ChatCompletionToolMessageParamWithName(ChatCompletionToolMessageParam):
name: Optional[str]
"""The name of the tool."""


def tool_result(tool_call_id: str, content: str, name: str) -> ChatCompletionToolMessageParamWithName:
"""Create a tool result message.
Args:
Expand All @@ -112,10 +138,12 @@ def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessagePar
"""
return {
"role": "tool",
"name": name,
"content": content,
"tool_call_id": tool_call_id,
}


# Aliases
narrate = system
human = user
Expand Down
2 changes: 1 addition & 1 deletion chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:

parameters: dict = {}

if arguments is not None:
if arguments is not None and arguments != "":
try:
parameters = json.loads(arguments)
except json.JSONDecodeError:
Expand Down
5 changes: 2 additions & 3 deletions chatlab/views/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from ..messaging import assistant


class AssistantMessageView(Markdown):
content: str= ""
content: str = ""
finished: bool = False
has_displayed: bool = False

Expand All @@ -17,5 +18,3 @@ def display(self):
def display_once(self):
if not self.has_displayed:
self.display()


Loading

0 comments on commit c3e75e6

Please sign in to comment.