Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug with excess RAM usage during vector computes. #1195

Merged
merged 3 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lilac/data/dataset_select_rows_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ def vector_compute(
self, all_vector_spans: Iterable[list[SpanVector]]
) -> Iterator[Optional[Item]]:
for vector_spans in all_vector_spans:
embeddings = np.array([vector_span['vector'] for vector_span in vector_spans])
embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]).reshape(
len(vector_spans), -1
)
scores = embeddings.dot(self._query).reshape(-1)
res: Item = []
for vector_span, score in zip(vector_spans, scores):
Expand Down
9 changes: 3 additions & 6 deletions lilac/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import os
import pickle
from typing import Iterable, Optional, Sequence, Type, cast
from typing import Iterable, Iterator, Optional, Sequence, Type, cast

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@abc.abstractmethod
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
"""Return the embeddings for given keys.

Args:
Expand Down Expand Up @@ -159,13 +159,10 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
all_spans.append(spans)
all_vector_keys.append([(*path_key, i) for i in range(len(spans))])

offset = 0
flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])]
all_vectors = self._vector_store.get(flat_vector_keys)
for spans in all_spans:
vectors = all_vectors[offset : offset + len(spans)]
yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
offset += len(spans)
yield [{'span': span, 'vector': next(all_vectors)} for span in spans]

def topk(
self, query: np.ndarray, k: int, rowids: Optional[Iterable[str]] = None
Expand Down
19 changes: 13 additions & 6 deletions lilac/embeddings/vector_store_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import multiprocessing
import os
import threading
from typing import Iterable, Optional, Set, cast
from typing import Iterable, Iterator, Optional, Set, cast

import hnswlib
import numpy as np
import pandas as pd
from typing_extensions import override

from ..schema import VectorKey
from ..utils import DebugTimer
from ..utils import DebugTimer, chunks
from .vector_store import VectorStore

_HNSW_SUFFIX = '.hnswlib.bin'
Expand All @@ -22,6 +22,8 @@
CONSTRUCTION_EF = 100
M = 16
SPACE = 'ip'
# The number of items to retrieve at a time given a query of keys.
HNSW_RETRIEVAL_BATCH_SIZE = 1024


class HNSWVectorStore(VectorStore):
Expand Down Expand Up @@ -105,15 +107,20 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
self._index.set_ef(min(QUERY_EF, self.size()))

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
assert (
self._index is not None and self._key_to_label is not None
), 'No embeddings exist in this store.'
with self._lock:
if not keys:
return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32)
locs = self._key_to_label.loc[cast(list[str], keys)].values
return np.array(self._index.get_items(locs), dtype=np.float32)
locs = self._key_to_label.values
else:
locs = self._key_to_label.loc[cast(list[str], keys)].values

for loc_chunk in chunks(locs, HNSW_RETRIEVAL_BATCH_SIZE):
chunk_items = np.array(self._index.get_items(loc_chunk), dtype=np.float32)
for vector in np.split(chunk_items, chunk_items.shape[0]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably simpler/shorter:

for i in range(chunk_items.shape[0]):
  yield chunk_items[i]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use split here since it doesn't make a copy

yield np.squeeze(vector)

@override
def topk(
Expand Down
14 changes: 9 additions & 5 deletions lilac/embeddings/vector_store_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""NumpyVectorStore class for storing vectors in numpy arrays."""

import os
from typing import Iterable, Optional, cast
from typing import Iterable, Iterator, Optional, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -73,14 +73,18 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
self._key_to_index = new_key_to_label

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
assert (
self._embeddings is not None and self._key_to_index is not None
), 'The vector store has no embeddings. Call load() or add() first.'
if not keys:
return self._embeddings
locs = self._key_to_index.loc[cast(list[str], keys)]
return self._embeddings.take(locs, axis=0)
embeddings = self._embeddings
else:
locs = self._key_to_index.loc[cast(list[str], keys)]
embeddings = self._embeddings.take(locs, axis=0)

for vector in np.split(embeddings, embeddings.shape[0]):
yield np.squeeze(vector)

@override
def topk(
Expand Down
29 changes: 19 additions & 10 deletions lilac/embeddings/vector_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,32 @@ def test_add_chunks(self, store_cls: Type[VectorStore]) -> None:
store.add([('a',), ('b',)], np.array([[1, 2], [3, 4]]))
store.add([('c',)], np.array([[5, 6]]))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_get_all(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()

store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]]))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_get_subset(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()

store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]]))

np.testing.assert_array_equal(store.get([('b',), ('c',)]), np.array([[3, 4], [5, 6]]))
vectors = list(store.get([('b',), ('c',)]))
assert len(vectors) == 2
np.testing.assert_array_equal(vectors[0], [3, 4])
np.testing.assert_array_equal(vectors[1], [5, 6])

def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -> None:
store = store_cls()
Expand All @@ -54,9 +61,11 @@ def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -
store = store_cls()
store.load((str(tmp_path)))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_topk(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()
Expand Down
6 changes: 4 additions & 2 deletions lilac/signals/semantic_similarity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
keys = keys or []
return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys])
yield from [
np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys
]

@override
def delete(self, base_path: str) -> None:
Expand Down
Loading