Skip to content

Commit

Permalink
yay tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Jan 23, 2024
1 parent 9078268 commit 17b2ca6
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions chatlab/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
ChoiceDeltaFunctionCall,
Choice,
)

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam

from pydantic import BaseModel

from typing import Dict, List
from chatlab.messaging import tool_result

from chatlab.registry import FunctionRegistry

"""
Example data:
[ChoiceDeltaToolCall(index=0, id='call_hpw4baOCbrhV3ImwxMXl54VC', function=ChoiceDeltaToolCallFunction(arguments='', name='show_colors'), type='function')]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='colors', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='":["', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='#', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='1', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='1', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='1', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='D', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='","#', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='4', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='E', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='4', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='E', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='50', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='","#', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='6', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='F', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='223', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='2', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='","#', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='950', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='740', name=None), type=None)]
...
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='9', name=None), type=None)]
[ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='"]}', name=None), type=None)]
"""

# We need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)


class ToolCallBuilder(BaseModel):
# Assuming that we might not get all the tool calls at once, we need to build a model that will recieve updates to tool_calls (from a a collection of deltas or the whole thing)
# TODO: Declare this as a partial of ChatCompletionMessageToolCall
tool_calls: Dict[int, ChoiceDeltaToolCall] = {}

def update(self, *tool_calls: ChoiceDeltaToolCall):
for tool_call in tool_calls:
in_progress_call = self.tool_calls.get(tool_call.index)

if in_progress_call:
in_progress_call.function.arguments += tool_call.function.arguments
else:
self.tool_calls[tool_call.index] = tool_call.model_copy()

def finalize(self) -> List[ChatCompletionMessageToolCall]:
return [
ChatCompletionMessageToolCall(
id=tool_call.id,
function=Function(name=tool_call.function.name, arguments=tool_call.function.arguments),
type=tool_call.type,
)
for tool_call in self.tool_calls.values()
]

async def run(self, function_registry: FunctionRegistry):
tool_calls = self.finalize()

tasks: List[asyncio.Future[ChatCompletionToolMessageParam]] = [
function_registry.run_tool(tool) for tool in tool_calls
]

for future in asyncio.as_completed(tasks):
response = await future
yield response

0 comments on commit 17b2ca6

Please sign in to comment.