Skip to content

migrate to openai beta #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 45 additions & 90 deletions chatlab/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,69 +14,28 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, cast, overload
from typing import Callable, List, Optional, Tuple, Type, Union, overload

import openai
import openai.error
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai._streaming import AsyncStream
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView

from ._version import __version__
from .display import ChatFunctionCall
from .errors import ChatLabError
from .messaging import (
ChatCompletion,
Message,
StreamCompletion,
human,
is_full_choice,
is_function_call,
is_stream_choice,
)
from .registry import FunctionRegistry, PythonHallucinationFunction
from .messaging import Message, human
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

logger = logging.getLogger(__name__)


@dataclass
class ContentDelta:
"""A delta that contains markdown."""

content: str


@dataclass
class FunctionCallArgumentsDelta:
"""A delta that contains function call arguments."""

arguments: str


@dataclass
class FunctionCallNameDelta:
"""A delta that contains function call name."""

name: str


def process_delta(delta):
"""Process a delta."""
if 'content' in delta and delta['content'] is not None:
yield ContentDelta(delta['content'])

elif 'function_call' in delta: # If the delta contains a function call
if 'name' in delta['function_call']:
yield FunctionCallNameDelta(delta['function_call']['name'])

if 'arguments' in delta['function_call']:
yield FunctionCallArgumentsDelta(delta['function_call']['arguments'])


class Chat:
"""Interactive chats inside of computational notebooks, relying on OpenAI's API.

Expand Down Expand Up @@ -172,57 +131,44 @@ def chat(
"""
raise Exception("This method is deprecated. Use `submit` instead.")

async def __call__(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
async def __call__(self, *messages: Union[Message, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response."""
return await self.submit(*messages, stream=stream, **kwargs)

async def __process_stream(
self, resp: Iterable[Union[StreamCompletion, ChatCompletion]]
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
finish_reason = None

for result in resp: # Go through the results of the stream
if not isinstance(result, dict):
logger.warning(f"Unknown result type: {type(result)}: {result}")
continue

choices = result.get('choices', [])
async for result in resp: # Go through the results of the stream
choices = result.choices

if len(choices) == 0:
logger.warning(f"Result has no choices: {result}")
continue

choice = choices[0]

if is_stream_choice(choice): # If there is a delta in the result
delta = choice['delta']

for event in process_delta(delta):
if isinstance(event, ContentDelta):
assistant_view.append(event.content)
elif isinstance(event, FunctionCallNameDelta):
# Is stream choice?
if choice.delta is not None:
if choice.delta.content is not None:
assistant_view.append(choice.delta.content)
elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
if assistant_view.in_progress():
# Flush out the finished assistant message
message = assistant_view.flush()
self.append(message)
function_view = AssistantFunctionCallView(event.name)
elif isinstance(event, FunctionCallArgumentsDelta):
function_view = AssistantFunctionCallView(function_call.name)
if function_call.arguments is not None:
if function_view is None:
raise ValueError("Function arguments provided without function name")
function_view.append(event.arguments)
elif is_full_choice(choice):
message = choice['message']

if is_function_call(message):
function_view = AssistantFunctionCallView(message['function_call']['name'])
function_view.append(message['function_call']['arguments'])
elif 'content' in message and message['content'] is not None:
assistant_view.append(message['content'])

if 'finish_reason' in choice and choice['finish_reason'] is not None:
finish_reason = choice['finish_reason']
function_view.append(function_call.arguments)
if choice.finish_reason is not None:
finish_reason = choice.finish_reason
break

# Wrap up the previous assistant
Expand All @@ -236,7 +182,7 @@ async def __process_stream(

return (finish_reason, function_view)

async def submit(self, *messages: Union[Message, str], stream: bool = True, **kwargs):
async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.

Side effects:
Expand All @@ -247,10 +193,9 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
Args:
messages (str | Message): One or more messages to send to the chat, can be strings or Message objects.

stream (bool): Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
stream: Whether to stream chat into markdown or not. If False, the entire chat will be sent once.

"""

full_messages: List[Message] = []
full_messages.extend(self.messages)
for message in messages:
Expand All @@ -260,14 +205,22 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw
full_messages.append(message)

try:
resp = openai.ChatCompletion.create(
client = AsyncOpenAI()

manifest = self.function_registry.api_manifest()

resp = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**self.function_registry.api_manifest(),
stream=stream,
**manifest,
# Due to this openai beta migration, we're going to assume
# only streaming and drop the non-streaming case for now until
# types are working right.
stream=True,
temperature=kwargs.get("temperature", 0),
)
except openai.error.RateLimitError as e:

except openai.RateLimitError as e:
logger.error(f"Rate limited: {e}. Waiting 5 seconds and trying again.")
await asyncio.sleep(5)
await self.submit(*messages, stream=stream, **kwargs)
Expand All @@ -276,10 +229,8 @@ async def submit(self, *messages: Union[Message, str], stream: bool = True, **kw

self.append(*messages)

if not stream:
resp = [resp]

resp = cast(Iterable[Union[StreamCompletion, ChatCompletion]], resp)
# if not stream:
# resp = [resp]

finish_reason, function_call_request = await self.__process_stream(resp)

Expand Down Expand Up @@ -340,13 +291,17 @@ def register(
...

@overload
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
def register(
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> FunctionSchema:
...

def register(
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Union[Callable, Dict]:
"""Register a function with the ChatLab instance. This can be used as a decorator like so:
) -> Union[Callable, FunctionSchema]:
"""Register a function with the ChatLab instance.

This can be used as a decorator like so:

>>> from chatlab import Chat
>>> chat = Chat()
Expand Down
119 changes: 8 additions & 111 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,117 +10,14 @@

"""

from typing import List, Optional, TypedDict, Union

from typing_extensions import TypeGuard

BasicMessage = TypedDict(
"BasicMessage",
{
"role": str,
"content": str,
},
)
from typing import Optional

from openai.types.chat import ChatCompletionMessageParam

FunctionCall = TypedDict(
"FunctionCall",
{
"name": str,
"arguments": str,
},
)
Message = ChatCompletionMessageParam

FunctionCallMessage = TypedDict(
"FunctionCallMessage",
{
"role": str,
"content": Optional[str],
"function_call": FunctionCall,
},
)

FunctionResultMessage = TypedDict(
"FunctionResultMessage",
{
"role": str,
"content": str,
"name": str,
},
)


Message = Union[BasicMessage, FunctionCallMessage, FunctionResultMessage]


def is_function_call(message: Message) -> TypeGuard[FunctionCallMessage]:
"""Check if a message is a function call message."""
return 'function_call' in message


def is_basic_message(message: Message) -> TypeGuard[BasicMessage]:
"""Check if a message is a basic message."""
return 'content' in message and 'role' in message and 'function_call' not in message


#### STREAMING ####

Delta = TypedDict(
"Delta",
{
"function_call": FunctionCall,
"content": Optional[str],
},
total=False,
)


StreamChoice = TypedDict(
"StreamChoice",
{
"finish_reason": Optional[str],
"delta": Delta,
},
)

StreamCompletion = TypedDict(
"StreamCompletion",
{
"choices": List[StreamChoice],
},
total=False,
)

#### NON STREAMING ####

FullChoice = TypedDict(
"FullChoice",
{
"finish_reason": Optional[str],
"message": Message,
},
)

ChatCompletion = TypedDict(
"ChatCompletion",
{
"choices": List[FullChoice],
},
total=False,
)


def is_stream_choice(choice: Union[StreamChoice, FullChoice]) -> TypeGuard[StreamChoice]:
"""Check if a choice is a stream choice."""
return 'delta' in choice


def is_full_choice(choice: Union[StreamChoice, FullChoice]) -> TypeGuard[FullChoice]:
"""Check if a choice is a regular choice."""
return 'message' in choice


def assistant(content: str) -> BasicMessage:
def assistant(content: str) -> ChatCompletionMessageParam:
"""Create a message from the assistant.

Args:
Expand All @@ -135,7 +32,7 @@ def assistant(content: str) -> BasicMessage:
}


def user(content: str) -> BasicMessage:
def user(content: str) -> ChatCompletionMessageParam:
"""Create a message from the user.

Args:
Expand All @@ -150,7 +47,7 @@ def user(content: str) -> BasicMessage:
}


def system(content: str) -> BasicMessage:
def system(content: str) -> ChatCompletionMessageParam:
"""Create a message from the system.

Args:
Expand All @@ -165,7 +62,7 @@ def system(content: str) -> BasicMessage:
}


def assistant_function_call(name: str, arguments: Optional[str] = None) -> FunctionCallMessage:
def assistant_function_call(name: str, arguments: Optional[str] = None) -> ChatCompletionMessageParam:
"""Create a function call message from the assistant.

Args:
Expand All @@ -188,7 +85,7 @@ def assistant_function_call(name: str, arguments: Optional[str] = None) -> Funct
}


def function_result(name: str, content: str) -> FunctionResultMessage:
def function_result(name: str, content: str) -> ChatCompletionMessageParam:
"""Create a function result message.

Args:
Expand Down
Loading