Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Tool calling #125

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI, AsyncStream
from openai.types import FunctionDefinition
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionAssistantMessageParam,
)
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView
Expand All @@ -33,6 +38,9 @@
from .registry import FunctionRegistry, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

from .tool_call import ToolCallBuilder


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -145,11 +153,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre

async def __process_stream(
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
) -> Tuple[str, Optional[AssistantFunctionCallView], ToolCallBuilder]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
finish_reason = None

tool_call_builder = ToolCallBuilder()

async for result in resp: # Go through the results of the stream
choices = result.choices

Expand All @@ -175,6 +185,10 @@ async def __process_stream(
if function_view is None:
raise ValueError("Function arguments provided without function name")
function_view.append(function_call.arguments)

elif choice.delta.tool_calls is not None:
tool_call_builder.update(*choice.delta.tool_calls)

if choice.finish_reason is not None:
finish_reason = choice.finish_reason
break
Expand All @@ -188,9 +202,11 @@ async def __process_stream(
if finish_reason is None:
raise ValueError("No finish reason provided by OpenAI")

return (finish_reason, function_view)
return (finish_reason, function_view, tool_call_builder)

async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]:
async def __process_full_completion(
self, resp: ChatCompletion
) -> Tuple[str, Optional[AssistantFunctionCallView], ToolCallBuilder]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None

Expand All @@ -202,15 +218,19 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op

message = choice.message

tool_call_builder = ToolCallBuilder()

if message.content is not None:
assistant_view.append(message.content)
self.append(assistant_view.flush())
if message.function_call is not None:
function_call = message.function_call
function_view = AssistantFunctionCallView(function_name=function_call.name)
function_view.append(function_call.arguments)
if message.tool_calls is not None:
raise NotImplementedError("Tool calls are not yet supported")

return choice.finish_reason, function_view
return choice.finish_reason, function_view, tool_call_builder

async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Expand Down Expand Up @@ -241,37 +261,36 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=True,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
finish_reason, function_call_request, tool_call_builder = await self.__process_stream(
streaming_response
)
else:
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=False,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

(
finish_reason,
function_call_request,
) = await self.__process_full_completion(full_response)
(finish_reason, function_call_request, tool_call_builder) = await self.__process_full_completion(
full_response
)

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
Expand All @@ -280,6 +299,23 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream

return

if finish_reason == "tool_calls":
tool_calls = tool_call_builder.finalize()
self.append(
ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=tool_calls,
# TODO: content and name may have been specified. I did not collect it above during streaming.
)
)

async for message in tool_call_builder.run(self.function_registry):
# TODO: Reintroduce the visual display
self.append(message)

await self.submit(stream=stream, **kwargs)
return

if finish_reason == "function_call":
if function_call_request is None:
raise ValueError(
Expand Down
1 change: 0 additions & 1 deletion chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def assistant_function_call(name: str, arguments: Optional[str] = None) -> ChatC
},
}


def function_result(name: str, content: str) -> ChatCompletionMessageParam:
"""Create a function result message.

Expand Down
12 changes: 12 additions & 0 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ class WhatTime(BaseModel):
from openai.types.chat.completion_create_params import Function, FunctionCall
from pydantic import BaseModel, create_model

from chatlab.messaging import tool_result

from .decorators import ChatlabMetadata

from openai.types.chat import ChatCompletionToolMessageParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall


class APIManifest(TypedDict, total=False):
"""The schema for the API."""
Expand Down Expand Up @@ -482,6 +487,13 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
result = function(**prepared_arguments)
return result

async def run_tool(self, tool: ChatCompletionMessageToolCall) -> ChatCompletionToolMessageParam:
result = await self.call(tool.function.name, tool.function.arguments)
# TODO: Bring over the special result formatter from chatlab.Chat
tool_call_response = tool_result(tool.id, content=str(result))
# tool_call_response["name"] = tool.function.name
return tool_call_response

def __contains__(self, name) -> bool:
"""Check if a function is registered by name."""
if name == "python" and self.python_hallucination_function:
Expand Down
51 changes: 51 additions & 0 deletions chatlab/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
)

from openai.types.chat import ChatCompletionToolMessageParam

from pydantic import BaseModel

from typing import Dict, List

from chatlab.registry import FunctionRegistry

# We need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)


class ToolCallBuilder(BaseModel):
# Assuming that we might not get all the tool calls at once, we need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)
# TODO: Declare this as a partial of ChatCompletionMessageToolCall
tool_calls: Dict[int, ChoiceDeltaToolCall] = {}

def update(self, *tool_calls: ChoiceDeltaToolCall):
for tool_call in tool_calls:
in_progress_call = self.tool_calls.get(tool_call.index)

if in_progress_call:
in_progress_call.function.arguments += tool_call.function.arguments
else:
self.tool_calls[tool_call.index] = tool_call.model_copy()

def finalize(self) -> List[ChatCompletionMessageToolCall]:
return [
ChatCompletionMessageToolCall(
id=tool_call.id,
function=Function(name=tool_call.function.name, arguments=tool_call.function.arguments),
type=tool_call.type,
)
for tool_call in self.tool_calls.values()
]

async def run(self, function_registry: FunctionRegistry):
tool_calls = self.finalize()

tasks: List[asyncio.Future[ChatCompletionToolMessageParam]] = [
function_registry.run_tool(tool) for tool in tool_calls
]

for future in asyncio.as_completed(tasks):
response = await future
yield response
Loading
Loading