diff --git a/CHANGELOG.md b/CHANGELOG.md index bb9a52e..30b9692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.0.0] + +- Support parallel tool calling by default in `Chat`. +- Legacy support for function calling is available by passing `legacy_function_calling=True` to the `Chat` constructor. + ## [1.3.0] - Support tool call format from `FunctionRegistry`. Enables parallel function calling (note: not in `Chat` yet). https://github.com/rgbkrk/chatlab/pull/122 diff --git a/chatlab/chat.py b/chatlab/chat.py index 790d552..0d908b8 100644 --- a/chatlab/chat.py +++ b/chatlab/chat.py @@ -26,7 +26,7 @@ from ._version import __version__ from .errors import ChatLabError -from .messaging import human +from .messaging import assistant_tool_calls, human from .registry import FunctionRegistry, PythonHallucinationFunction from .views import ToolArguments, AssistantMessageView @@ -73,6 +73,7 @@ def __init__( chat_functions: Optional[List[Callable]] = None, allow_hallucinated_python: bool = False, python_hallucination_function: Optional[PythonHallucinationFunction] = None, + legacy_function_calling: bool = False, ): """Initialize a Chat with an optional initial context of messages. @@ -99,6 +100,8 @@ def __init__( self.api_key = openai_api_key self.base_url = base_url + self.legacy_function_calling = legacy_function_calling + if initial_context is None: initial_context = [] # type: ignore @@ -142,11 +145,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre async def __process_stream( self, resp: AsyncStream[ChatCompletionChunk] - ) -> Tuple[str, Optional[ToolArguments]]: + ) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]: assistant_view: AssistantMessageView = AssistantMessageView() function_view: Optional[ToolArguments] = None finish_reason = None + tool_calls: list[ToolArguments] = [] + async for result in resp: # Go through the results of the stream choices = result.choices @@ -158,18 +163,50 @@ async def __process_stream( # Is stream choice? if choice.delta is not None: - if choice.delta.content is not None: + if choice.delta.content is not None and choice.delta.content != "": assistant_view.display_once() assistant_view.append(choice.delta.content) + elif choice.delta.tool_calls is not None: + if not assistant_view.finished: + assistant_view.finished = True + + if assistant_view.content != "": + # Flush out the finished assistant message + message = assistant_view.get_message() + self.append(message) + for tool_call in choice.delta.tool_calls: + if tool_call.function is None: + # This should not be occurring. We could continue instead. + raise ValueError("Tool call without function") + # If this is a continuation of a tool call, then we have to change the tool argument + if tool_call.index < len(tool_calls): + tool_argument = tool_calls[tool_call.index] + if tool_call.function.arguments is not None: + tool_argument.append_arguments(tool_call.function.arguments) + elif ( + tool_call.function.name is not None + and tool_call.function.arguments is not None + and tool_call.id is not None + ): + # Must build up + tool_argument = ToolArguments( + id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments + ) + tool_argument.display() + tool_calls.append(tool_argument) + elif choice.delta.function_call is not None: function_call = choice.delta.function_call if function_call.name is not None: if not assistant_view.finished: - # Flush out the finished assistant message - message = assistant_view.get_message() - self.append(message) + assistant_view.finished = True + if assistant_view.content != "": + # Flush out the finished assistant message + message = assistant_view.get_message() + self.append(message) + # IDs are for the tool calling apparatus from newer versions of the API - # We will make use of it later. + # Function call just uses the name. It's 1:1, whereas tools allow for multiple calls. function_view = ToolArguments(id="TBD", name=function_call.name) function_view.display() if function_call.arguments is not None: @@ -189,15 +226,19 @@ async def __process_stream( if finish_reason is None: raise ValueError("No finish reason provided by OpenAI") - return (finish_reason, function_view) + return (finish_reason, function_view, tool_calls) - async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[ToolArguments]]: + async def __process_full_completion( + self, resp: ChatCompletion + ) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]: assistant_view: AssistantMessageView = AssistantMessageView() function_view: Optional[ToolArguments] = None + tool_calls: list[ToolArguments] = [] + if len(resp.choices) == 0: logger.warning(f"Result has no choices: {resp}") - return ("stop", None) # TODO + return ("stop", None, tool_calls) # TODO choice = resp.choices[0] @@ -211,8 +252,18 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op function_call = message.function_call function_view = ToolArguments(id="TBD", name=function_call.name, arguments=function_call.arguments) function_view.display() + if message.tool_calls is not None: + for tool_call in message.tool_calls: + tool_argument = ToolArguments( + id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments + ) + tool_argument.display() + tool_calls.append(tool_argument) + + # TODO: self.append the big tools payload, verify this + self.append(message.model_dump()) # type: ignore - return choice.finish_reason, function_view + return choice.finish_reason, function_view, tool_calls async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs): """Send messages to the chat model and display the response. @@ -231,6 +282,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream """ full_messages: List[ChatCompletionMessageParam] = [] full_messages.extend(self.messages) + + # TODO: Just keeping this aside while working on both stream and non-stream + tool_arguments: List[ToolArguments] = [] + for message in messages: if isinstance(message, str): full_messages.append(human(message)) @@ -243,37 +298,39 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream base_url=self.base_url, ) - api_manifest = self.function_registry.api_manifest() + chat_create_kwargs = { + "model": self.model, + "messages": full_messages, + "temperature": kwargs.get("temperature", 0), + } # Due to the strict response typing based on `Literal` typing on `stream`, we have to process these # two cases separately if stream: + if self.legacy_function_calling: + chat_create_kwargs.update(self.function_registry.api_manifest()) + else: + chat_create_kwargs["tools"] = self.function_registry.tools + streaming_response = await client.chat.completions.create( - model=self.model, - messages=full_messages, - **api_manifest, + **chat_create_kwargs, stream=True, - temperature=kwargs.get("temperature", 0), ) self.append(*messages) - finish_reason, function_call_request = await self.__process_stream(streaming_response) + finish_reason, function_call_request, tool_arguments = await self.__process_stream(streaming_response) else: full_response = await client.chat.completions.create( - model=self.model, - messages=full_messages, - **api_manifest, + **chat_create_kwargs, stream=False, - temperature=kwargs.get("temperature", 0), ) self.append(*messages) - ( - finish_reason, - function_call_request, - ) = await self.__process_full_completion(full_response) + (finish_reason, function_call_request, tool_arguments) = await self.__process_full_completion( + full_response + ) except openai.RateLimitError as e: logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.") @@ -299,6 +356,17 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream await self.submit(stream=stream, **kwargs) return + if finish_reason == "tool_calls": + self.append(assistant_tool_calls(tool_arguments)) + for tool_argument in tool_arguments: + # Oh crap I need to append the big assistant call of it too. May have to assume we've done it by here. + function_called = await tool_argument.call(self.function_registry) + # TODO: Format the tool message + self.append(function_called.get_tool_called_message()) + + await self.submit(stream=stream, **kwargs) + return + # All other finish reasons are valid for regular assistant messages if finish_reason == "stop": return diff --git a/chatlab/messaging.py b/chatlab/messaging.py index 547e658..1c17dcb 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -10,9 +10,13 @@ """ -from typing import Optional +from typing import Optional, Iterable, Protocol, List -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionMessageToolCallParam, +) def assistant(content: str) -> ChatCompletionMessageParam: @@ -100,7 +104,29 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam: } -def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam: +class HasGetToolArgumentsParameter(Protocol): + def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam: + ... + + +def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam: + converted_tool_calls: List[ChatCompletionMessageToolCallParam] = [] + + for tool_call in tool_calls: + converted_tool_calls.append(tool_call.get_tool_arguments_parameter()) + + return { + "role": "assistant", + "tool_calls": converted_tool_calls, + } + + +class ChatCompletionToolMessageParamWithName(ChatCompletionToolMessageParam): + name: Optional[str] + """The name of the tool.""" + + +def tool_result(tool_call_id: str, content: str, name: str) -> ChatCompletionToolMessageParamWithName: """Create a tool result message. Args: @@ -112,10 +138,12 @@ def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessagePar """ return { "role": "tool", + "name": name, "content": content, "tool_call_id": tool_call_id, } + # Aliases narrate = system human = user diff --git a/chatlab/registry.py b/chatlab/registry.py index 6331460..3631160 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -460,7 +460,7 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any: parameters: dict = {} - if arguments is not None: + if arguments is not None and arguments != "": try: parameters = json.loads(arguments) except json.JSONDecodeError: diff --git a/chatlab/views/assistant.py b/chatlab/views/assistant.py index 436fad4..bb3cd76 100644 --- a/chatlab/views/assistant.py +++ b/chatlab/views/assistant.py @@ -2,8 +2,9 @@ from ..messaging import assistant + class AssistantMessageView(Markdown): - content: str= "" + content: str = "" finished: bool = False has_displayed: bool = False @@ -17,5 +18,3 @@ def display(self): def display_once(self): if not self.has_displayed: self.display() - - diff --git a/chatlab/views/tools.py b/chatlab/views/tools.py index 2be94d2..02cfb0b 100644 --- a/chatlab/views/tools.py +++ b/chatlab/views/tools.py @@ -4,10 +4,14 @@ from ..registry import FunctionRegistry, FunctionArgumentError, UnknownFunctionError -from ..messaging import assistant_function_call, function_result - +from ..messaging import assistant_function_call, function_result, tool_result + +from openai.types.chat import ChatCompletionMessageToolCallParam + + class ToolCalled(AutoUpdate): """Once a tool has finished up, this is the view.""" + id: str name: str arguments: str = "" @@ -15,17 +19,16 @@ class ToolCalled(AutoUpdate): result: str = "" def render(self): - return ChatFunctionComponent( - name=self.name, - verbage="ok", - input=self.arguments, - output=self.result - ) - + return ChatFunctionComponent(name=self.name, verbage="ok", input=self.arguments, output=self.result) + # TODO: This is only here for legacy function calling def get_function_called_message(self): return function_result(self.name, self.result) + def get_tool_called_message(self): + # NOTE: OpenAI has mismatched types where it doesn't include the `name` + # xref: https://github.com/openai/openai-python/issues/1078 + return tool_result(tool_call_id=self.id, content=self.result, name=self.name) class ToolArguments(AutoUpdate): @@ -39,26 +42,21 @@ class ToolArguments(AutoUpdate): def get_function_message(self): return assistant_function_call(self.name, self.arguments) + def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam: + return {"id": self.id, "function": {"name": self.name, "arguments": self.arguments}, "type": "function"} + def render(self): - return ChatFunctionComponent( - name=self.name, - verbage=self.verbage, - input=self.arguments - ) - + return ChatFunctionComponent(name=self.name, verbage=self.verbage, input=self.arguments) + def append_arguments(self, arguments: str): self.arguments += arguments - + def apply_result(self, result: str): """Replaces the existing display with a new one that shows the result of the tool being called.""" return ToolCalled( - id=self.id, - name=self.name, - arguments=self.arguments, - result=result, - display_id=self.display_id + id=self.id, name=self.name, arguments=self.arguments, result=result, display_id=self.display_id ) - + async def call(self, function_registry: FunctionRegistry) -> ToolCalled: """Call the function and return a stack of messages for LLM and human consumption.""" function_name = self.name @@ -113,4 +111,3 @@ async def call(self, function_registry: FunctionRegistry) -> ToolCalled: self.verbage = "Ran" return self.apply_result(repr_llm) - diff --git a/notebooks/knowledge-graph.ipynb b/notebooks/knowledge-graph.ipynb index 66d4298..02418c0 100644 --- a/notebooks/knowledge-graph.ipynb +++ b/notebooks/knowledge-graph.ipynb @@ -2,136 +2,141 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install graphviz -q" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "