Skip to content

Commit

Permalink
convert conversation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Oct 14, 2023
1 parent 4771273 commit ee299ab
Showing 1 changed file with 31 additions and 70 deletions.
101 changes: 31 additions & 70 deletions chatlab/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Type, Union, cast, overload
from typing import Callable, List, Optional, Tuple, Type, Union, overload

import openai
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai._streaming import AsyncStream
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView
Expand All @@ -29,46 +30,12 @@
from .display import ChatFunctionCall
from .errors import ChatLabError
from .messaging import Message, human
from .registry import FunctionRegistry, PythonHallucinationFunction
from .registry import FunctionRegistry, FunctionSchema, PythonHallucinationFunction
from .views.assistant import AssistantMessageView

logger = logging.getLogger(__name__)


@dataclass
class ContentDelta:
"""A delta that contains markdown."""

content: str


@dataclass
class FunctionCallArgumentsDelta:
"""A delta that contains function call arguments."""

arguments: str


@dataclass
class FunctionCallNameDelta:
"""A delta that contains function call name."""

name: str


def process_delta(delta):
"""Process a delta."""
if 'content' in delta and delta['content'] is not None:
yield ContentDelta(delta['content'])

elif 'function_call' in delta: # If the delta contains a function call
if 'name' in delta['function_call']:
yield FunctionCallNameDelta(delta['function_call']['name'])

if 'arguments' in delta['function_call']:
yield FunctionCallArgumentsDelta(delta['function_call']['arguments'])


class Chat:
"""Interactive chats inside of computational notebooks, relying on OpenAI's API.
Expand Down Expand Up @@ -169,52 +136,39 @@ async def __call__(self, *messages: Union[Message, str], stream=True, **kwargs):
return await self.submit(*messages, stream=stream, **kwargs)

async def __process_stream(
self, resp: Iterable[Union[StreamCompletion, ChatCompletion]]
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None
finish_reason = None

for result in resp: # Go through the results of the stream
if not isinstance(result, dict):
logger.warning(f"Unknown result type: {type(result)}: {result}")
continue

choices = result.get('choices', [])
async for result in resp: # Go through the results of the stream
choices = result.choices

if len(choices) == 0:
logger.warning(f"Result has no choices: {result}")
continue

choice = choices[0]

if is_stream_choice(choice): # If there is a delta in the result
delta = choice['delta']

for event in process_delta(delta):
if isinstance(event, ContentDelta):
assistant_view.append(event.content)
elif isinstance(event, FunctionCallNameDelta):
# Is stream choice?
if choice.delta is not None:
if choice.delta.content is not None:
assistant_view.append(choice.delta.content)
elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
if assistant_view.in_progress():
# Flush out the finished assistant message
message = assistant_view.flush()
self.append(message)
function_view = AssistantFunctionCallView(event.name)
elif isinstance(event, FunctionCallArgumentsDelta):
function_view = AssistantFunctionCallView(function_call.name)
if function_call.arguments is not None:
if function_view is None:
raise ValueError("Function arguments provided without function name")
function_view.append(event.arguments)
elif is_full_choice(choice):
message = choice['message']

if is_function_call(message):
function_view = AssistantFunctionCallView(message['function_call']['name'])
function_view.append(message['function_call']['arguments'])
elif 'content' in message and message['content'] is not None:
assistant_view.append(message['content'])

if 'finish_reason' in choice and choice['finish_reason'] is not None:
finish_reason = choice['finish_reason']
function_view.append(function_call.arguments)
if choice.finish_reason is not None:
finish_reason = choice.finish_reason
break

# Wrap up the previous assistant
Expand Down Expand Up @@ -259,6 +213,9 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):
model=self.model,
messages=full_messages,
**manifest,
# Due to this openai beta migration, we're going to assume
# only streaming and drop the non-streaming case for now until
# types are working right.
stream=True,
temperature=kwargs.get("temperature", 0),
)
Expand All @@ -272,8 +229,8 @@ async def submit(self, *messages: Union[Message, str], stream=True, **kwargs):

self.append(*messages)

if not stream:
resp = [resp]
# if not stream:
# resp = [resp]

finish_reason, function_call_request = await self.__process_stream(resp)

Expand Down Expand Up @@ -334,13 +291,17 @@ def register(
...

@overload
def register(self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Dict:
def register(
self, function: Callable, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> FunctionSchema:
...

def register(
self, function: Optional[Callable] = None, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Union[Callable, Dict]:
"""Register a function with the ChatLab instance. This can be used as a decorator like so:
) -> Union[Callable, FunctionSchema]:
"""Register a function with the ChatLab instance.
This can be used as a decorator like so:
>>> from chatlab import Chat
>>> chat = Chat()
Expand Down

0 comments on commit ee299ab

Please sign in to comment.