Skip to content

Commit

Permalink
rely directly on ChatCompletionMessageParam
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Oct 14, 2023
1 parent be76355 commit 09cd981
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 28 deletions.
36 changes: 16 additions & 20 deletions chatlab/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView

from ._version import __version__
from .display import ChatFunctionCall
from .errors import ChatLabError
from .messaging import Message, human
from .messaging import human
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

Expand All @@ -43,7 +43,7 @@ class Chat:
History is tracked and can be used to continue a conversation.
Args:
initial_context (str | Message): The initial context for the conversation.
initial_context (str | ChatCompletionMessageParam): The initial context for the conversation.
model (str): The model to use for the conversation.
Expand All @@ -60,14 +60,14 @@ class Chat:
"""

messages: List[Message]
messages: List[ChatCompletionMessageParam]
model: str
function_registry: FunctionRegistry
allow_hallucinated_python: bool

def __init__(
self,
*initial_context: Union[Message, str],
*initial_context: Union[ChatCompletionMessageParam, str],
model="gpt-3.5-turbo-0613",
function_registry: Optional[FunctionRegistry] = None,
chat_functions: Optional[List[Callable]] = None,
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
if initial_context is None:
initial_context = [] # type: ignore

self.messages: List[Message] = []
self.messages: List[ChatCompletionMessageParam] = []

self.append(*initial_context)
self.model = model
Expand All @@ -122,15 +122,15 @@ def __init__(
)
def chat(
self,
*messages: Union[Message, str],
*messages: Union[ChatCompletionMessageParam, str],
):
"""Send messages to the chat model and display the response.
Deprecated in 0.13.0, removed in 1.0.0. Use `submit` instead.
"""
raise Exception("This method is deprecated. Use `submit` instead.")

async def __call__(self, *messages: Union[Message, str], stream=True, **kwargs):
async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response."""
return await self.submit(*messages, stream=stream, **kwargs)

Expand Down Expand Up @@ -195,15 +195,15 @@ async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Op

if message.content is not None:
assistant_view.append(message.content)
assistant_view.flush()
self.append(assistant_view.flush())
if message.function_call is not None:
function_call = message.function_call
function_view = AssistantFunctionCallView(function_name=function_call.name)
function_view.append(function_call.arguments)

return choice.finish_reason, function_view

async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Side effects:
Expand All @@ -212,12 +212,13 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
- chat.messages are updated with response(s).
Args:
messages (str | Message): One or more messages to send to the chat, can be strings or Message objects.
messages (str | ChatCompletionMessageParam): One or more messages to send to the chat, can be strings or
ChatCompletionMessageParam objects.
stream: Whether to stream chat into markdown or not. If False, the entire chat will be sent once.
"""
full_messages: List[Message] = []
full_messages: List[ChatCompletionMessageParam] = []
full_messages.extend(self.messages)
for message in messages:
if isinstance(message, str):
Expand All @@ -230,14 +231,12 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):

manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**manifest,
# Due to this openai beta migration, we're going to assume
# only streaming and drop the non-streaming case for now until
# types are working right.
stream=True,
temperature=kwargs.get("temperature", 0),
)
Expand All @@ -248,9 +247,6 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
model=self.model,
messages=full_messages,
**manifest,
# Due to this openai beta migration, we're going to assume
# only streaming and drop the non-streaming case for now until
# types are working right.
stream=False,
temperature=kwargs.get("temperature", 0),
)
Expand Down Expand Up @@ -300,13 +296,13 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
f"UNKNOWN FINISH REASON: '{finish_reason}'. If you see this message, report it as an issue to https://github.com/rgbkrk/chatlab/issues" # noqa: E501
)

def append(self, *messages: Union[Message, str]):
def append(self, *messages: Union[ChatCompletionMessageParam, str]):
"""Append messages to the conversation history.
Note: this does not send the messages on until `chat` is called.
Args:
messages (str | Message): One or more messages to append to the conversation.
messages (str | ChatCompletionMessageParam): One or more messages to append to the conversation.
"""
# Messages are either a dict respecting the {role, content} format or a str that we convert to a human message
Expand Down
6 changes: 4 additions & 2 deletions chatlab/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from typing import Optional

from openai.types.chat import ChatCompletionMessageParam

from .components.function_details import ChatFunctionComponent
from .messaging import Message, function_result, system
from .messaging import function_result, system
from .registry import FunctionArgumentError, FunctionRegistry, UnknownFunctionError
from .views.abstracts import AutoDisplayer

Expand Down Expand Up @@ -35,7 +37,7 @@ def __init__(
self._display_id = display_id
self.update_displays()

async def call(self) -> Message:
async def call(self) -> ChatCompletionMessageParam:
"""Call the function and return a stack of messages for LLM and human consumption."""
function_name = self.function_name
function_args = self.function_args
Expand Down
2 changes: 0 additions & 2 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from openai.types.chat import ChatCompletionMessageParam

Message = ChatCompletionMessageParam


def assistant(content: str) -> ChatCompletionMessageParam:
"""Create a message from the assistant.
Expand Down
7 changes: 3 additions & 4 deletions chatlab/views/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from binascii import hexlify

from IPython.core import display_functions

from ..messaging import Message
from openai.types.chat import ChatCompletionMessageParam


class AutoDisplayer(ABC):
Expand Down Expand Up @@ -85,11 +84,11 @@ def in_progress(self):
return self.active and not self.is_empty()

@abstractmethod
def get_message(self) -> Message:
def get_message(self) -> ChatCompletionMessageParam:
"""Returns the crafted message. To be overridden in subclasses."""
pass

def flush(self):
def flush(self) -> ChatCompletionMessageParam:
"""Flushes the message buffer."""
message = self.get_message()
self.buffer = self.create_buffer()
Expand Down

0 comments on commit 09cd981

Please sign in to comment.