Skip to content

Commit

Permalink
push tools rather than function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Jan 23, 2024
1 parent 8d91afc commit 86f516c
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 102 deletions.
64 changes: 50 additions & 14 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI, AsyncStream
from openai.types import FunctionDefinition
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionAssistantMessageParam,
)
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView
Expand All @@ -33,6 +38,9 @@
from .registry import FunctionRegistry, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

from .tool_call import ToolCallBuilder


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -145,11 +153,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre

async def __process_stream(
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
) -> Tuple[str, Optional[AssistantFunctionCallView], ToolCallBuilder]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
finish_reason = None

tool_call_builder = ToolCallBuilder()

async for result in resp: # Go through the results of the stream
choices = result.choices

Expand All @@ -175,6 +185,10 @@ async def __process_stream(
if function_view is None:
raise ValueError("Function arguments provided without function name")
function_view.append(function_call.arguments)

elif choice.delta.tool_calls is not None:
tool_call_builder.update(*choice.delta.tool_calls)

if choice.finish_reason is not None:
finish_reason = choice.finish_reason
break
Expand All @@ -188,9 +202,11 @@ async def __process_stream(
if finish_reason is None:
raise ValueError("No finish reason provided by OpenAI")

return (finish_reason, function_view)
return (finish_reason, function_view, tool_call_builder)

async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]:
async def __process_full_completion(
self, resp: ChatCompletion
) -> Tuple[str, Optional[AssistantFunctionCallView], ToolCallBuilder]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None

Expand All @@ -202,15 +218,19 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op

message = choice.message

tool_call_builder = ToolCallBuilder()

if message.content is not None:
assistant_view.append(message.content)
self.append(assistant_view.flush())
if message.function_call is not None:
function_call = message.function_call
function_view = AssistantFunctionCallView(function_name=function_call.name)
function_view.append(function_call.arguments)
if message.tool_calls is not None:
raise NotImplementedError("Tool calls are not yet supported")

return choice.finish_reason, function_view
return choice.finish_reason, function_view, tool_call_builder

async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Expand Down Expand Up @@ -241,37 +261,36 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=True,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
finish_reason, function_call_request, tool_call_builder = await self.__process_stream(
streaming_response
)
else:
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=False,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

(
finish_reason,
function_call_request,
) = await self.__process_full_completion(full_response)
(finish_reason, function_call_request, tool_call_builder) = await self.__process_full_completion(
full_response
)

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
Expand All @@ -280,6 +299,23 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream

return

if finish_reason == "tool_calls":
tool_calls = tool_call_builder.finalize()
self.append(
ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=tool_calls,
# TODO: content and name may have been specified. I did not collect it above during streaming.
)
)

async for message in tool_call_builder.run(self.function_registry):
# TODO: Reintroduce the visual display
self.append(message)

await self.submit(stream=stream, **kwargs)
return

if finish_reason == "function_call":
if function_call_request is None:
raise ValueError(
Expand Down
1 change: 0 additions & 1 deletion chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def assistant_function_call(name: str, arguments: Optional[str] = None) -> ChatC
},
}


def function_result(name: str, content: str) -> ChatCompletionMessageParam:
"""Create a function result message.
Expand Down
12 changes: 12 additions & 0 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ class WhatTime(BaseModel):
from openai.types.chat.completion_create_params import Function, FunctionCall
from pydantic import BaseModel, create_model

from chatlab.messaging import tool_result

from .decorators import ChatlabMetadata

from openai.types.chat import ChatCompletionToolMessageParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall


class APIManifest(TypedDict, total=False):
"""The schema for the API."""
Expand Down Expand Up @@ -482,6 +487,13 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
result = function(**prepared_arguments)
return result

async def run_tool(self, tool: ChatCompletionMessageToolCall) -> ChatCompletionToolMessageParam:
result = await self.call(tool.function.name, tool.function.arguments)
# TODO: Bring over the special result formatter from chatlab.Chat
tool_call_response = tool_result(tool.id, content=str(result))
# tool_call_response["name"] = tool.function.name
return tool_call_response

def __contains__(self, name) -> bool:
"""Check if a function is registered by name."""
if name == "python" and self.python_hallucination_function:
Expand Down
51 changes: 51 additions & 0 deletions chatlab/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
)

from openai.types.chat import ChatCompletionToolMessageParam

from pydantic import BaseModel

from typing import Dict, List

from chatlab.registry import FunctionRegistry

# We need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)


class ToolCallBuilder(BaseModel):
# Assuming that we might not get all the tool calls at once, we need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)
# TODO: Declare this as a partial of ChatCompletionMessageToolCall
tool_calls: Dict[int, ChoiceDeltaToolCall] = {}

def update(self, *tool_calls: ChoiceDeltaToolCall):
for tool_call in tool_calls:
in_progress_call = self.tool_calls.get(tool_call.index)

if in_progress_call:
in_progress_call.function.arguments += tool_call.function.arguments
else:
self.tool_calls[tool_call.index] = tool_call.model_copy()

def finalize(self) -> List[ChatCompletionMessageToolCall]:
return [
ChatCompletionMessageToolCall(
id=tool_call.id,
function=Function(name=tool_call.function.name, arguments=tool_call.function.arguments),
type=tool_call.type,
)
for tool_call in self.tool_calls.values()
]

async def run(self, function_registry: FunctionRegistry):
tool_calls = self.finalize()

tasks: List[asyncio.Future[ChatCompletionToolMessageParam]] = [
function_registry.run_tool(tool) for tool in tool_calls
]

for future in asyncio.as_completed(tasks):
response = await future
yield response
Loading

0 comments on commit 86f516c

Please sign in to comment.