Skip to content

Commit 659de3e

Browse files
authored
Merge pull request #149 from rgbkrk/update
Update
2 parents 9873fc9 + 0f7c479 commit 659de3e

26 files changed

+2105
-1749
lines changed

chatlab/chat.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -390,16 +390,14 @@ def register(
390390
self,
391391
function: None = None,
392392
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
393-
) -> Callable:
394-
...
393+
) -> Callable: ...
395394

396395
@overload
397396
def register(
398397
self,
399398
function: Callable,
400399
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
401-
) -> FunctionDefinition:
402-
...
400+
) -> FunctionDefinition: ...
403401

404402
def register(
405403
self,

chatlab/decorators.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@
2929
3030
"""
3131

32-
3332
from typing import Callable, Optional
3433

3534
from pydantic import BaseModel
3635

36+
3737
class ChatlabMetadata(BaseModel):
3838
"""ChatLab metadata for a function."""
39+
3940
expose_exception_to_llm: bool = True
4041
render: Optional[Callable] = None
4142
bubble_exceptions: bool = False
4243

44+
4345
def bubble_exceptions(func):
4446
if not hasattr(func, "chatlab_metadata"):
4547
func.chatlab_metadata = ChatlabMetadata()
@@ -51,6 +53,7 @@ def bubble_exceptions(func):
5153
func.chatlab_metadata.bubble_exceptions = True
5254
return func
5355

56+
5457
def expose_exception_to_llm(func):
5558
"""Expose exceptions from calling the function to the LLM.
5659
@@ -107,6 +110,7 @@ def store_knowledge_graph(kg: KnowledgeGraph, comment: str = "Knowledge Graph"):
107110
chat.register(store_knowledge_graph)
108111
'''
109112

113+
110114
def incremental_display(render_func: Callable):
111115
def decorator(func):
112116
if not hasattr(func, "chatlab_metadata"):
@@ -118,5 +122,5 @@ def decorator(func):
118122

119123
func.chatlab_metadata.render = render_func
120124
return func
121-
return decorator
122125

126+
return decorator

chatlab/messaging.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam:
105105

106106

107107
class HasGetToolArgumentsParameter(Protocol):
108-
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam:
109-
...
108+
def get_tool_arguments_parameter(self) -> ChatCompletionMessageToolCallParam: ...
110109

111110

112111
def assistant_tool_calls(tool_calls: Iterable[HasGetToolArgumentsParameter]) -> ChatCompletionMessageParam:

chatlab/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
class ChatModel(Enum):
77
"""Models available for use with chatlab."""
8+
89
GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview"
910
GPT_4_0125_PREVIEW = "gpt-4-0125-preview"
1011
GPT_4_1106_PREVIEW = "gpt-4-1106-preview"

chatlab/registry.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,14 @@ def register(
288288
self,
289289
function: None = None,
290290
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
291-
) -> Callable:
292-
...
291+
) -> Callable: ...
293292

294293
@overload
295294
def register(
296295
self,
297296
function: Callable,
298297
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
299-
) -> FunctionDefinition:
300-
...
298+
) -> FunctionDefinition: ...
301299

302300
def register(
303301
self,
@@ -438,7 +436,13 @@ def api_manifest(self, function_call_option: FunctionCall = "auto") -> APIManife
438436

439437
@property
440438
def tools(self) -> Iterable[ChatCompletionToolParam]:
441-
return [{"type": "function", "function": adapt_function_definition(f)} for f in self.__schemas.values()]
439+
return [
440+
ChatCompletionToolParam(
441+
type="function",
442+
function=adapt_function_definition(f), # type: ignore
443+
)
444+
for f in self.__schemas.values()
445+
]
442446

443447
async def call(self, name: str, arguments: Optional[str] = None) -> Any:
444448
"""Call a function by name with the given parameters."""

chatlab/tools/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
"run_python",
99
"shell_functions",
1010
]
11-

chatlab/tools/_mediatypes.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Media types for rich output for LLMs and in-notebook."""
2+
23
import json
34
from typing import Optional
45

chatlab/tools/colors.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Let models pick and show color palettes to you."""
2+
23
import hashlib
34
from typing import List, Optional
45
from pydantic import BaseModel, validator, Field

chatlab/tools/files.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
You've been warned. Have fun and be safe!
99
"""
10+
1011
import asyncio
1112
import os
1213

chatlab/tools/python.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""The in-IPython python code runner for ChatLab."""
2+
23
from traceback import TracebackException
34
from typing import Optional
45

chatlab/tools/shell.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Shell commands for ChatLab."""
2+
23
import asyncio
34
import subprocess
45

chatlab/views/__init__.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
"""Views for ChatLab."""
2+
23
from .assistant import AssistantMessageView
34
from .tools import ToolArguments, ToolCalled
45

5-
__all__ = [
6-
"AssistantMessageView",
7-
"ToolArguments",
8-
"ToolCalled"
9-
]
6+
__all__ = ["AssistantMessageView", "ToolArguments", "ToolCalled"]

chatlab/views/tools.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from IPython.display import display
1616
from IPython.core.getipython import get_ipython
1717

18-
from instructor.dsl.partialjson import JSONParser
19-
18+
from jiter import from_json
2019

2120

2221
class ToolArguments(AutoUpdate):
@@ -75,11 +74,11 @@ def update(self) -> None:
7574

7675
def render(self):
7776
if self.custom_render is not None:
78-
# We use the same definition as was in the original function
7977
try:
80-
parser = JSONParser()
81-
possible_args = parser.parse(self.arguments)
82-
78+
possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
79+
except Exception:
80+
return None
81+
try:
8382
Model = extract_model_from_function(self.name, self.custom_render)
8483
# model = Model.model_validate(possible_args)
8584
model = Model(**possible_args)
@@ -110,13 +109,17 @@ def append_arguments(self, arguments: str):
110109
def apply_result(self, result: str):
111110
"""Replaces the existing display with a new one that shows the result of the tool being called."""
112111
tc = ToolCalled(
113-
id=self.id, name=self.name, arguments=self.arguments, result=result, display_id=self.display_id,
114-
custom_render=self.custom_render
112+
id=self.id,
113+
name=self.name,
114+
arguments=self.arguments,
115+
result=result,
116+
display_id=self.display_id,
117+
custom_render=self.custom_render,
115118
)
116119
tc.update()
117120
return tc
118121

119-
async def call(self, function_registry: FunctionRegistry) -> 'ToolCalled':
122+
async def call(self, function_registry: FunctionRegistry) -> "ToolCalled":
120123
"""Call the function and return a stack of messages for LLM and human consumption."""
121124
function_name = self.name
122125
function_args = self.arguments
@@ -185,9 +188,11 @@ def render(self):
185188
if self.custom_render is not None:
186189
# We use the same definition as was in the original function
187190
try:
188-
parser = JSONParser()
189-
possible_args = parser.parse(self.arguments)
191+
possible_args = from_json(self.arguments.encode("utf-8"), partial_mode="trailing-strings")
192+
except Exception:
193+
return None
190194

195+
try:
191196
Model = extract_model_from_function(self.name, self.custom_render)
192197
# model = Model.model_validate(possible_args)
193198
model = Model(**possible_args)

notebooks/basics.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
"outputs": [],
103103
"source": [
104104
"from datetime import datetime\n",
105-
"from pytz import timezone, all_timezones, utc\n",
105+
"from pytz import timezone, all_timezones\n",
106106
"from typing import Optional\n",
107107
"from pydantic import BaseModel\n",
108108
"\n",

0 commit comments

Comments
 (0)