Skip to content

Commit 46943cc

Browse files
Merge pull request #30 from weaviate/release/v0.2.x
2 parents 9a59dd4 + 0f19e97 commit 46943cc

File tree

10 files changed

+795
-432
lines changed

10 files changed

+795
-432
lines changed

docs/Examples/data_analysis.md

Lines changed: 93 additions & 412 deletions
Large diffs are not rendered by default.

docs/Examples/old_data_analysis.md

Lines changed: 463 additions & 0 deletions
Large diffs are not rendered by default.

docs/creating_tools.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Finally, tools can interact with Elysia's environment, LMs and the Weaviate clie
123123
```python
124124
@tool
125125
async def some_tool(
126-
tree_data, base_lm, complex_lm, tree_data, # these inputs are automatically assigned as Elysia variables
126+
tree_data, base_lm, complex_lm, client_manager, # these inputs are automatically assigned as Elysia variables
127127
x: str, y: int # these inputs are not assigned automatically and get assigned by the decision agent
128128
):
129129
# do something
16.2 KB
Loading

elysia/objects.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def return_mapping(result, inputs: dict):
373373

374374
class ToolClass(Tool):
375375
def __init__(self, **kwargs):
376+
self._original_function = function
377+
self._original_function_args = {
378+
arg: None
379+
for arg in function.__code__.co_varnames[
380+
: function.__code__.co_argcount
381+
]
382+
}
383+
for arg in function.__annotations__:
384+
if arg in self._original_function_args:
385+
self._original_function_args[arg] = function.__annotations__[
386+
arg
387+
]
376388
super().__init__(
377389
name=function.__name__,
378390
description=function.__doc__ or "",
@@ -384,12 +396,14 @@ def __init__(self, **kwargs):
384396
inputs={
385397
input_key: {
386398
"description": "",
387-
"type": input_value,
399+
"type": (
400+
"Not specified" if input_value is None else input_value
401+
),
388402
"default": defaults_mapping.get(input_key, None),
389403
"required": defaults_mapping.get(input_key, None)
390404
is not None,
391405
}
392-
for input_key, input_value in function.__annotations__.items()
406+
for input_key, input_value in self._original_function_args.items()
393407
if input_key
394408
not in [
395409
"tree_data",
@@ -401,7 +415,6 @@ def __init__(self, **kwargs):
401415
},
402416
end=end,
403417
)
404-
self._original_function = function
405418

406419
async def __call__(
407420
self, tree_data, inputs, base_lm, complex_lm, client_manager, **kwargs
@@ -420,7 +433,7 @@ async def __call__(
420433
"client_manager": client_manager,
421434
**kwargs,
422435
}.items()
423-
if k in function.__annotations__
436+
if k in self._original_function_args
424437
},
425438
)
426439
]
@@ -437,7 +450,7 @@ async def __call__(
437450
"client_manager": client_manager,
438451
**kwargs,
439452
}.items()
440-
if k in function.__annotations__
453+
if k in self._original_function_args
441454
},
442455
):
443456
results.append(result)

elysia/preprocessing/collection.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ async def preprocess_async(
368368
collection_name: str,
369369
client_manager: ClientManager | None = None,
370370
min_sample_size: int = 10,
371-
max_sample_size: int = 20,
371+
max_sample_size: int | None = None,
372372
num_sample_tokens: int = 30000,
373373
force: bool = False,
374374
percentage_correct_threshold: float = 0.3,
@@ -443,9 +443,19 @@ async def preprocess_async(
443443
agg = await collection.aggregate.over_all(total_count=True)
444444
len_collection: int = agg.total_count # type: ignore
445445

446+
if max_sample_size is None and len_collection > 50_000:
447+
max_sample_size = 20
448+
logger.warning(
449+
f"Collection is large (greater than 50,000 objects), causing slowdown in pre-processing. "
450+
f"Reducing maximum sample size to {max_sample_size} objects. "
451+
"To override this, set `max_sample_size` as an argument to preprocess."
452+
)
453+
elif max_sample_size is None:
454+
max_sample_size = 50
455+
446456
# Randomly sample sample_size objects for the summary
447457
indices = random.sample(
448-
range(len_collection),
458+
range(min(99_999, len_collection)),
449459
max(min(max_sample_size, len_collection), 1),
450460
)
451461

@@ -455,17 +465,20 @@ async def preprocess_async(
455465
subset_objects: list[dict] = [obj.objects[0].properties] # type: ignore
456466

457467
# Get number of objects to sample to get close to num_sample_tokens
458-
num_sample_objects = max(min_sample_size, num_sample_tokens // token_count_0)
459-
460-
for index in indices[1:num_sample_objects]:
461-
obj = await collection.query.fetch_objects(limit=1, offset=index)
462-
subset_objects.append(obj.objects[0].properties) # type: ignore
468+
num_sample_objects = min(
469+
max(min_sample_size, num_sample_tokens // token_count_0),
470+
max_sample_size,
471+
)
463472

464473
# Estimate number of tokens
465474
logger.debug(
466-
f"Estimated token count of sample: {token_count_0*len(subset_objects)}"
475+
f"Estimated token count of sample: {token_count_0*num_sample_objects}"
467476
)
468-
logger.debug(f"Number of objects in sample: {len(subset_objects)}")
477+
logger.debug(f"Number of objects in sample: {num_sample_objects}")
478+
479+
for index in indices[1:num_sample_objects]:
480+
obj = await collection.query.fetch_objects(limit=1, offset=index)
481+
subset_objects.append(obj.objects[0].properties) # type: ignore
469482

470483
# Summarise the collection using LLM and the subset of the data
471484
summary, field_descriptions = await _summarise_collection(
@@ -481,7 +494,7 @@ async def preprocess_async(
481494
message="Generated summary of collection",
482495
)
483496

484-
if len_collection > max_sample_size:
497+
if len_collection > 10_000: # arbitrary cutoff for estimating field statistics
485498
full_response = subset_objects
486499
else:
487500
weaviate_resp = await collection.query.fetch_objects(limit=len_collection)
@@ -782,7 +795,7 @@ async def _preprocess_async(
782795
collection_names: list[str] | str,
783796
client_manager: ClientManager | None = None,
784797
min_sample_size: int = 10,
785-
max_sample_size: int = 20,
798+
max_sample_size: int | None = None,
786799
num_sample_tokens: int = 30000,
787800
settings: Settings = environment_settings,
788801
force: bool = False,
@@ -860,8 +873,8 @@ async def _preprocess_async(
860873
def preprocess(
861874
collection_names: str | list[str],
862875
client_manager: ClientManager | None = None,
863-
min_sample_size: int = 5,
864-
max_sample_size: int = 100,
876+
min_sample_size: int = 10,
877+
max_sample_size: int | None = None,
865878
num_sample_tokens: int = 30000,
866879
settings: Settings = environment_settings,
867880
force: bool = False,

elysia/util/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from logging import Logger
99

1010
import weaviate
11-
from weaviate.classes.init import Auth
11+
from weaviate.classes.init import Auth, Timeout
12+
from weaviate.config import AdditionalConfig
1213
from weaviate.client import WeaviateClient, WeaviateAsyncClient
1314
from elysia.config import settings as environment_settings, Settings
1415

@@ -67,6 +68,9 @@ def __init__(
6768
client_timeout: datetime.timedelta | int | None = None,
6869
logger: Logger | None = None,
6970
settings: Settings | None = None,
71+
query_timeout: int = 60,
72+
insert_timeout: int = 120,
73+
init_timeout: int = 5,
7074
**kwargs,
7175
) -> None:
7276
"""
@@ -76,6 +80,9 @@ def __init__(
7680
client_timeout (datetime.timedelta | int | None): how long (in minutes) means the client should be restarted. Defaults to 3 minutes.
7781
logger (Logger | None): a logger object for logging messages. Defaults to None.
7882
settings (Settings | None): a settings object for the client manager. Defaults to environment settings.
83+
query_timeout (int): the timeout for Weaviate queries. Defaults to 60 seconds (Weaviate default is 30 seconds).
84+
insert_timeout (int): the timeout for Weaviate inserts. Defaults to 120 seconds (Weaviate default is 90 seconds).
85+
init_timeout (int): the timeout for Weaviate initialisation. Defaults to 5 seconds (Weaviate default is 2 seconds).
7986
**kwargs (Any): any other api keys for third party services (formatted as e.g. OPENAI_APIKEY).
8087
8188
Example:
@@ -116,6 +123,10 @@ def __init__(
116123
else:
117124
self.wcd_api_key = wcd_api_key
118125

126+
self.query_timeout = query_timeout
127+
self.insert_timeout = insert_timeout
128+
self.init_timeout = init_timeout
129+
119130
# Set the api keys for non weaviate cluster (third parties)
120131
self.headers = {}
121132
for api_key in self.settings.API_KEYS:
@@ -244,6 +255,13 @@ def get_client(self) -> WeaviateClient:
244255
auth_credentials=Auth.api_key(self.wcd_api_key),
245256
headers=self.headers,
246257
skip_init_checks=True,
258+
additional_config=AdditionalConfig(
259+
timeout=Timeout(
260+
query=self.query_timeout,
261+
insert=self.insert_timeout,
262+
init=self.init_timeout,
263+
)
264+
),
247265
)
248266

249267
async def get_async_client(self) -> WeaviateAsyncClient:
@@ -255,6 +273,13 @@ async def get_async_client(self) -> WeaviateAsyncClient:
255273
auth_credentials=Auth.api_key(self.wcd_api_key),
256274
headers=self.headers,
257275
skip_init_checks=True,
276+
additional_config=AdditionalConfig(
277+
timeout=Timeout(
278+
query=self.query_timeout,
279+
insert=self.insert_timeout,
280+
init=self.init_timeout,
281+
)
282+
),
258283
)
259284

260285
@contextmanager

elysia/util/collection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ async def paginated_collection(
119119
):
120120
collection = client.collections.get(collection_name)
121121

122+
if (page_size * (page_number - 1) + page_size) > 99_999:
123+
raise ValueError(
124+
"Page size exceeds Weaviate's limit of 100,000 objects for using offset."
125+
)
126+
122127
filter_type = filter_config.get("type", "all")
123128
filters_list = filter_config.get("filters", [])
124129
filters = [f["field"] for f in filters_list]

tests/no_reqs/general/test_tools_nr.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,72 @@ async def example_async_decorator_tool_from_tree():
447447
assert "example_async_decorator_tool_from_tree" in tree.tools
448448

449449

450+
def test_decorator_tool_typed_inputs():
451+
452+
tree = Tree()
453+
454+
@tool(tree=tree)
455+
async def example_decorator_tool(x: int, y: int):
456+
return x + y
457+
458+
assert "example_decorator_tool" in tree.tools
459+
assert "x" in tree.tools["example_decorator_tool"].inputs
460+
assert "y" in tree.tools["example_decorator_tool"].inputs
461+
assert tree.tools["example_decorator_tool"].inputs["x"]["type"] is int
462+
assert tree.tools["example_decorator_tool"].inputs["y"]["type"] is int
463+
464+
465+
def test_decorator_tool_typed_inputs_with_default_inputs():
466+
467+
tree = Tree()
468+
469+
@tool(tree=tree)
470+
async def example_decorator_tool(x: int = 1, y: int = 2):
471+
return x + y
472+
473+
assert "example_decorator_tool" in tree.tools
474+
475+
assert "x" in tree.tools["example_decorator_tool"].inputs
476+
assert "y" in tree.tools["example_decorator_tool"].inputs
477+
assert tree.tools["example_decorator_tool"].inputs["x"]["type"] is int
478+
assert tree.tools["example_decorator_tool"].inputs["y"]["type"] is int
479+
assert tree.tools["example_decorator_tool"].inputs["x"]["default"] == 1
480+
assert tree.tools["example_decorator_tool"].inputs["y"]["default"] == 2
481+
482+
483+
def test_decorator_tool_untyped_inputs():
484+
485+
tree = Tree()
486+
487+
@tool(tree=tree)
488+
async def example_decorator_tool(x, y):
489+
return x + y
490+
491+
assert "example_decorator_tool" in tree.tools
492+
assert "x" in tree.tools["example_decorator_tool"].inputs
493+
assert "y" in tree.tools["example_decorator_tool"].inputs
494+
assert tree.tools["example_decorator_tool"].inputs["x"]["type"] == "Not specified"
495+
assert tree.tools["example_decorator_tool"].inputs["y"]["type"] == "Not specified"
496+
497+
498+
def test_decorator_with_elysia_inputs():
499+
tree = Tree()
500+
501+
@tool(tree=tree)
502+
async def example_decorator_tool(
503+
x: int, y: int, tree_data, base_lm, complex_lm, client_manager
504+
):
505+
return x + y
506+
507+
assert "example_decorator_tool" in tree.tools
508+
assert "x" in tree.tools["example_decorator_tool"].inputs
509+
assert "y" in tree.tools["example_decorator_tool"].inputs
510+
assert "tree_data" not in tree.tools["example_decorator_tool"].inputs
511+
assert "base_lm" not in tree.tools["example_decorator_tool"].inputs
512+
assert "complex_lm" not in tree.tools["example_decorator_tool"].inputs
513+
assert "client_manager" not in tree.tools["example_decorator_tool"].inputs
514+
515+
450516
@pytest.mark.asyncio
451517
async def test_add_tool_with_stem_tool():
452518
tree = Tree(

0 commit comments

Comments
 (0)