Skip to content

Commit

Permalink
Merge pull request #96 from rgbkrk/openai-migrate-api
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk authored Oct 14, 2023
2 parents 3f0193b + 09cd981 commit 334a92b
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 695 deletions.
195 changes: 89 additions & 106 deletions chatlab/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,27 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, cast, overload
from typing import AsyncIterator, Callable, List, Optional, Tuple, Type, Union, overload

import openai
import openai.error
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView

from ._version import __version__
from .display import ChatFunctionCall
from .errors import ChatLabError
from .messaging import (
ChatCompletion,
Message,
StreamCompletion,
human,
is_full_choice,
is_function_call,
is_stream_choice,
)
from .registry import FunctionRegistry, PythonHallucinationFunction
from .messaging import human
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

logger = logging.getLogger(__name__)


@dataclass
class ContentDelta:
"""A delta that contains markdown."""

content: str


@dataclass
class FunctionCallArgumentsDelta:
"""A delta that contains function call arguments."""

arguments: str


@dataclass
class FunctionCallNameDelta:
"""A delta that contains function call name."""

name: str


def process_delta(delta):
"""Process a delta."""
if 'content' in delta and delta['content'] is not None:
yield ContentDelta(delta['content'])

elif 'function_call' in delta: # If the delta contains a function call
if 'name' in delta['function_call']:
yield FunctionCallNameDelta(delta['function_call']['name'])

if 'arguments' in delta['function_call']:
yield FunctionCallArgumentsDelta(delta['function_call']['arguments'])


class Chat:
"""Interactive chats inside of computational notebooks, relying on OpenAI's API.
Expand All @@ -85,7 +43,7 @@ class Chat:
History is tracked and can be used to continue a conversation.
Args:
initial_context (str | Message): The initial context for the conversation.
initial_context (str | ChatCompletionMessageParam): The initial context for the conversation.
model (str): The model to use for the conversation.
Expand All @@ -102,14 +60,14 @@ class Chat:
"""

messages: List[Message]
messages: List[ChatCompletionMessageParam]
model: str
function_registry: FunctionRegistry
allow_hallucinated_python: bool

def __init__(
self,
*initial_context: Union[Message, str],
*initial_context: Union[ChatCompletionMessageParam, str],
model="gpt-3.5-turbo-0613",
function_registry: Optional[FunctionRegistry] = None,
chat_functions: Optional[List[Callable]] = None,
Expand Down Expand Up @@ -141,7 +99,7 @@ def __init__(
if initial_context is None:
initial_context = [] # type: ignore

self.messages: List[Message] = []
self.messages: List[ChatCompletionMessageParam] = []

self.append(*initial_context)
self.model = model
Expand All @@ -164,65 +122,52 @@ def __init__(
)
def chat(
self,
*messages: Union[Message, str],
*messages: Union[ChatCompletionMessageParam, str],
):
"""Send messages to the chat model and display the response.
Deprecated in 0.13.0, removed in 1.0.0. Use `submit` instead.
"""
raise Exception("This method is deprecated. Use `submit` instead.")

async def __call__(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response."""
return await self.submit(*messages, stream=stream, **kwargs)

async def __process_stream(
self, resp: Iterable[Union[StreamCompletion, ChatCompletion]]
self, resp: AsyncIterator[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
finish_reason = None

for result in resp: # Go through the results of the stream
if not isinstance(result, dict):
logger.warning(f"Unknown result type: {type(result)}: {result}")
continue

choices = result.get('choices', [])
async for result in resp: # Go through the results of the stream
choices = result.choices

if len(choices) == 0:
logger.warning(f"Result has no choices: {result}")
continue

choice = choices[0]

if is_stream_choice(choice): # If there is a delta in the result
delta = choice['delta']

for event in process_delta(delta):
if isinstance(event, ContentDelta):
assistant_view.append(event.content)
elif isinstance(event, FunctionCallNameDelta):
# Is stream choice?
if choice.delta is not None:
if choice.delta.content is not None:
assistant_view.append(choice.delta.content)
elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
if assistant_view.in_progress():
# Flush out the finished assistant message
message = assistant_view.flush()
self.append(message)
function_view = AssistantFunctionCallView(event.name)
elif isinstance(event, FunctionCallArgumentsDelta):
function_view = AssistantFunctionCallView(function_call.name)
if function_call.arguments is not None:
if function_view is None:
raise ValueError("Function arguments provided without function name")
function_view.append(event.arguments)
elif is_full_choice(choice):
message = choice['message']

if is_function_call(message):
function_view = AssistantFunctionCallView(message['function_call']['name'])
function_view.append(message['function_call']['arguments'])
elif 'content' in message and message['content'] is not None:
assistant_view.append(message['content'])

if 'finish_reason' in choice and choice['finish_reason'] is not None:
finish_reason = choice['finish_reason']
function_view.append(function_call.arguments)
if choice.finish_reason is not None:
finish_reason = choice.finish_reason
break

# Wrap up the previous assistant
Expand All @@ -236,7 +181,29 @@ async def __process_stream(

return (finish_reason, function_view)

async def submit(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None

if len(resp.choices) == 0:
logger.warning(f"Result has no choices: {resp}")
return ("stop", None) # TODO

choice = resp.choices[0]

message = choice.message

if message.content is not None:
assistant_view.append(message.content)
self.append(assistant_view.flush())
if message.function_call is not None:
function_call = message.function_call
function_view = AssistantFunctionCallView(function_name=function_call.name)
function_view.append(function_call.arguments)

return choice.finish_reason, function_view

async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Side effects:
Expand All @@ -245,13 +212,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
- chat.messages are updated with response(s).
Args:
messages (str | Message): One or more messages to send to the chat, can be strings or Message objects.
messages (str | ChatCompletionMessageParam): One or more messages to send to the chat, can be strings or
ChatCompletionMessageParam objects.
stream (bool): Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
stream: Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
"""

full_messages: List[Message] = []
full_messages: List[ChatCompletionMessageParam] = []
full_messages.extend(self.messages)
for message in messages:
if isinstance(message, str):
Expand All @@ -260,14 +227,33 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
full_messages.append(message)

try:
resp = openai.ChatCompletion.create(
model=self.model,
messages=full_messages,
**self.function_registry.api_manifest(),
stream=stream,
temperature=kwargs.get("temperature", 0),
)
except openai.error.RateLimitError as e:
client = AsyncOpenAI()

manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**manifest,
stream=True,
temperature=kwargs.get("temperature", 0),
)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
else:
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**manifest,
stream=False,
temperature=kwargs.get("temperature", 0),
)

finish_reason, function_call_request = await self.__process_full_completion(full_response)

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
await asyncio.sleep(5)
await self.submit(*messages, stream=stream, **kwargs)
Expand All @@ -276,13 +262,6 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw

self.append(*messages)

if not stream:
resp = [resp]

resp = cast(Iterable[Union[StreamCompletion, ChatCompletion]], resp)

finish_reason, function_call_request = await self.__process_stream(resp)

if finish_reason == "function_call":
if function_call_request is None:
raise ValueError(
Expand Down Expand Up @@ -317,13 +296,13 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
f"UNKNOWN FINISH REASON: '{finish_reason}'. If you see this message, report it as an issue to https://github.com/rgbkrk/chatlab/issues" # noqa: E501
)

def append(self, *messages: Union[Message, str]):
def append(self, *messages: Union[ChatCompletionMessageParam, str]):
"""Append messages to the conversation history.
Note: this does not send the messages on until `chat` is called.
Args:
messages (str | Message): One or more messages to append to the conversation.
messages (str | ChatCompletionMessageParam): One or more messages to append to the conversation.
"""
# Messages are either a dict respecting the {role, content} format or a str that we convert to a human message
Expand All @@ -340,13 +319,17 @@ def register(
...

@overload
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
def register(
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> FunctionSchema:
...

def register(
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Union[Callable, Dict]:
"""Register a function with the ChatLab instance. This can be used as a decorator like so:
) -> Union[Callable, FunctionSchema]:
"""Register a function with the ChatLab instance.
This can be used as a decorator like so:
>>> from chatlab import Chat
>>> chat = Chat()
Expand Down
6 changes: 4 additions & 2 deletions chatlab/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from typing import Optional

from openai.types.chat import ChatCompletionMessageParam

from .components.function_details import ChatFunctionComponent
from .messaging import Message, function_result, system
from .messaging import function_result, system
from .registry import FunctionArgumentError, FunctionRegistry, UnknownFunctionError
from .views.abstracts import AutoDisplayer

Expand Down Expand Up @@ -35,7 +37,7 @@ def __init__(
self._display_id = display_id
self.update_displays()

async def call(self) -> Message:
async def call(self) -> ChatCompletionMessageParam:
"""Call the function and return a stack of messages for LLM and human consumption."""
function_name = self.function_name
function_args = self.function_args
Expand Down
Loading

0 comments on commit 334a92b

Please sign in to comment.