Skip to content

Commit d9b3c6c

Browse files
fix: better recognition of vectorisers for chunked collections
1 parent 6eac264 commit d9b3c6c

File tree

2 files changed

+182
-39
lines changed

2 files changed

+182
-39
lines changed

elysia/tools/retrieval/chunk.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ async def get_vectoriser(
194194
self.collection_name
195195
).config.get()
196196

197-
# see if any named vectors exclusively vectorise the content field
198197
if collection_config.vector_config is not None:
199198
for (
200199
named_vector_name,
@@ -203,9 +202,15 @@ async def get_vectoriser(
203202

204203
named_vectorizer = named_vector_config.vectorizer
205204

206-
if (
205+
if ( # default weaviate naming "default" or a single named vector
206+
(
207+
named_vector_name == "default"
208+
or len(collection_config.vector_config) == 1
209+
)
210+
and named_vectorizer.source_properties is None
211+
) or ( # named vector with source properties which includes the content field
207212
named_vectorizer.source_properties is not None
208-
and named_vectorizer.source_properties == [content_field]
213+
and content_field in named_vectorizer.source_properties
209214
):
210215
try:
211216
vectorizer = getattr(
@@ -229,27 +234,20 @@ async def get_vectoriser(
229234
except AttributeError as e:
230235
pass
231236

232-
# if we haven't returned yet, try the overall vectoriser
233-
if collection_config.vectorizer_config is not None:
234-
try:
235-
vectorizer_name = (
236-
collection_config.vectorizer_config.vectorizer.replace("-", "_")
237-
)
238-
vectorizer = getattr(Configure.Vectors, vectorizer_name)
239-
valid_args = inspect.signature(vectorizer).parameters
240-
241-
return vectorizer(
242-
**{
243-
arg: collection_config.vectorizer_config.model[arg]
244-
for arg in collection_config.vectorizer_config.model
245-
if arg in valid_args and arg != "vector_index_config"
246-
},
247-
vector_index_config=Configure.VectorIndex.hnsw(
248-
quantizer=Configure.VectorIndex.Quantizer.sq() # scalar quantization
249-
),
250-
)
251-
except AttributeError as e:
252-
pass
237+
# check old vectorizer config
238+
elif collection_config.vectorizer_config is not None:
239+
vectorizer_name = collection_config.vectorizer_config.vectorizer.replace(
240+
"-", "_"
241+
)
242+
vectorizer = getattr(Configure.Vectors, vectorizer_name)
243+
valid_args = inspect.signature(vectorizer).parameters
244+
return vectorizer(
245+
**{
246+
arg: collection_config.vectorizer_config.model[arg]
247+
for arg in collection_config.vectorizer_config.model
248+
if arg in valid_args and arg != "vector_index_config"
249+
}
250+
)
253251

254252
# otherwise use default weaviate embedding service
255253
return Configure.Vectors.text2vec_weaviate(

tests/requires_env/general/test_chunking.py

Lines changed: 160 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717

1818
@pytest.mark.asyncio
19-
async def test_correct_vectoriser():
19+
async def test_correct_vectoriser_old():
2020
collection_name_full = "Test_ELYSIA_collection_chunking_vectoriser_full"
2121
collection_name_named = "Test_ELYSIA_collection_chunking_vectoriser_named"
2222
client_manager = ClientManager()
2323

2424
# example data
2525
data = [
2626
{
27-
"content": (
27+
"random_content_field": (
2828
"Lorem ipsum dolor sit amet consectetur adipiscing elit. "
2929
"Quisque faucibus ex sapien vitae pellentesque sem placerat. "
3030
"In id cursus mi pretium tellus duis convallis. "
@@ -37,7 +37,7 @@ async def test_correct_vectoriser():
3737
"other_field": "other_value",
3838
},
3939
{
40-
"content": (
40+
"random_content_field": (
4141
"Lorem ipsum dolor sit sit amet consectetur adipiscing elit. "
4242
"Quisque faucibus ex sapien sed vitae pellentesque sem placerat. "
4343
"In id cursus mi pretium tellus sed duis convallis. "
@@ -53,8 +53,8 @@ async def test_correct_vectoriser():
5353

5454
# create collection with vectoriser
5555
full_vectoriser = Configure.Vectorizer.text2vec_openai(
56-
model="text-embedding-3-large",
57-
dimensions=256,
56+
model="text-embedding-3-small",
57+
dimensions=512,
5858
)
5959
async with client_manager.connect_to_async_client() as client:
6060
collection_full = await client.collections.create(
@@ -64,8 +64,11 @@ async def test_correct_vectoriser():
6464
await collection_full.data.insert_many(data)
6565

6666
# create collection with named vectoriser
67-
named_vectoriser = Configure.NamedVectors.text2vec_jinaai(
68-
name="content_vector", source_properties=["content"], model="jina-embeddings-v3"
67+
named_vectoriser = Configure.NamedVectors.text2vec_openai(
68+
name="content_vector",
69+
source_properties=["random_content_field"],
70+
model="text-embedding-3-small",
71+
dimensions=512,
6972
)
7073
async with client_manager.connect_to_async_client() as client:
7174
collection_named = await client.collections.create(
@@ -79,36 +82,172 @@ async def test_correct_vectoriser():
7982
try:
8083
# do chunking on full collection
8184
collection_chunker = AsyncCollectionChunker(collection_name_full)
82-
await collection_chunker.create_chunked_reference("content", client_manager)
85+
await collection_chunker.create_chunked_reference(
86+
"random_content_field", client_manager
87+
)
8388

8489
async with client_manager.connect_to_async_client() as client:
8590
# check existence of collection
86-
assert await client.collections.exists(collection_name_full)
91+
assert await client.collections.exists(
92+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
93+
)
8794

8895
# check vectoriser
8996
collection_full_config = await client.collections.get(
90-
collection_name_full
97+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
9198
).config.get()
99+
assert "default" in collection_full_config.vector_config
92100
assert (
93-
collection_full_config.vectorizer_config.vectorizer
101+
collection_full_config.vector_config["default"].vectorizer.vectorizer
94102
== full_vectoriser.vectorizer
95103
)
96104

97105
# check named vectoriser
98106
collection_chunker = AsyncCollectionChunker(collection_name_named)
99-
await collection_chunker.create_chunked_reference("content", client_manager)
107+
await collection_chunker.create_chunked_reference(
108+
"random_content_field", client_manager
109+
)
100110

101111
async with client_manager.connect_to_async_client() as client:
102112
# check existence of collection
103-
assert await client.collections.exists(collection_name_named)
113+
assert await client.collections.exists(
114+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
115+
)
104116

105117
# check vectoriser
106118
collection_named_config = await client.collections.get(
107-
collection_name_named
119+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
120+
).config.get()
121+
assert (
122+
list(collection_named_config.vector_config.keys())[0]
123+
in collection_named_config.vector_config
124+
)
125+
assert (
126+
collection_named_config.vector_config[
127+
list(collection_named_config.vector_config.keys())[0]
128+
].vectorizer.vectorizer
129+
== named_vectoriser.vectorizer.vectorizer
130+
)
131+
132+
finally:
133+
async with client_manager.connect_to_async_client() as client:
134+
await client.collections.delete(collection_name_full)
135+
await client.collections.delete(collection_name_named)
136+
await client.collections.delete(
137+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
138+
)
139+
await client.collections.delete(
140+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
141+
)
142+
143+
await client_manager.close_clients()
144+
145+
146+
@pytest.mark.asyncio
147+
async def test_correct_vectoriser_new():
148+
collection_name_full = "Test_ELYSIA_collection_chunking_vectoriser_full_new"
149+
collection_name_named = "Test_ELYSIA_collection_chunking_vectoriser_named_new"
150+
client_manager = ClientManager()
151+
152+
# example data
153+
data = [
154+
{
155+
"random_content_field": (
156+
"Lorem ipsum dolor sit amet consectetur adipiscing elit. "
157+
"Quisque faucibus ex sapien vitae pellentesque sem placerat. "
158+
"In id cursus mi pretium tellus duis convallis. "
159+
"Tempus leo eu aenean sed diam urna tempor. "
160+
"Pulvinar vivamus fringilla lacus nec metus bibendum egestas. "
161+
"Iaculis massa nisl malesuada lacinia integer nunc posuere. "
162+
"Ut hendrerit semper vel class aptent taciti sociosqu. "
163+
"Ad litora torquent per conubia nostra inceptos himenaeos."
164+
),
165+
"other_field": "other_value",
166+
},
167+
{
168+
"random_content_field": (
169+
"Lorem ipsum dolor sit sit amet consectetur adipiscing elit. "
170+
"Quisque faucibus ex sapien sed vitae pellentesque sem placerat. "
171+
"In id cursus mi pretium tellus sed duis convallis. "
172+
"Tempus leo eu aenean sed diam urna tempor. "
173+
"Pulvinar vivamus fringilla la cus nec metus bibendum egestas. "
174+
"Iaculis massa nisl malesuada lacinia integer nunc posuere. "
175+
"Ut hendrerit semper vel class ooga aptent taciti sociosqu. "
176+
"Ad litora torquent per conubia booga nostra inceptos himenaeos."
177+
),
178+
"other_field": "other_value_2",
179+
},
180+
]
181+
182+
# create collection with vectoriser
183+
full_vectoriser = Configure.Vectors.text2vec_openai(
184+
model="text-embedding-3-large",
185+
dimensions=512,
186+
)
187+
async with client_manager.connect_to_async_client() as client:
188+
collection_full = await client.collections.create(
189+
collection_name_full,
190+
vector_config=full_vectoriser,
191+
)
192+
await collection_full.data.insert_many(data)
193+
194+
named_vectoriser = Configure.Vectors.text2vec_openai(
195+
name="content_vector",
196+
source_properties=["random_content_field"],
197+
model="text-embedding-3-small",
198+
)
199+
# create collection with named vectoriser
200+
async with client_manager.connect_to_async_client() as client:
201+
collection_named = await client.collections.create(
202+
collection_name_named,
203+
vector_config=[
204+
named_vectoriser,
205+
],
206+
)
207+
await collection_named.data.insert_many(data)
208+
209+
try:
210+
# do chunking on full collection
211+
collection_chunker = AsyncCollectionChunker(collection_name_full)
212+
await collection_chunker.create_chunked_reference(
213+
"random_content_field", client_manager
214+
)
215+
216+
async with client_manager.connect_to_async_client() as client:
217+
# check existence of collection
218+
assert await client.collections.exists(
219+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
220+
)
221+
222+
# check vectoriser
223+
collection_full_config = await client.collections.get(
224+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
225+
).config.get()
226+
assert "default" in collection_full_config.vector_config
227+
assert (
228+
collection_full_config.vector_config["default"].vectorizer.vectorizer
229+
== full_vectoriser.vectorizer.vectorizer
230+
)
231+
232+
# check named vectoriser
233+
collection_chunker = AsyncCollectionChunker(collection_name_named)
234+
await collection_chunker.create_chunked_reference(
235+
"random_content_field", client_manager
236+
)
237+
238+
async with client_manager.connect_to_async_client() as client:
239+
# check existence of collection
240+
assert await client.collections.exists(
241+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
242+
)
243+
244+
# check vectoriser
245+
collection_named_config = await client.collections.get(
246+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
108247
).config.get()
109248
assert collection_named_config.vector_config[
110249
list(collection_named_config.vector_config.keys())[0]
111-
].vectorizer.source_properties == ["content"]
250+
].vectorizer.source_properties == ["random_content_field"]
112251
assert (
113252
collection_named_config.vector_config[
114253
list(collection_named_config.vector_config.keys())[0]
@@ -120,5 +259,11 @@ async def test_correct_vectoriser():
120259
async with client_manager.connect_to_async_client() as client:
121260
await client.collections.delete(collection_name_full)
122261
await client.collections.delete(collection_name_named)
262+
await client.collections.delete(
263+
f"ELYSIA_CHUNKED_{collection_name_full.lower()}__"
264+
)
265+
await client.collections.delete(
266+
f"ELYSIA_CHUNKED_{collection_name_named.lower()}__"
267+
)
123268

124269
await client_manager.close_clients()

0 commit comments

Comments
 (0)