diff --git a/lilac/data/dataset_select_rows_sort_test.py b/lilac/data/dataset_select_rows_sort_test.py index 3498cd88..3c414610 100644 --- a/lilac/data/dataset_select_rows_sort_test.py +++ b/lilac/data/dataset_select_rows_sort_test.py @@ -473,7 +473,9 @@ def vector_compute( self, all_vector_spans: Iterable[list[SpanVector]] ) -> Iterator[Optional[Item]]: for vector_spans in all_vector_spans: - embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]) + embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]).reshape( + len(vector_spans), -1 + ) scores = embeddings.dot(self._query).reshape(-1) res: Item = [] for vector_span, score in zip(vector_spans, scores): diff --git a/lilac/embeddings/vector_store.py b/lilac/embeddings/vector_store.py index cb5058e2..efa8c54e 100644 --- a/lilac/embeddings/vector_store.py +++ b/lilac/embeddings/vector_store.py @@ -3,7 +3,7 @@ import abc import os import pickle -from typing import Iterable, Optional, Sequence, Type, cast +from typing import Iterable, Iterator, Optional, Sequence, Type, cast import numpy as np @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @abc.abstractmethod - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: """Return the embeddings for given keys. Args: @@ -159,13 +159,10 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]: all_spans.append(spans) all_vector_keys.append([(*path_key, i) for i in range(len(spans))]) - offset = 0 flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])] all_vectors = self._vector_store.get(flat_vector_keys) for spans in all_spans: - vectors = all_vectors[offset : offset + len(spans)] - yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)] - offset += len(spans) + yield [{'span': span, 'vector': next(all_vectors)} for span in spans] def topk( self, query: np.ndarray, k: int, rowids: Optional[Iterable[str]] = None diff --git a/lilac/embeddings/vector_store_hnsw.py b/lilac/embeddings/vector_store_hnsw.py index 0eadccfb..db8613c5 100644 --- a/lilac/embeddings/vector_store_hnsw.py +++ b/lilac/embeddings/vector_store_hnsw.py @@ -3,7 +3,7 @@ import multiprocessing import os import threading -from typing import Iterable, Optional, Set, cast +from typing import Iterable, Iterator, Optional, Set, cast import hnswlib import numpy as np @@ -11,7 +11,7 @@ from typing_extensions import override from ..schema import VectorKey -from ..utils import DebugTimer +from ..utils import DebugTimer, chunks from .vector_store import VectorStore _HNSW_SUFFIX = '.hnswlib.bin' @@ -22,6 +22,8 @@ CONSTRUCTION_EF = 100 M = 16 SPACE = 'ip' +# The number of items to retrieve at a time given a query of keys. +HNSW_RETRIEVAL_BATCH_SIZE = 1024 class HNSWVectorStore(VectorStore): @@ -105,15 +107,20 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: self._index.set_ef(min(QUERY_EF, self.size())) @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: assert ( self._index is not None and self._key_to_label is not None ), 'No embeddings exist in this store.' with self._lock: if not keys: - return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32) - locs = self._key_to_label.loc[cast(list[str], keys)].values - return np.array(self._index.get_items(locs), dtype=np.float32) + locs = self._key_to_label.values + else: + locs = self._key_to_label.loc[cast(list[str], keys)].values + + for loc_chunk in chunks(locs, HNSW_RETRIEVAL_BATCH_SIZE): + chunk_items = np.array(self._index.get_items(loc_chunk), dtype=np.float32) + for vector in np.split(chunk_items, chunk_items.shape[0]): + yield np.squeeze(vector) @override def topk( diff --git a/lilac/embeddings/vector_store_numpy.py b/lilac/embeddings/vector_store_numpy.py index 95e59b31..ce0696cb 100644 --- a/lilac/embeddings/vector_store_numpy.py +++ b/lilac/embeddings/vector_store_numpy.py @@ -1,7 +1,7 @@ """NumpyVectorStore class for storing vectors in numpy arrays.""" import os -from typing import Iterable, Optional, cast +from typing import Iterable, Iterator, Optional, cast import numpy as np import pandas as pd @@ -73,14 +73,18 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: self._key_to_index = new_key_to_label @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: assert ( self._embeddings is not None and self._key_to_index is not None ), 'The vector store has no embeddings. Call load() or add() first.' if not keys: - return self._embeddings - locs = self._key_to_index.loc[cast(list[str], keys)] - return self._embeddings.take(locs, axis=0) + embeddings = self._embeddings + else: + locs = self._key_to_index.loc[cast(list[str], keys)] + embeddings = self._embeddings.take(locs, axis=0) + + for vector in np.split(embeddings, embeddings.shape[0]): + yield np.squeeze(vector) @override def topk( diff --git a/lilac/embeddings/vector_store_test.py b/lilac/embeddings/vector_store_test.py index a59b867a..5f68ea1f 100644 --- a/lilac/embeddings/vector_store_test.py +++ b/lilac/embeddings/vector_store_test.py @@ -22,25 +22,32 @@ def test_add_chunks(self, store_cls: Type[VectorStore]) -> None: store.add([('a',), ('b',)], np.array([[1, 2], [3, 4]])) store.add([('c',)], np.array([[5, 6]])) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_get_all(self, store_cls: Type[VectorStore]) -> None: store = store_cls() store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_get_subset(self, store_cls: Type[VectorStore]) -> None: store = store_cls() store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - np.testing.assert_array_equal(store.get([('b',), ('c',)]), np.array([[3, 4], [5, 6]])) + vectors = list(store.get([('b',), ('c',)])) + assert len(vectors) == 2 + np.testing.assert_array_equal(vectors[0], [3, 4]) + np.testing.assert_array_equal(vectors[1], [5, 6]) def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -> None: store = store_cls() @@ -54,9 +61,11 @@ def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) - store = store_cls() store.load((str(tmp_path))) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_topk(self, store_cls: Type[VectorStore]) -> None: store = store_cls() diff --git a/lilac/signals/semantic_similarity_test.py b/lilac/signals/semantic_similarity_test.py index 2a54a9d1..cc7c45dd 100644 --- a/lilac/signals/semantic_similarity_test.py +++ b/lilac/signals/semantic_similarity_test.py @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: keys = keys or [] - return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys]) + yield from [ + np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys + ] @override def delete(self, base_path: str) -> None: