14
14
import asyncio
15
15
import logging
16
16
import os
17
- from dataclasses import dataclass
18
- from typing import Callable , Dict , Iterable , List , Optional , Tuple , Type , Union , cast , overload
17
+ from typing import AsyncIterator , Callable , List , Optional , Tuple , Type , Union , overload
19
18
20
19
import openai
21
- import openai .error
22
20
from deprecation import deprecated
23
21
from IPython .core .async_helpers import get_asyncio_loop
22
+ from openai import AsyncOpenAI
23
+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
24
24
from pydantic import BaseModel
25
25
26
26
from chatlab .views .assistant_function_call import AssistantFunctionCallView
27
27
28
28
from ._version import __version__
29
29
from .display import ChatFunctionCall
30
30
from .errors import ChatLabError
31
- from .messaging import (
32
- ChatCompletion ,
33
- Message ,
34
- StreamCompletion ,
35
- human ,
36
- is_full_choice ,
37
- is_function_call ,
38
- is_stream_choice ,
39
- )
40
- from .registry import FunctionRegistry , PythonHallucinationFunction
31
+ from .messaging import human
32
+ from .registry import FunctionRegistry , FunctionSchema , PythonHallucinationFunction
41
33
from .views .assistant import AssistantMessageView
42
34
43
35
logger = logging .getLogger (__name__ )
44
36
45
37
46
- @dataclass
47
- class ContentDelta :
48
- """A delta that contains markdown."""
49
-
50
- content : str
51
-
52
-
53
- @dataclass
54
- class FunctionCallArgumentsDelta :
55
- """A delta that contains function call arguments."""
56
-
57
- arguments : str
58
-
59
-
60
- @dataclass
61
- class FunctionCallNameDelta :
62
- """A delta that contains function call name."""
63
-
64
- name : str
65
-
66
-
67
- def process_delta (delta ):
68
- """Process a delta."""
69
- if 'content' in delta and delta ['content' ] is not None :
70
- yield ContentDelta (delta ['content' ])
71
-
72
- elif 'function_call' in delta : # If the delta contains a function call
73
- if 'name' in delta ['function_call' ]:
74
- yield FunctionCallNameDelta (delta ['function_call' ]['name' ])
75
-
76
- if 'arguments' in delta ['function_call' ]:
77
- yield FunctionCallArgumentsDelta (delta ['function_call' ]['arguments' ])
78
-
79
-
80
38
class Chat :
81
39
"""Interactive chats inside of computational notebooks, relying on OpenAI's API.
82
40
@@ -85,7 +43,7 @@ class Chat:
85
43
History is tracked and can be used to continue a conversation.
86
44
87
45
Args:
88
- initial_context (str | Message ): The initial context for the conversation.
46
+ initial_context (str | ChatCompletionMessageParam ): The initial context for the conversation.
89
47
90
48
model (str): The model to use for the conversation.
91
49
@@ -102,14 +60,14 @@ class Chat:
102
60
103
61
"""
104
62
105
- messages : List [Message ]
63
+ messages : List [ChatCompletionMessageParam ]
106
64
model : str
107
65
function_registry : FunctionRegistry
108
66
allow_hallucinated_python : bool
109
67
110
68
def __init__ (
111
69
self ,
112
- * initial_context : Union [Message , str ],
70
+ * initial_context : Union [ChatCompletionMessageParam , str ],
113
71
model = "gpt-3.5-turbo-0613" ,
114
72
function_registry : Optional [FunctionRegistry ] = None ,
115
73
chat_functions : Optional [List [Callable ]] = None ,
@@ -141,7 +99,7 @@ def __init__(
141
99
if initial_context is None :
142
100
initial_context = [] # type: ignore
143
101
144
- self .messages : List [Message ] = []
102
+ self .messages : List [ChatCompletionMessageParam ] = []
145
103
146
104
self .append (* initial_context )
147
105
self .model = model
@@ -164,65 +122,52 @@ def __init__(
164
122
)
165
123
def chat (
166
124
self ,
167
- * messages : Union [Message , str ],
125
+ * messages : Union [ChatCompletionMessageParam , str ],
168
126
):
169
127
"""Send messages to the chat model and display the response.
170
128
171
129
Deprecated in 0.13.0, removed in 1.0.0. Use `submit` instead.
172
130
"""
173
131
raise Exception ("This method is deprecated. Use `submit` instead." )
174
132
175
- async def __call__ (self , * messages : Union [Message , str ], stream : bool = True , ** kwargs ):
133
+ async def __call__ (self , * messages : Union [ChatCompletionMessageParam , str ], stream = True , ** kwargs ):
176
134
"""Send messages to the chat model and display the response."""
177
135
return await self .submit (* messages , stream = stream , ** kwargs )
178
136
179
137
async def __process_stream (
180
- self , resp : Iterable [ Union [ StreamCompletion , ChatCompletion ] ]
138
+ self , resp : AsyncIterator [ ChatCompletionChunk ]
181
139
) -> Tuple [str , Optional [AssistantFunctionCallView ]]:
182
140
assistant_view : AssistantMessageView = AssistantMessageView ()
183
141
function_view : Optional [AssistantFunctionCallView ] = None
184
142
finish_reason = None
185
143
186
- for result in resp : # Go through the results of the stream
187
- if not isinstance (result , dict ):
188
- logger .warning (f"Unknown result type: { type (result )} : { result } " )
189
- continue
190
-
191
- choices = result .get ('choices' , [])
144
+ async for result in resp : # Go through the results of the stream
145
+ choices = result .choices
192
146
193
147
if len (choices ) == 0 :
194
148
logger .warning (f"Result has no choices: { result } " )
195
149
continue
196
150
197
151
choice = choices [0 ]
198
152
199
- if is_stream_choice ( choice ): # If there is a delta in the result
200
- delta = choice [ 'delta' ]
201
-
202
- for event in process_delta ( delta ):
203
- if isinstance ( event , ContentDelta ) :
204
- assistant_view . append ( event . content )
205
- elif isinstance ( event , FunctionCallNameDelta ) :
153
+ # Is stream choice?
154
+ if choice . delta is not None :
155
+ if choice . delta . content is not None :
156
+ assistant_view . append ( choice . delta . content )
157
+ elif choice . delta . function_call is not None :
158
+ function_call = choice . delta . function_call
159
+ if function_call . name is not None :
206
160
if assistant_view .in_progress ():
207
161
# Flush out the finished assistant message
208
162
message = assistant_view .flush ()
209
163
self .append (message )
210
- function_view = AssistantFunctionCallView (event .name )
211
- elif isinstance ( event , FunctionCallArgumentsDelta ) :
164
+ function_view = AssistantFunctionCallView (function_call .name )
165
+ if function_call . arguments is not None :
212
166
if function_view is None :
213
167
raise ValueError ("Function arguments provided without function name" )
214
- function_view .append (event .arguments )
215
- elif is_full_choice (choice ):
216
- message = choice ['message' ]
217
-
218
- if is_function_call (message ):
219
- function_view = AssistantFunctionCallView (message ['function_call' ]['name' ])
220
- function_view .append (message ['function_call' ]['arguments' ])
221
- elif 'content' in message and message ['content' ] is not None :
222
- assistant_view .append (message ['content' ])
223
-
224
- if 'finish_reason' in choice and choice ['finish_reason' ] is not None :
225
- finish_reason = choice ['finish_reason' ]
168
+ function_view .append (function_call .arguments )
169
+ if choice .finish_reason is not None :
170
+ finish_reason = choice .finish_reason
226
171
break
227
172
228
173
# Wrap up the previous assistant
@@ -236,7 +181,29 @@ async def __process_stream(
236
181
237
182
return (finish_reason , function_view )
238
183
239
- async def submit (self , * messages : Union [Message , str ], stream : bool = True , ** kwargs ):
184
+ async def __process_full_completion (self , resp : ChatCompletion ) -> Tuple [str , Optional [AssistantFunctionCallView ]]:
185
+ assistant_view : AssistantMessageView = AssistantMessageView ()
186
+ function_view : Optional [AssistantFunctionCallView ] = None
187
+
188
+ if len (resp .choices ) == 0 :
189
+ logger .warning (f"Result has no choices: { resp } " )
190
+ return ("stop" , None ) # TODO
191
+
192
+ choice = resp .choices [0 ]
193
+
194
+ message = choice .message
195
+
196
+ if message .content is not None :
197
+ assistant_view .append (message .content )
198
+ self .append (assistant_view .flush ())
199
+ if message .function_call is not None :
200
+ function_call = message .function_call
201
+ function_view = AssistantFunctionCallView (function_name = function_call .name )
202
+ function_view .append (function_call .arguments )
203
+
204
+ return choice .finish_reason , function_view
205
+
206
+ async def submit (self , * messages : Union [ChatCompletionMessageParam , str ], stream = True , ** kwargs ):
240
207
"""Send messages to the chat model and display the response.
241
208
242
209
Side effects:
@@ -245,13 +212,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
245
212
- chat.messages are updated with response(s).
246
213
247
214
Args:
248
- messages (str | Message): One or more messages to send to the chat, can be strings or Message objects.
215
+ messages (str | ChatCompletionMessageParam): One or more messages to send to the chat, can be strings or
216
+ ChatCompletionMessageParam objects.
249
217
250
- stream (bool) : Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
218
+ stream: Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
251
219
252
220
"""
253
-
254
- full_messages : List [Message ] = []
221
+ full_messages : List [ChatCompletionMessageParam ] = []
255
222
full_messages .extend (self .messages )
256
223
for message in messages :
257
224
if isinstance (message , str ):
@@ -260,14 +227,33 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
260
227
full_messages .append (message )
261
228
262
229
try :
263
- resp = openai .ChatCompletion .create (
264
- model = self .model ,
265
- messages = full_messages ,
266
- ** self .function_registry .api_manifest (),
267
- stream = stream ,
268
- temperature = kwargs .get ("temperature" , 0 ),
269
- )
270
- except openai .error .RateLimitError as e :
230
+ client = AsyncOpenAI ()
231
+
232
+ manifest = self .function_registry .api_manifest ()
233
+
234
+ # Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately
235
+ if stream :
236
+ streaming_response = await client .chat .completions .create (
237
+ model = self .model ,
238
+ messages = full_messages ,
239
+ ** manifest ,
240
+ stream = True ,
241
+ temperature = kwargs .get ("temperature" , 0 ),
242
+ )
243
+
244
+ finish_reason , function_call_request = await self .__process_stream (streaming_response )
245
+ else :
246
+ full_response = await client .chat .completions .create (
247
+ model = self .model ,
248
+ messages = full_messages ,
249
+ ** manifest ,
250
+ stream = False ,
251
+ temperature = kwargs .get ("temperature" , 0 ),
252
+ )
253
+
254
+ finish_reason , function_call_request = await self .__process_full_completion (full_response )
255
+
256
+ except openai .RateLimitError as e :
271
257
logger .error (f"Rate limited: { e } . Waiting 5 seconds and trying again." )
272
258
await asyncio .sleep (5 )
273
259
await self .submit (* messages , stream = stream , ** kwargs )
@@ -276,13 +262,6 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
276
262
277
263
self .append (* messages )
278
264
279
- if not stream :
280
- resp = [resp ]
281
-
282
- resp = cast (Iterable [Union [StreamCompletion , ChatCompletion ]], resp )
283
-
284
- finish_reason , function_call_request = await self .__process_stream (resp )
285
-
286
265
if finish_reason == "function_call" :
287
266
if function_call_request is None :
288
267
raise ValueError (
@@ -317,13 +296,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
317
296
f"UNKNOWN FINISH REASON: '{ finish_reason } '. If you see this message, report it as an issue to https://github.com/rgbkrk/chatlab/issues" # noqa: E501
318
297
)
319
298
320
- def append (self , * messages : Union [Message , str ]):
299
+ def append (self , * messages : Union [ChatCompletionMessageParam , str ]):
321
300
"""Append messages to the conversation history.
322
301
323
302
Note: this does not send the messages on until `chat` is called.
324
303
325
304
Args:
326
- messages (str | Message ): One or more messages to append to the conversation.
305
+ messages (str | ChatCompletionMessageParam ): One or more messages to append to the conversation.
327
306
328
307
"""
329
308
# Messages are either a dict respecting the {role, content} format or a str that we convert to a human message
@@ -340,13 +319,17 @@ def register(
340
319
...
341
320
342
321
@overload
343
- def register (self , function : Callable , parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None ) -> Dict :
322
+ def register (
323
+ self , function : Callable , parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None
324
+ ) -> FunctionSchema :
344
325
...
345
326
346
327
def register (
347
328
self , function : Optional [Callable ] = None , parameter_schema : Optional [Union [Type ["BaseModel" ], dict ]] = None
348
- ) -> Union [Callable , Dict ]:
349
- """Register a function with the ChatLab instance. This can be used as a decorator like so:
329
+ ) -> Union [Callable , FunctionSchema ]:
330
+ """Register a function with the ChatLab instance.
331
+
332
+ This can be used as a decorator like so:
350
333
351
334
>>> from chatlab import Chat
352
335
>>> chat = Chat()
0 commit comments