Skip to content

Commit 334a92b

Browse files
authored
Merge pull request #96 from rgbkrk/openai-migrate-api
2 parents 3f0193b + 09cd981 commit 334a92b

File tree

9 files changed

+176
-695
lines changed

9 files changed

+176
-695
lines changed

chatlab/conversation.py

Lines changed: 89 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -14,69 +14,27 @@
1414
import asyncio
1515
import logging
1616
import os
17-
from dataclasses import dataclass
18-
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, cast, overload
17+
from typing import AsyncIterator, Callable, List, Optional, Tuple, Type, Union, overload
1918

2019
import openai
21-
import openai.error
2220
from deprecation import deprecated
2321
from IPython.core.async_helpers import get_asyncio_loop
22+
from openai import AsyncOpenAI
23+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
2424
from pydantic import BaseModel
2525

2626
from chatlab.views.assistant_function_call import AssistantFunctionCallView
2727

2828
from ._version import __version__
2929
from .display import ChatFunctionCall
3030
from .errors import ChatLabError
31-
from .messaging import (
32-
ChatCompletion,
33-
Message,
34-
StreamCompletion,
35-
human,
36-
is_full_choice,
37-
is_function_call,
38-
is_stream_choice,
39-
)
40-
from .registry import FunctionRegistry, PythonHallucinationFunction
31+
from .messaging import human
32+
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
4133
from .views.assistant import AssistantMessageView
4234

4335
logger = logging.getLogger(__name__)
4436

4537

46-
@dataclass
47-
class ContentDelta:
48-
"""A delta that contains markdown."""
49-
50-
content: str
51-
52-
53-
@dataclass
54-
class FunctionCallArgumentsDelta:
55-
"""A delta that contains function call arguments."""
56-
57-
arguments: str
58-
59-
60-
@dataclass
61-
class FunctionCallNameDelta:
62-
"""A delta that contains function call name."""
63-
64-
name: str
65-
66-
67-
def process_delta(delta):
68-
"""Process a delta."""
69-
if 'content' in delta and delta['content'] is not None:
70-
yield ContentDelta(delta['content'])
71-
72-
elif 'function_call' in delta: # If the delta contains a function call
73-
if 'name' in delta['function_call']:
74-
yield FunctionCallNameDelta(delta['function_call']['name'])
75-
76-
if 'arguments' in delta['function_call']:
77-
yield FunctionCallArgumentsDelta(delta['function_call']['arguments'])
78-
79-
8038
class Chat:
8139
"""Interactive chats inside of computational notebooks, relying on OpenAI's API.
8240
@@ -85,7 +43,7 @@ class Chat:
8543
History is tracked and can be used to continue a conversation.
8644
8745
Args:
88-
initial_context (str | Message): The initial context for the conversation.
46+
initial_context (str | ChatCompletionMessageParam): The initial context for the conversation.
8947
9048
model (str): The model to use for the conversation.
9149
@@ -102,14 +60,14 @@ class Chat:
10260
10361
"""
10462

105-
messages: List[Message]
63+
messages: List[ChatCompletionMessageParam]
10664
model: str
10765
function_registry: FunctionRegistry
10866
allow_hallucinated_python: bool
10967

11068
def __init__(
11169
self,
112-
*initial_context: Union[Message, str],
70+
*initial_context: Union[ChatCompletionMessageParam, str],
11371
model="gpt-3.5-turbo-0613",
11472
function_registry: Optional[FunctionRegistry] = None,
11573
chat_functions: Optional[List[Callable]] = None,
@@ -141,7 +99,7 @@ def __init__(
14199
if initial_context is None:
142100
initial_context = [] # type: ignore
143101

144-
self.messages: List[Message] = []
102+
self.messages: List[ChatCompletionMessageParam] = []
145103

146104
self.append(*initial_context)
147105
self.model = model
@@ -164,65 +122,52 @@ def __init__(
164122
)
165123
def chat(
166124
self,
167-
*messages: Union[Message, str],
125+
*messages: Union[ChatCompletionMessageParam, str],
168126
):
169127
"""Send messages to the chat model and display the response.
170128
171129
Deprecated in 0.13.0, removed in 1.0.0. Use `submit` instead.
172130
"""
173131
raise Exception("This method is deprecated. Use `submit` instead.")
174132

175-
async def __call__(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
133+
async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
176134
"""Send messages to the chat model and display the response."""
177135
return await self.submit(*messages, stream=stream, **kwargs)
178136

179137
async def __process_stream(
180-
self, resp: Iterable[Union[StreamCompletion, ChatCompletion]]
138+
self, resp: AsyncIterator[ChatCompletionChunk]
181139
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
182140
assistant_view: AssistantMessageView = AssistantMessageView()
183141
function_view: Optional[AssistantFunctionCallView] = None
184142
finish_reason = None
185143

186-
for result in resp: # Go through the results of the stream
187-
if not isinstance(result, dict):
188-
logger.warning(f"Unknown result type: {type(result)}: {result}")
189-
continue
190-
191-
choices = result.get('choices', [])
144+
async for result in resp: # Go through the results of the stream
145+
choices = result.choices
192146

193147
if len(choices) == 0:
194148
logger.warning(f"Result has no choices: {result}")
195149
continue
196150

197151
choice = choices[0]
198152

199-
if is_stream_choice(choice): # If there is a delta in the result
200-
delta = choice['delta']
201-
202-
for event in process_delta(delta):
203-
if isinstance(event, ContentDelta):
204-
assistant_view.append(event.content)
205-
elif isinstance(event, FunctionCallNameDelta):
153+
# Is stream choice?
154+
if choice.delta is not None:
155+
if choice.delta.content is not None:
156+
assistant_view.append(choice.delta.content)
157+
elif choice.delta.function_call is not None:
158+
function_call = choice.delta.function_call
159+
if function_call.name is not None:
206160
if assistant_view.in_progress():
207161
# Flush out the finished assistant message
208162
message = assistant_view.flush()
209163
self.append(message)
210-
function_view = AssistantFunctionCallView(event.name)
211-
elif isinstance(event, FunctionCallArgumentsDelta):
164+
function_view = AssistantFunctionCallView(function_call.name)
165+
if function_call.arguments is not None:
212166
if function_view is None:
213167
raise ValueError("Function arguments provided without function name")
214-
function_view.append(event.arguments)
215-
elif is_full_choice(choice):
216-
message = choice['message']
217-
218-
if is_function_call(message):
219-
function_view = AssistantFunctionCallView(message['function_call']['name'])
220-
function_view.append(message['function_call']['arguments'])
221-
elif 'content' in message and message['content'] is not None:
222-
assistant_view.append(message['content'])
223-
224-
if 'finish_reason' in choice and choice['finish_reason'] is not None:
225-
finish_reason = choice['finish_reason']
168+
function_view.append(function_call.arguments)
169+
if choice.finish_reason is not None:
170+
finish_reason = choice.finish_reason
226171
break
227172

228173
# Wrap up the previous assistant
@@ -236,7 +181,29 @@ async def __process_stream(
236181

237182
return (finish_reason, function_view)
238183

239-
async def submit(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
184+
async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]:
185+
assistant_view: AssistantMessageView = AssistantMessageView()
186+
function_view: Optional[AssistantFunctionCallView] = None
187+
188+
if len(resp.choices) == 0:
189+
logger.warning(f"Result has no choices: {resp}")
190+
return ("stop", None) # TODO
191+
192+
choice = resp.choices[0]
193+
194+
message = choice.message
195+
196+
if message.content is not None:
197+
assistant_view.append(message.content)
198+
self.append(assistant_view.flush())
199+
if message.function_call is not None:
200+
function_call = message.function_call
201+
function_view = AssistantFunctionCallView(function_name=function_call.name)
202+
function_view.append(function_call.arguments)
203+
204+
return choice.finish_reason, function_view
205+
206+
async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
240207
"""Send messages to the chat model and display the response.
241208
242209
Side effects:
@@ -245,13 +212,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
245212
- chat.messages are updated with response(s).
246213
247214
Args:
248-
messages (str | Message): One or more messages to send to the chat, can be strings or Message objects.
215+
messages (str | ChatCompletionMessageParam): One or more messages to send to the chat, can be strings or
216+
ChatCompletionMessageParam objects.
249217
250-
stream (bool): Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
218+
stream: Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
251219
252220
"""
253-
254-
full_messages: List[Message] = []
221+
full_messages: List[ChatCompletionMessageParam] = []
255222
full_messages.extend(self.messages)
256223
for message in messages:
257224
if isinstance(message, str):
@@ -260,14 +227,33 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
260227
full_messages.append(message)
261228

262229
try:
263-
resp = openai.ChatCompletion.create(
264-
model=self.model,
265-
messages=full_messages,
266-
**self.function_registry.api_manifest(),
267-
stream=stream,
268-
temperature=kwargs.get("temperature", 0),
269-
)
270-
except openai.error.RateLimitError as e:
230+
client = AsyncOpenAI()
231+
232+
manifest = self.function_registry.api_manifest()
233+
234+
# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately
235+
if stream:
236+
streaming_response = await client.chat.completions.create(
237+
model=self.model,
238+
messages=full_messages,
239+
**manifest,
240+
stream=True,
241+
temperature=kwargs.get("temperature", 0),
242+
)
243+
244+
finish_reason, function_call_request = await self.__process_stream(streaming_response)
245+
else:
246+
full_response = await client.chat.completions.create(
247+
model=self.model,
248+
messages=full_messages,
249+
**manifest,
250+
stream=False,
251+
temperature=kwargs.get("temperature", 0),
252+
)
253+
254+
finish_reason, function_call_request = await self.__process_full_completion(full_response)
255+
256+
except openai.RateLimitError as e:
271257
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
272258
await asyncio.sleep(5)
273259
await self.submit(*messages, stream=stream, **kwargs)
@@ -276,13 +262,6 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
276262

277263
self.append(*messages)
278264

279-
if not stream:
280-
resp = [resp]
281-
282-
resp = cast(Iterable[Union[StreamCompletion, ChatCompletion]], resp)
283-
284-
finish_reason, function_call_request = await self.__process_stream(resp)
285-
286265
if finish_reason == "function_call":
287266
if function_call_request is None:
288267
raise ValueError(
@@ -317,13 +296,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
317296
f"UNKNOWN FINISH REASON: '{finish_reason}'. If you see this message, report it as an issue to https://github.com/rgbkrk/chatlab/issues" # noqa: E501
318297
)
319298

320-
def append(self, *messages: Union[Message, str]):
299+
def append(self, *messages: Union[ChatCompletionMessageParam, str]):
321300
"""Append messages to the conversation history.
322301
323302
Note: this does not send the messages on until `chat` is called.
324303
325304
Args:
326-
messages (str | Message): One or more messages to append to the conversation.
305+
messages (str | ChatCompletionMessageParam): One or more messages to append to the conversation.
327306
328307
"""
329308
# Messages are either a dict respecting the {role, content} format or a str that we convert to a human message
@@ -340,13 +319,17 @@ def register(
340319
...
341320

342321
@overload
343-
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
322+
def register(
323+
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
324+
) -> FunctionSchema:
344325
...
345326

346327
def register(
347328
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
348-
) -> Union[Callable, Dict]:
349-
"""Register a function with the ChatLab instance. This can be used as a decorator like so:
329+
) -> Union[Callable, FunctionSchema]:
330+
"""Register a function with the ChatLab instance.
331+
332+
This can be used as a decorator like so:
350333
351334
>>> from chatlab import Chat
352335
>>> chat = Chat()

chatlab/display.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from typing import Optional
44

5+
from openai.types.chat import ChatCompletionMessageParam
6+
57
from .components.function_details import ChatFunctionComponent
6-
from .messaging import Message, function_result, system
8+
from .messaging import function_result, system
79
from .registry import FunctionArgumentError, FunctionRegistry, UnknownFunctionError
810
from .views.abstracts import AutoDisplayer
911

@@ -35,7 +37,7 @@ def __init__(
3537
self._display_id = display_id
3638
self.update_displays()
3739

38-
async def call(self) -> Message:
40+
async def call(self) -> ChatCompletionMessageParam:
3941
"""Call the function and return a stack of messages for LLM and human consumption."""
4042
function_name = self.function_name
4143
function_args = self.function_args

0 commit comments

Comments
 (0)