24
24
from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
25
25
from pydantic import BaseModel
26
26
27
- from chatlab .views .assistant_function_call import AssistantFunctionCallView
28
-
29
27
from ._version import __version__
30
- from .display import ChatFunctionCall
31
28
from .errors import ChatLabError
32
29
from .messaging import human
33
30
from .registry import FunctionRegistry , PythonHallucinationFunction
34
- from .views . assistant import AssistantMessageView
31
+ from .views import ToolArguments , AssistantMessageView
35
32
36
33
logger = logging .getLogger (__name__ )
37
34
@@ -145,9 +142,9 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre
145
142
146
143
async def __process_stream (
147
144
self , resp : AsyncStream [ChatCompletionChunk ]
148
- ) -> Tuple [str , Optional [AssistantFunctionCallView ]]:
149
- assistant_view : AssistantMessageView = AssistantMessageView ()
150
- function_view : Optional [AssistantFunctionCallView ] = None
145
+ ) -> Tuple [str , Optional [ToolArguments ]]:
146
+ assistant_view : AssistantMessageView = AssistantMessageView (content = "" )
147
+ function_view : Optional [ToolArguments ] = None
151
148
finish_reason = None
152
149
153
150
async for result in resp : # Go through the results of the stream
@@ -162,37 +159,41 @@ async def __process_stream(
162
159
# Is stream choice?
163
160
if choice .delta is not None :
164
161
if choice .delta .content is not None :
162
+ assistant_view .display_once ()
165
163
assistant_view .append (choice .delta .content )
166
164
elif choice .delta .function_call is not None :
167
165
function_call = choice .delta .function_call
168
166
if function_call .name is not None :
169
- if assistant_view .in_progress () :
167
+ if not assistant_view .finished :
170
168
# Flush out the finished assistant message
171
- message = assistant_view .flush ()
169
+ message = assistant_view .get_message ()
172
170
self .append (message )
173
- function_view = AssistantFunctionCallView (function_call .name )
171
+ # IDs are for the tool calling apparatus from newer versions of the API
172
+ # We will make use of it later.
173
+ function_view = ToolArguments (id = "TBD" , name = function_call .name )
174
+ function_view .display ()
174
175
if function_call .arguments is not None :
175
176
if function_view is None :
176
177
raise ValueError ("Function arguments provided without function name" )
177
- function_view .append (function_call .arguments )
178
+ function_view .append_arguments (function_call .arguments )
178
179
if choice .finish_reason is not None :
179
180
finish_reason = choice .finish_reason
180
181
break
181
182
182
183
# Wrap up the previous assistant
183
184
# Note: This will also wrap up the assistant's message when it ran out of tokens
184
- if assistant_view .in_progress () :
185
- message = assistant_view .flush ()
185
+ if not assistant_view .finished :
186
+ message = assistant_view .get_message ()
186
187
self .append (message )
187
188
188
189
if finish_reason is None :
189
190
raise ValueError ("No finish reason provided by OpenAI" )
190
191
191
192
return (finish_reason , function_view )
192
193
193
- async def __process_full_completion (self , resp : ChatCompletion ) -> Tuple [str , Optional [AssistantFunctionCallView ]]:
194
- assistant_view : AssistantMessageView = AssistantMessageView ()
195
- function_view : Optional [AssistantFunctionCallView ] = None
194
+ async def __process_full_completion (self , resp : ChatCompletion ) -> Tuple [str , Optional [ToolArguments ]]:
195
+ assistant_view : AssistantMessageView = AssistantMessageView (content = "" )
196
+ function_view : Optional [ToolArguments ] = None
196
197
197
198
if len (resp .choices ) == 0 :
198
199
logger .warning (f"Result has no choices: { resp } " )
@@ -203,12 +204,13 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op
203
204
message = choice .message
204
205
205
206
if message .content is not None :
207
+ assistant_view .display_once ()
206
208
assistant_view .append (message .content )
207
- self .append (assistant_view .flush ())
209
+ self .append (assistant_view .get_message ())
208
210
if message .function_call is not None :
209
211
function_call = message .function_call
210
- function_view = AssistantFunctionCallView ( function_name = function_call .name )
211
- function_view .append ( function_call . arguments )
212
+ function_view = ToolArguments ( id = "TBD" , name = function_call .name , arguments = function_call . arguments )
213
+ function_view .display ( )
212
214
213
215
return choice .finish_reason , function_view
214
216
@@ -286,17 +288,12 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
286
288
"Function call was the stated function_call reason without having a complete function call. If you see this, report it as an issue to https://github.com/rgbkrk/chatlab/issues" # noqa: E501
287
289
)
288
290
# Record the attempted call from the LLM
289
- self .append (function_call_request .get_message ())
291
+ self .append (function_call_request .get_function_message ())
290
292
291
- chat_function = ChatFunctionCall (
292
- ** function_call_request .finalize (),
293
- function_registry = self .function_registry ,
294
- )
293
+ function_called = await function_call_request .call (function_registry = self .function_registry )
295
294
296
- # Make the call
297
- fn_message = await chat_function .call ()
298
295
# Include the response (or error) for the model
299
- self .append (fn_message )
296
+ self .append (function_called . get_function_called_message () )
300
297
301
298
# Reply back to the LLM with the result of the function call, allow it to continue
302
299
await self .submit (stream = stream , ** kwargs )
0 commit comments