Skip to content

Commit

Permalink
Merge pull request #103 from rgbkrk/fix-up-ruff-mypy-precommits
Browse files Browse the repository at this point in the history
Fix up ruff mypy precommits
  • Loading branch information
rgbkrk authored Nov 5, 2023
2 parents 41928d8 + ef777ab commit 7dce82c
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 127 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-merge-conflict
- id: check-yaml
Expand Down
11 changes: 10 additions & 1 deletion chatlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
from ._version import __version__
from .chat import Chat
from .decorators import ChatlabMetadata, expose_exception_to_llm
from .messaging import ai, assistant, assistant_function_call, function_result, human, narrate, system, user
from .messaging import (
ai,
assistant,
assistant_function_call,
function_result,
human,
narrate,
system,
user,
)
from .registry import FunctionRegistry
from .views.markdown import Markdown

Expand Down
32 changes: 8 additions & 24 deletions chatlab/builtins/noteable.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ async def create(cls, file_name=None, token=None, file_id=None, project_id=None)
# We have to track the kernel_session for now
kernel_session = await api_client.launch_kernel(file_id)

cn = NotebookClient(
api_client, rtu_client, file_id=file_id, kernel_session=kernel_session
)
cn = NotebookClient(api_client, rtu_client, file_id=file_id, kernel_session=kernel_session)

return cn

Expand Down Expand Up @@ -167,9 +165,7 @@ async def create_cell(
)

if cell is None:
return (
f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql."
)
return f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql."

logger.info(f"Adding cell {cell_id} to notebook")
cell = await rtu_client.add_cell(cell=cell, after_id=after_cell_id)
Expand Down Expand Up @@ -230,9 +226,7 @@ async def update_cell(

async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
"""Get the outputs for a given output collection ID."""
output_collection = await self.api_client.get_output_collection(
output_collection_id
)
output_collection = await self.api_client.get_output_collection(output_collection_id)

outputs = output_collection.outputs

Expand All @@ -252,9 +246,7 @@ async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID):
return llm_friendly_outputs

async def _extract_llm_plain(self, output: KernelOutput):
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"}
)
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"})
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand All @@ -265,9 +257,7 @@ async def _extract_llm_plain(self, output: KernelOutput):
return output_for_llm.content.raw

async def _extract_specific_mediatype(self, output: KernelOutput, mimetype: str):
resp = await self.api_client.client.get(
f"/outputs/{output.id}", params={"mimetype": mimetype}
)
resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": mimetype})
resp.raise_for_status()

output_for_llm = KernelOutput.parse_obj(resp.json())
Expand Down Expand Up @@ -330,9 +320,7 @@ async def _get_llm_friendly_output(self, output: KernelOutput):

for format in formats_for_llm:
if format in mimetypes:
resp = await self.api_client.client.get(
f"/outputs/{output.id}?mimetype={format}"
)
resp = await self.api_client.client.get(f"/outputs/{output.id}?mimetype={format}")
resp.raise_for_status()
if resp.status_code == 200:
return
Expand Down Expand Up @@ -421,9 +409,7 @@ async def get_cell(self, cell_id: str, with_outputs: bool = True):
response += cell.source
return response

source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get(
"language", ""
)
source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get("language", "")

if cell.metadata.get("noteable", {}).get("cell_type") == "sql":
source_type = "sql"
Expand Down Expand Up @@ -543,9 +529,7 @@ def chat_functions(self):
]


def provide_notebook_creation(
registry: FunctionRegistry, project_id: Optional[str] = None
):
def provide_notebook_creation(registry: FunctionRegistry, project_id: Optional[str] = None):
"""Register the notebook client with the registry.
>>> from chatlab import FunctionRegistry, Chat
Expand Down
12 changes: 3 additions & 9 deletions chatlab/builtins/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@ def apply_llm_formatter(shell: InteractiveShell):
"""Apply the LLM formatter to the given shell."""
llm_formatter = register_llm_formatter(shell)

llm_formatter.for_type_by_name(
"pandas.core.frame", "DataFrame", format_dataframe_for_llm
)
llm_formatter.for_type_by_name(
"pandas.core.series", "Series", format_series_for_llm
)
llm_formatter.for_type_by_name("pandas.core.frame", "DataFrame", format_dataframe_for_llm)
llm_formatter.for_type_by_name("pandas.core.series", "Series", format_series_for_llm)


def get_or_create_ipython() -> InteractiveShell:
Expand Down Expand Up @@ -80,9 +76,7 @@ def run_cell(self, code: str):

# Create a formatted traceback that includes the last 3 frames
# and the exception message
formatted = TracebackException.from_exception(exception, limit=3).format(
chain=True
)
formatted = TracebackException.from_exception(exception, limit=3).format(chain=True)
plaintext_traceback = "\n".join(formatted)

return plaintext_traceback
Expand Down
4 changes: 1 addition & 3 deletions chatlab/builtins/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ async def run_shell_command(command: str):
Returns:
- str: the output of the shell command
"""
process = await asyncio.create_subprocess_shell(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
process = await asyncio.create_subprocess_shell(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = await process.communicate()

resp = f"Return Code: {process.returncode}\n"
Expand Down
41 changes: 11 additions & 30 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from deprecation import deprecated
from IPython.core.async_helpers import get_asyncio_loop
from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
from pydantic import BaseModel

from chatlab.views.assistant_function_call import AssistantFunctionCallView
Expand Down Expand Up @@ -114,9 +110,7 @@ def __init__(

python_hallucination_function = run_cell

self.function_registry = FunctionRegistry(
python_hallucination_function=python_hallucination_function
)
self.function_registry = FunctionRegistry(python_hallucination_function=python_hallucination_function)
else:
self.function_registry = function_registry

Expand All @@ -139,9 +133,7 @@ def chat(
"""
raise Exception("This method is deprecated. Use `submit` instead.")

async def __call__(
self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs
):
async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response."""
return await self.submit(*messages, stream=stream, **kwargs)

Expand Down Expand Up @@ -175,9 +167,7 @@ async def __process_stream(
function_view = AssistantFunctionCallView(function_call.name)
if function_call.arguments is not None:
if function_view is None:
raise ValueError(
"Function arguments provided without function name"
)
raise ValueError("Function arguments provided without function name")
function_view.append(function_call.arguments)
if choice.finish_reason is not None:
finish_reason = choice.finish_reason
Expand All @@ -194,9 +184,7 @@ async def __process_stream(

return (finish_reason, function_view)

async def __process_full_completion(
self, resp: ChatCompletion
) -> Tuple[str, Optional[AssistantFunctionCallView]]:
async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[AssistantFunctionCallView] = None

Expand All @@ -218,9 +206,7 @@ async def __process_full_completion(

return choice.finish_reason, function_view

async def submit(
self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs
):
async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs):
"""Send messages to the chat model and display the response.
Side effects:
Expand Down Expand Up @@ -248,7 +234,8 @@ async def submit(

manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately
# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
Expand All @@ -258,9 +245,7 @@ async def submit(
temperature=kwargs.get("temperature", 0),
)

finish_reason, function_call_request = await self.__process_stream(
streaming_response
)
finish_reason, function_call_request = await self.__process_stream(streaming_response)
else:
full_response = await client.chat.completions.create(
model=self.model,
Expand Down Expand Up @@ -413,9 +398,7 @@ def ipython_magic_submit(self, line, cell: Optional[str] = None, **kwargs):
return
cell = cell.strip()

asyncio.run_coroutine_threadsafe(
self.submit(cell, **kwargs), get_asyncio_loop()
)
asyncio.run_coroutine_threadsafe(self.submit(cell, **kwargs), get_asyncio_loop())

def make_magic(self, name):
"""Register the chat as an IPython magic with the given name.
Expand All @@ -435,6 +418,4 @@ def make_magic(self, name):
if ip is None:
raise Exception("IPython is not available.")

ip.register_magic_function(
self.ipython_magic_submit, magic_kind="line_cell", magic_name=name
)
ip.register_magic_function(self.ipython_magic_submit, magic_kind="line_cell", magic_name=name)
16 changes: 4 additions & 12 deletions chatlab/components/function_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@

def function_logo():
"""Styled 𝑓 logo component for use in the chat function component."""
return span(
"𝑓", style=dict(color=colors["light"], paddingRight="5px", paddingLeft="5px")
)
return span("𝑓", style=dict(color=colors["light"], paddingRight="5px", paddingLeft="5px"))


def function_verbage(state: str):
Expand All @@ -42,9 +40,7 @@ def function_verbage(state: str):

def inline_pre(text: str):
"""A simple preformatted monospace component that works in all Jupyter frontends."""
return span(
text, style=dict(unicodeBidi="embed", fontFamily="monospace", whiteSpace="pre")
)
return span(text, style=dict(unicodeBidi="embed", fontFamily="monospace", whiteSpace="pre"))


def raw_function_interface_heading(text: str):
Expand Down Expand Up @@ -87,9 +83,7 @@ def ChatFunctionComponent(
input_element = div()
if input is not None:
input = input.strip()
input_element = div(
raw_function_interface_heading("Input:"), raw_function_interface(input)
)
input_element = div(raw_function_interface_heading("Input:"), raw_function_interface(input))

output_element = div()
if output is not None:
Expand All @@ -100,9 +94,7 @@ def ChatFunctionComponent(
)

return div(
style(
".chatlab-chat-details summary > * { display: inline; color: #27374D; }"
),
style(".chatlab-chat-details summary > * { display: inline; color: #27374D; }"),
details(
summary(
function_logo(),
Expand Down
8 changes: 2 additions & 6 deletions chatlab/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,11 @@ async def call(self) -> ChatCompletionMessageParam:
self.finished = True
self.set_state("No function named")
self.function_result = repr(e)
return system(
f"Function {function_name} not found in function registry: {e}"
)
return system(f"Function {function_name} not found in function registry: {e}")
except Exception as e:
# Check to see if the user has requested that the exception be exposed to LLM.
# If not, then we just raise it and let the user handle it.
chatlab_metadata = self.function_registry.get_chatlab_metadata(
function_name
)
chatlab_metadata = self.function_registry.get_chatlab_metadata(function_name)

if not chatlab_metadata.expose_exception_to_llm:
# Bubble up the exception to the user
Expand Down
4 changes: 1 addition & 3 deletions chatlab/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def system(content: str) -> ChatCompletionMessageParam:
}


def assistant_function_call(
name: str, arguments: Optional[str] = None
) -> ChatCompletionMessageParam:
def assistant_function_call(name: str, arguments: Optional[str] = None) -> ChatCompletionMessageParam:
"""Create a function call message from the assistant.
Args:
Expand Down
24 changes: 6 additions & 18 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ class UnknownFunctionError(Exception):

def is_optional_type(t):
"""Check if a type is Optional."""
return (
get_origin(t) is Union and len(get_args(t)) == 2 and type(None) in get_args(t)
)
return get_origin(t) is Union and len(get_args(t)) == 2 and type(None) in get_args(t)


def is_union_type(t):
Expand Down Expand Up @@ -162,9 +160,7 @@ def generate_function_schema(
# determine type annotation
if param.annotation == inspect.Parameter.empty:
# no annotation, raise instead of falling back to Any
raise Exception(
f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation"
)
raise Exception(f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation")
type_annotation = param.annotation

# get the default value, otherwise set as required
Expand Down Expand Up @@ -253,9 +249,7 @@ def __init__(

self.python_hallucination_function = python_hallucination_function

def decorator(
self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None
) -> Callable:
def decorator(self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Callable:
"""Create a decorator for registering functions with a schema."""

def decorator(function):
Expand Down Expand Up @@ -324,9 +318,7 @@ def register_function(

return final_schema

def register_functions(
self, functions: Union[Iterable[Callable], dict[str, Callable]]
):
def register_functions(self, functions: Union[Iterable[Callable], dict[str, Callable]]):
"""Register a dictionary of functions."""
if isinstance(functions, dict):
functions = functions.values()
Expand Down Expand Up @@ -355,9 +347,7 @@ def get_chatlab_metadata(self, function_name) -> ChatlabMetadata:
chatlab_metadata = getattr(function, "chatlab_metadata", ChatlabMetadata())
return chatlab_metadata

def api_manifest(
self, function_call_option: FunctionCallOption = "auto"
) -> APIManifest:
def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> APIManifest:
"""Get a dictionary containing function definitions and calling options.
This is designed to be used with OpenAI's Chat Completion API, where the
Expand Down Expand Up @@ -447,9 +437,7 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:
parameters = json.loads(arguments)
# TODO: Validate parameters against schema
except json.JSONDecodeError:
raise FunctionArgumentError(
f"Invalid Function call on {name}. Arguments must be a valid JSON object"
)
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")

if function is None:
raise UnknownFunctionError(f"Function {name} is not registered")
Expand Down
Loading

0 comments on commit 7dce82c

Please sign in to comment.