From 5d84c76b894cc911f0321d97d467777aeca63869 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Thu, 29 Feb 2024 12:55:38 -0500 Subject: [PATCH] Fix a bug with excess RAM usage during vector computes. (#1195) --- lilac/data/dataset_select_rows_sort_test.py | 4 ++- lilac/embeddings/vector_store.py | 9 +++---- lilac/embeddings/vector_store_hnsw.py | 19 +++++++++----- lilac/embeddings/vector_store_numpy.py | 14 ++++++---- lilac/embeddings/vector_store_test.py | 29 ++++++++++++++------- lilac/signals/semantic_similarity_test.py | 6 +++-- 6 files changed, 51 insertions(+), 30 deletions(-) diff --git a/lilac/data/dataset_select_rows_sort_test.py b/lilac/data/dataset_select_rows_sort_test.py index 3498cd88f..3c414610a 100644 --- a/lilac/data/dataset_select_rows_sort_test.py +++ b/lilac/data/dataset_select_rows_sort_test.py @@ -473,7 +473,9 @@ def vector_compute( self, all_vector_spans: Iterable[list[SpanVector]] ) -> Iterator[Optional[Item]]: for vector_spans in all_vector_spans: - embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]) + embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]).reshape( + len(vector_spans), -1 + ) scores = embeddings.dot(self._query).reshape(-1) res: Item = [] for vector_span, score in zip(vector_spans, scores): diff --git a/lilac/embeddings/vector_store.py b/lilac/embeddings/vector_store.py index cb5058e28..efa8c54ee 100644 --- a/lilac/embeddings/vector_store.py +++ b/lilac/embeddings/vector_store.py @@ -3,7 +3,7 @@ import abc import os import pickle -from typing import Iterable, Optional, Sequence, Type, cast +from typing import Iterable, Iterator, Optional, Sequence, Type, cast import numpy as np @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @abc.abstractmethod - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: """Return the embeddings for given keys. Args: @@ -159,13 +159,10 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]: all_spans.append(spans) all_vector_keys.append([(*path_key, i) for i in range(len(spans))]) - offset = 0 flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])] all_vectors = self._vector_store.get(flat_vector_keys) for spans in all_spans: - vectors = all_vectors[offset : offset + len(spans)] - yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)] - offset += len(spans) + yield [{'span': span, 'vector': next(all_vectors)} for span in spans] def topk( self, query: np.ndarray, k: int, rowids: Optional[Iterable[str]] = None diff --git a/lilac/embeddings/vector_store_hnsw.py b/lilac/embeddings/vector_store_hnsw.py index 0eadccfb9..db8613c51 100644 --- a/lilac/embeddings/vector_store_hnsw.py +++ b/lilac/embeddings/vector_store_hnsw.py @@ -3,7 +3,7 @@ import multiprocessing import os import threading -from typing import Iterable, Optional, Set, cast +from typing import Iterable, Iterator, Optional, Set, cast import hnswlib import numpy as np @@ -11,7 +11,7 @@ from typing_extensions import override from ..schema import VectorKey -from ..utils import DebugTimer +from ..utils import DebugTimer, chunks from .vector_store import VectorStore _HNSW_SUFFIX = '.hnswlib.bin' @@ -22,6 +22,8 @@ CONSTRUCTION_EF = 100 M = 16 SPACE = 'ip' +# The number of items to retrieve at a time given a query of keys. +HNSW_RETRIEVAL_BATCH_SIZE = 1024 class HNSWVectorStore(VectorStore): @@ -105,15 +107,20 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: self._index.set_ef(min(QUERY_EF, self.size())) @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: assert ( self._index is not None and self._key_to_label is not None ), 'No embeddings exist in this store.' with self._lock: if not keys: - return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32) - locs = self._key_to_label.loc[cast(list[str], keys)].values - return np.array(self._index.get_items(locs), dtype=np.float32) + locs = self._key_to_label.values + else: + locs = self._key_to_label.loc[cast(list[str], keys)].values + + for loc_chunk in chunks(locs, HNSW_RETRIEVAL_BATCH_SIZE): + chunk_items = np.array(self._index.get_items(loc_chunk), dtype=np.float32) + for vector in np.split(chunk_items, chunk_items.shape[0]): + yield np.squeeze(vector) @override def topk( diff --git a/lilac/embeddings/vector_store_numpy.py b/lilac/embeddings/vector_store_numpy.py index 95e59b31e..ce0696cb7 100644 --- a/lilac/embeddings/vector_store_numpy.py +++ b/lilac/embeddings/vector_store_numpy.py @@ -1,7 +1,7 @@ """NumpyVectorStore class for storing vectors in numpy arrays.""" import os -from typing import Iterable, Optional, cast +from typing import Iterable, Iterator, Optional, cast import numpy as np import pandas as pd @@ -73,14 +73,18 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: self._key_to_index = new_key_to_label @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: assert ( self._embeddings is not None and self._key_to_index is not None ), 'The vector store has no embeddings. Call load() or add() first.' if not keys: - return self._embeddings - locs = self._key_to_index.loc[cast(list[str], keys)] - return self._embeddings.take(locs, axis=0) + embeddings = self._embeddings + else: + locs = self._key_to_index.loc[cast(list[str], keys)] + embeddings = self._embeddings.take(locs, axis=0) + + for vector in np.split(embeddings, embeddings.shape[0]): + yield np.squeeze(vector) @override def topk( diff --git a/lilac/embeddings/vector_store_test.py b/lilac/embeddings/vector_store_test.py index a59b867a4..5f68ea1f1 100644 --- a/lilac/embeddings/vector_store_test.py +++ b/lilac/embeddings/vector_store_test.py @@ -22,25 +22,32 @@ def test_add_chunks(self, store_cls: Type[VectorStore]) -> None: store.add([('a',), ('b',)], np.array([[1, 2], [3, 4]])) store.add([('c',)], np.array([[5, 6]])) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_get_all(self, store_cls: Type[VectorStore]) -> None: store = store_cls() store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_get_subset(self, store_cls: Type[VectorStore]) -> None: store = store_cls() store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - np.testing.assert_array_equal(store.get([('b',), ('c',)]), np.array([[3, 4], [5, 6]])) + vectors = list(store.get([('b',), ('c',)])) + assert len(vectors) == 2 + np.testing.assert_array_equal(vectors[0], [3, 4]) + np.testing.assert_array_equal(vectors[1], [5, 6]) def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -> None: store = store_cls() @@ -54,9 +61,11 @@ def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) - store = store_cls() store.load((str(tmp_path))) - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]]) - ) + vectors = list(store.get([('a',), ('b',), ('c',)])) + assert len(vectors) == 3 + np.testing.assert_array_equal(vectors[0], [1, 2]) + np.testing.assert_array_equal(vectors[1], [3, 4]) + np.testing.assert_array_equal(vectors[2], [5, 6]) def test_topk(self, store_cls: Type[VectorStore]) -> None: store = store_cls() diff --git a/lilac/signals/semantic_similarity_test.py b/lilac/signals/semantic_similarity_test.py index 2a54a9d1c..cc7c45dd7 100644 --- a/lilac/signals/semantic_similarity_test.py +++ b/lilac/signals/semantic_similarity_test.py @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: keys = keys or [] - return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys]) + yield from [ + np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys + ] @override def delete(self, base_path: str) -> None: