diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e62b5ec..f6245e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.4.0 hooks: - - id: trailing-whitespace - id: end-of-file-fixer - id: check-merge-conflict - id: check-yaml diff --git a/chatlab/__init__.py b/chatlab/__init__.py index 44aa6fe..6305159 100644 --- a/chatlab/__init__.py +++ b/chatlab/__init__.py @@ -21,7 +21,16 @@ from ._version import __version__ from .chat import Chat from .decorators import ChatlabMetadata, expose_exception_to_llm -from .messaging import ai, assistant, assistant_function_call, function_result, human, narrate, system, user +from .messaging import ( + ai, + assistant, + assistant_function_call, + function_result, + human, + narrate, + system, + user, +) from .registry import FunctionRegistry from .views.markdown import Markdown diff --git a/chatlab/builtins/noteable.py b/chatlab/builtins/noteable.py index 95b5c1c..2d7d57c 100644 --- a/chatlab/builtins/noteable.py +++ b/chatlab/builtins/noteable.py @@ -102,9 +102,7 @@ async def create(cls, file_name=None, token=None, file_id=None, project_id=None) # We have to track the kernel_session for now kernel_session = await api_client.launch_kernel(file_id) - cn = NotebookClient( - api_client, rtu_client, file_id=file_id, kernel_session=kernel_session - ) + cn = NotebookClient(api_client, rtu_client, file_id=file_id, kernel_session=kernel_session) return cn @@ -167,9 +165,7 @@ async def create_cell( ) if cell is None: - return ( - f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql." - ) + return f"Unknown cell type {cell_type}. Valid types are: markdown, code, sql." logger.info(f"Adding cell {cell_id} to notebook") cell = await rtu_client.add_cell(cell=cell, after_id=after_cell_id) @@ -230,9 +226,7 @@ async def update_cell( async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID): """Get the outputs for a given output collection ID.""" - output_collection = await self.api_client.get_output_collection( - output_collection_id - ) + output_collection = await self.api_client.get_output_collection(output_collection_id) outputs = output_collection.outputs @@ -252,9 +246,7 @@ async def _get_llm_friendly_outputs(self, output_collection_id: uuid.UUID): return llm_friendly_outputs async def _extract_llm_plain(self, output: KernelOutput): - resp = await self.api_client.client.get( - f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"} - ) + resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": "text/llm+plain"}) resp.raise_for_status() output_for_llm = KernelOutput.parse_obj(resp.json()) @@ -265,9 +257,7 @@ async def _extract_llm_plain(self, output: KernelOutput): return output_for_llm.content.raw async def _extract_specific_mediatype(self, output: KernelOutput, mimetype: str): - resp = await self.api_client.client.get( - f"/outputs/{output.id}", params={"mimetype": mimetype} - ) + resp = await self.api_client.client.get(f"/outputs/{output.id}", params={"mimetype": mimetype}) resp.raise_for_status() output_for_llm = KernelOutput.parse_obj(resp.json()) @@ -330,9 +320,7 @@ async def _get_llm_friendly_output(self, output: KernelOutput): for format in formats_for_llm: if format in mimetypes: - resp = await self.api_client.client.get( - f"/outputs/{output.id}?mimetype={format}" - ) + resp = await self.api_client.client.get(f"/outputs/{output.id}?mimetype={format}") resp.raise_for_status() if resp.status_code == 200: return @@ -421,9 +409,7 @@ async def get_cell(self, cell_id: str, with_outputs: bool = True): response += cell.source return response - source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get( - "language", "" - ) + source_type = rtu_client.builder.nb.metadata.get("kernelspec", {}).get("language", "") if cell.metadata.get("noteable", {}).get("cell_type") == "sql": source_type = "sql" @@ -543,9 +529,7 @@ def chat_functions(self): ] -def provide_notebook_creation( - registry: FunctionRegistry, project_id: Optional[str] = None -): +def provide_notebook_creation(registry: FunctionRegistry, project_id: Optional[str] = None): """Register the notebook client with the registry. >>> from chatlab import FunctionRegistry, Chat diff --git a/chatlab/builtins/python.py b/chatlab/builtins/python.py index f4cba37..a7b355a 100644 --- a/chatlab/builtins/python.py +++ b/chatlab/builtins/python.py @@ -16,12 +16,8 @@ def apply_llm_formatter(shell: InteractiveShell): """Apply the LLM formatter to the given shell.""" llm_formatter = register_llm_formatter(shell) - llm_formatter.for_type_by_name( - "pandas.core.frame", "DataFrame", format_dataframe_for_llm - ) - llm_formatter.for_type_by_name( - "pandas.core.series", "Series", format_series_for_llm - ) + llm_formatter.for_type_by_name("pandas.core.frame", "DataFrame", format_dataframe_for_llm) + llm_formatter.for_type_by_name("pandas.core.series", "Series", format_series_for_llm) def get_or_create_ipython() -> InteractiveShell: @@ -80,9 +76,7 @@ def run_cell(self, code: str): # Create a formatted traceback that includes the last 3 frames # and the exception message - formatted = TracebackException.from_exception(exception, limit=3).format( - chain=True - ) + formatted = TracebackException.from_exception(exception, limit=3).format(chain=True) plaintext_traceback = "\n".join(formatted) return plaintext_traceback diff --git a/chatlab/builtins/shell.py b/chatlab/builtins/shell.py index fc9d28e..daea730 100644 --- a/chatlab/builtins/shell.py +++ b/chatlab/builtins/shell.py @@ -15,9 +15,7 @@ async def run_shell_command(command: str): Returns: - str: the output of the shell command """ - process = await asyncio.create_subprocess_shell( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + process = await asyncio.create_subprocess_shell(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = await process.communicate() resp = f"Return Code: {process.returncode}\n" diff --git a/chatlab/chat.py b/chatlab/chat.py index 2cfebe7..4231275 100644 --- a/chatlab/chat.py +++ b/chatlab/chat.py @@ -20,11 +20,7 @@ from deprecation import deprecated from IPython.core.async_helpers import get_asyncio_loop from openai import AsyncOpenAI -from openai.types.chat import ( - ChatCompletion, - ChatCompletionChunk, - ChatCompletionMessageParam, -) +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from pydantic import BaseModel from chatlab.views.assistant_function_call import AssistantFunctionCallView @@ -114,9 +110,7 @@ def __init__( python_hallucination_function = run_cell - self.function_registry = FunctionRegistry( - python_hallucination_function=python_hallucination_function - ) + self.function_registry = FunctionRegistry(python_hallucination_function=python_hallucination_function) else: self.function_registry = function_registry @@ -139,9 +133,7 @@ def chat( """ raise Exception("This method is deprecated. Use `submit` instead.") - async def __call__( - self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs - ): + async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs): """Send messages to the chat model and display the response.""" return await self.submit(*messages, stream=stream, **kwargs) @@ -175,9 +167,7 @@ async def __process_stream( function_view = AssistantFunctionCallView(function_call.name) if function_call.arguments is not None: if function_view is None: - raise ValueError( - "Function arguments provided without function name" - ) + raise ValueError("Function arguments provided without function name") function_view.append(function_call.arguments) if choice.finish_reason is not None: finish_reason = choice.finish_reason @@ -194,9 +184,7 @@ async def __process_stream( return (finish_reason, function_view) - async def __process_full_completion( - self, resp: ChatCompletion - ) -> Tuple[str, Optional[AssistantFunctionCallView]]: + async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[AssistantFunctionCallView]]: assistant_view: AssistantMessageView = AssistantMessageView() function_view: Optional[AssistantFunctionCallView] = None @@ -218,9 +206,7 @@ async def __process_full_completion( return choice.finish_reason, function_view - async def submit( - self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs - ): + async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream=True, **kwargs): """Send messages to the chat model and display the response. Side effects: @@ -248,7 +234,8 @@ async def submit( manifest = self.function_registry.api_manifest() - # Due to the strict response typing based on `Literal` typing on `stream`, we have to process these two cases separately + # Due to the strict response typing based on `Literal` typing on `stream`, we have to process these + # two cases separately if stream: streaming_response = await client.chat.completions.create( model=self.model, @@ -258,9 +245,7 @@ async def submit( temperature=kwargs.get("temperature", 0), ) - finish_reason, function_call_request = await self.__process_stream( - streaming_response - ) + finish_reason, function_call_request = await self.__process_stream(streaming_response) else: full_response = await client.chat.completions.create( model=self.model, @@ -413,9 +398,7 @@ def ipython_magic_submit(self, line, cell: Optional[str] = None, **kwargs): return cell = cell.strip() - asyncio.run_coroutine_threadsafe( - self.submit(cell, **kwargs), get_asyncio_loop() - ) + asyncio.run_coroutine_threadsafe(self.submit(cell, **kwargs), get_asyncio_loop()) def make_magic(self, name): """Register the chat as an IPython magic with the given name. @@ -435,6 +418,4 @@ def make_magic(self, name): if ip is None: raise Exception("IPython is not available.") - ip.register_magic_function( - self.ipython_magic_submit, magic_kind="line_cell", magic_name=name - ) + ip.register_magic_function(self.ipython_magic_submit, magic_kind="line_cell", magic_name=name) diff --git a/chatlab/components/function_details.py b/chatlab/components/function_details.py index af468ce..a17fa6c 100644 --- a/chatlab/components/function_details.py +++ b/chatlab/components/function_details.py @@ -27,9 +27,7 @@ def function_logo(): """Styled 𝑓 logo component for use in the chat function component.""" - return span( - "𝑓", style=dict(color=colors["light"], paddingRight="5px", paddingLeft="5px") - ) + return span("𝑓", style=dict(color=colors["light"], paddingRight="5px", paddingLeft="5px")) def function_verbage(state: str): @@ -42,9 +40,7 @@ def function_verbage(state: str): def inline_pre(text: str): """A simple preformatted monospace component that works in all Jupyter frontends.""" - return span( - text, style=dict(unicodeBidi="embed", fontFamily="monospace", whiteSpace="pre") - ) + return span(text, style=dict(unicodeBidi="embed", fontFamily="monospace", whiteSpace="pre")) def raw_function_interface_heading(text: str): @@ -87,9 +83,7 @@ def ChatFunctionComponent( input_element = div() if input is not None: input = input.strip() - input_element = div( - raw_function_interface_heading("Input:"), raw_function_interface(input) - ) + input_element = div(raw_function_interface_heading("Input:"), raw_function_interface(input)) output_element = div() if output is not None: @@ -100,9 +94,7 @@ def ChatFunctionComponent( ) return div( - style( - ".chatlab-chat-details summary > * { display: inline; color: #27374D; }" - ), + style(".chatlab-chat-details summary > * { display: inline; color: #27374D; }"), details( summary( function_logo(), diff --git a/chatlab/display.py b/chatlab/display.py index bce5db6..174da18 100644 --- a/chatlab/display.py +++ b/chatlab/display.py @@ -56,15 +56,11 @@ async def call(self) -> ChatCompletionMessageParam: self.finished = True self.set_state("No function named") self.function_result = repr(e) - return system( - f"Function {function_name} not found in function registry: {e}" - ) + return system(f"Function {function_name} not found in function registry: {e}") except Exception as e: # Check to see if the user has requested that the exception be exposed to LLM. # If not, then we just raise it and let the user handle it. - chatlab_metadata = self.function_registry.get_chatlab_metadata( - function_name - ) + chatlab_metadata = self.function_registry.get_chatlab_metadata(function_name) if not chatlab_metadata.expose_exception_to_llm: # Bubble up the exception to the user diff --git a/chatlab/messaging.py b/chatlab/messaging.py index 7f8ccd0..142dfa7 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -60,9 +60,7 @@ def system(content: str) -> ChatCompletionMessageParam: } -def assistant_function_call( - name: str, arguments: Optional[str] = None -) -> ChatCompletionMessageParam: +def assistant_function_call(name: str, arguments: Optional[str] = None) -> ChatCompletionMessageParam: """Create a function call message from the assistant. Args: diff --git a/chatlab/registry.py b/chatlab/registry.py index ff93135..8323dd0 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -108,9 +108,7 @@ class UnknownFunctionError(Exception): def is_optional_type(t): """Check if a type is Optional.""" - return ( - get_origin(t) is Union and len(get_args(t)) == 2 and type(None) in get_args(t) - ) + return get_origin(t) is Union and len(get_args(t)) == 2 and type(None) in get_args(t) def is_union_type(t): @@ -162,9 +160,7 @@ def generate_function_schema( # determine type annotation if param.annotation == inspect.Parameter.empty: # no annotation, raise instead of falling back to Any - raise Exception( - f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation" - ) + raise Exception(f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation") type_annotation = param.annotation # get the default value, otherwise set as required @@ -253,9 +249,7 @@ def __init__( self.python_hallucination_function = python_hallucination_function - def decorator( - self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None - ) -> Callable: + def decorator(self, parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None) -> Callable: """Create a decorator for registering functions with a schema.""" def decorator(function): @@ -324,9 +318,7 @@ def register_function( return final_schema - def register_functions( - self, functions: Union[Iterable[Callable], dict[str, Callable]] - ): + def register_functions(self, functions: Union[Iterable[Callable], dict[str, Callable]]): """Register a dictionary of functions.""" if isinstance(functions, dict): functions = functions.values() @@ -355,9 +347,7 @@ def get_chatlab_metadata(self, function_name) -> ChatlabMetadata: chatlab_metadata = getattr(function, "chatlab_metadata", ChatlabMetadata()) return chatlab_metadata - def api_manifest( - self, function_call_option: FunctionCallOption = "auto" - ) -> APIManifest: + def api_manifest(self, function_call_option: FunctionCallOption = "auto") -> APIManifest: """Get a dictionary containing function definitions and calling options. This is designed to be used with OpenAI's Chat Completion API, where the @@ -447,9 +437,7 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any: parameters = json.loads(arguments) # TODO: Validate parameters against schema except json.JSONDecodeError: - raise FunctionArgumentError( - f"Invalid Function call on {name}. Arguments must be a valid JSON object" - ) + raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object") if function is None: raise UnknownFunctionError(f"Function {name} is not registered") diff --git a/poetry.lock b/poetry.lock index a17d6c8..b482279 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2488,6 +2488,17 @@ files = [ {file = "types_aiofiles-23.2.0.0-py3-none-any.whl", hash = "sha256:5d6719e8148cb2a9c4ea46dad86d50d3b675c46a940adca698533a8d2216d53d"}, ] +[[package]] +name = "types-orjson" +version = "3.6.2" +description = "Typing stubs for orjson" +optional = false +python-versions = "*" +files = [ + {file = "types-orjson-3.6.2.tar.gz", hash = "sha256:cf9afcc79a86325c7aff251790338109ed6f6b1bab09d2d4262dd18c85a3c638"}, + {file = "types_orjson-3.6.2-py3-none-any.whl", hash = "sha256:22ee9a79236b6b0bfb35a0684eded62ad930a88a56797fa3c449b026cf7dbfe4"}, +] + [[package]] name = "types-toml" version = "0.10.8.7" @@ -2688,4 +2699,4 @@ test = [] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.13" -content-hash = "6eb7545220b44cb8a043f26d27899433e2f9bd420fad69bcb3dca6d621537569" +content-hash = "e3ab2f754cd2050de0bd4135ae10f031287ae3be33d974d5f8b2b391dce22392" diff --git a/pyproject.toml b/pyproject.toml index 748996c..f338af8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ python-ulid = { version = "^1.1.0", optional = true } noteable-origami = { version = "^1.0.0a5", allow-prereleases = true, optional = true } typing-extensions = "^4.7.1" aiofiles = "^23.1.0" +types-aiofiles = "^23.2.0.0" +mypy = "^1.6.1" [tool.poetry.group.dev.dependencies] tox = "^4.4.11" @@ -63,6 +65,7 @@ types-toml = "^0.10.8.6" pandas = "^2.0.2" pytest-asyncio = "^0.21.1" types-aiofiles = "^23.1.0.5" +types-orjson = "^3.6.2" [tool.poetry.extras] noteable = ["noteable-origami", "python-ulid"] @@ -76,9 +79,24 @@ test = [ dev = ["tox", "pre-commit", "virtualenv", "pip", "twine", "toml", "bump2version"] +[[tool.mypy.overrides]] +module = [ + "aiofiles", + "orjson", + "vdom", + "repr_llm.*", + "origami.*", + "deprecation" +] + +ignore_missing_imports = true + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] addopts = "--cov --cov-report=lcov:lcov.info --cov-report=term --cov-report=html" + +[tool.ruff] +line-length = 120 diff --git a/setup.cfg b/setup.cfg index 2f6cef5..c17252b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,4 +90,4 @@ commands = [bdist_wheel] -universal = 1 \ No newline at end of file +universal = 1 diff --git a/tests/test_registry.py b/tests/test_registry.py index 3945c1b..cb2586f 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -73,9 +73,7 @@ def simple_func_with_uuid_arg( # Test the function generation schema def test_generate_function_schema_lambda(): - with pytest.raises( - Exception, match="Lambdas cannot be registered. Use `def` instead." - ): + with pytest.raises(Exception, match="Lambdas cannot be registered. Use `def` instead."): generate_function_schema(lambda x: x) @@ -83,9 +81,7 @@ def test_generate_function_schema_no_docstring(): def no_docstring(x: int): return x - with pytest.raises( - Exception, match="Only functions with docstrings can be registered" - ): + with pytest.raises(Exception, match="Only functions with docstrings can be registered"): generate_function_schema(no_docstring) @@ -285,9 +281,7 @@ def test_generate_function_schema_with_uuid_argument(): @pytest.mark.asyncio async def test_function_registry_unknown_function(): registry = FunctionRegistry() - with pytest.raises( - UnknownFunctionError, match="Function unknown is not registered" - ): + with pytest.raises(UnknownFunctionError, match="Function unknown is not registered"): await registry.call("unknown") @@ -306,18 +300,14 @@ async def test_function_registry_function_argument_error(): async def test_function_registry_call(): registry = FunctionRegistry() registry.register(simple_func, SimpleModel) - result = await registry.call( - "simple_func", arguments='{"x": 1, "y": "str", "z": true}' - ) + result = await registry.call("simple_func", arguments='{"x": 1, "y": "str", "z": true}') assert result == "1, str, True" # Testing for registry's register method with an invalid function def test_function_registry_register_invalid_function(): registry = FunctionRegistry() - with pytest.raises( - Exception, match="Lambdas cannot be registered. Use `def` instead." - ): + with pytest.raises(Exception, match="Lambdas cannot be registered. Use `def` instead."): registry.register(lambda x: x) @@ -387,9 +377,7 @@ def func_no_args(): async def test_function_registry_call_edge_cases(): registry = FunctionRegistry() with pytest.raises(UnknownFunctionError): - await registry.call( - "totes_not_real", arguments='{"x": 1, "y": "str", "z": true}' - ) + await registry.call("totes_not_real", arguments='{"x": 1, "y": "str", "z": true}') with pytest.raises(UnknownFunctionError): await registry.call(None) # type: ignore