-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[Improvements] Manage segment cache and memory #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dd2331d
877d712
cb15ae4
48b9cfe
db78fa1
06e7a84
1f30cbc
9e909ef
8a6f537
097cc51
42fbc6d
2386cdd
ba45af9
9724918
034940c
15cc717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,6 @@ | |
| DEFAULT_TENANT = "default_tenant" | ||
| DEFAULT_DATABASE = "default_database" | ||
|
|
||
|
|
||
| class Settings(BaseSettings): # type: ignore | ||
| environment: str = "" | ||
|
|
||
|
|
@@ -116,6 +115,9 @@ class Settings(BaseSettings): # type: ignore | |
| is_persistent: bool = False | ||
| persist_directory: str = "./chroma" | ||
|
|
||
| chroma_memory_limit_bytes: int = 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do I turn this capability on and off? Is 0 implicitly off?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. 0 is unlimited |
||
| chroma_segment_cache_policy: Optional[str] = None | ||
|
|
||
| chroma_server_host: Optional[str] = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we introduce a config - called segment_manager_cache_policy and make this one of many types? |
||
| chroma_server_headers: Optional[Dict[str, str]] = None | ||
| chroma_server_http_port: Optional[str] = None | ||
|
|
@@ -313,6 +315,15 @@ def __init__(self, settings: Settings): | |
| if settings[key] is not None: | ||
| raise ValueError(LEGACY_ERROR) | ||
|
|
||
| if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU": | ||
| logger.error( | ||
| f"Failed to set chroma_segment_cache_policy: Only LRU is available." | ||
| ) | ||
| if settings["chroma_memory_limit_bytes"] == 0: | ||
| logger.error( | ||
| f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require." | ||
| ) | ||
|
|
||
| # Apply the nofile limit if set | ||
| if settings["chroma_server_nofile"] is not None: | ||
| if platform.system() != "Windows": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import uuid | ||
| from typing import Any, Callable | ||
| from chromadb.types import Segment | ||
| from overrides import override | ||
| from typing import Dict, Optional | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| class SegmentCache(ABC): | ||
| @abstractmethod | ||
| def get(self, key: uuid.UUID) -> Optional[Segment]: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def pop(self, key: uuid.UUID) -> Optional[Segment]: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def set(self, key: uuid.UUID, value: Segment) -> None: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def reset(self) -> None: | ||
| pass | ||
|
|
||
|
|
||
| class BasicCache(SegmentCache): | ||
| def __init__(self): | ||
| self.cache:Dict[uuid.UUID, Segment] = {} | ||
|
|
||
| @override | ||
| def get(self, key: uuid.UUID) -> Optional[Segment]: | ||
| return self.cache.get(key) | ||
|
|
||
| @override | ||
| def pop(self, key: uuid.UUID) -> Optional[Segment]: | ||
| return self.cache.pop(key, None) | ||
|
|
||
| @override | ||
| def set(self, key: uuid.UUID, value: Segment) -> None: | ||
| self.cache[key] = value | ||
|
|
||
| @override | ||
| def reset(self) -> None: | ||
| self.cache = {} | ||
|
|
||
|
|
||
| class SegmentLRUCache(BasicCache): | ||
| """A simple LRU cache implementation that handles objects with dynamic sizes. | ||
| The size of each object is determined by a user-provided size function.""" | ||
|
|
||
| def __init__(self, capacity: int, size_func: Callable[[uuid.UUID], int], | ||
| callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None): | ||
| self.capacity = capacity | ||
| self.size_func = size_func | ||
| self.cache: Dict[uuid.UUID, Segment] = {} | ||
| self.history = [] | ||
| self.callback = callback | ||
|
|
||
| def _upsert_key(self, key: uuid.UUID): | ||
| if key in self.history: | ||
| self.history.remove(key) | ||
| self.history.append(key) | ||
| else: | ||
| self.history.append(key) | ||
|
|
||
| @override | ||
| def get(self, key: uuid.UUID) -> Optional[Segment]: | ||
| self._upsert_key(key) | ||
| if key in self.cache: | ||
| return self.cache[key] | ||
| else: | ||
| return None | ||
|
|
||
| @override | ||
| def pop(self, key: uuid.UUID) -> Optional[Segment]: | ||
| if key in self.history: | ||
| self.history.remove(key) | ||
| return self.cache.pop(key, None) | ||
|
|
||
|
|
||
| @override | ||
| def set(self, key: uuid.UUID, value: Segment) -> None: | ||
| if key in self.cache: | ||
| return | ||
| item_size = self.size_func(key) | ||
| key_sizes = {key: self.size_func(key) for key in self.cache} | ||
| total_size = sum(key_sizes.values()) | ||
| index = 0 | ||
| # Evict items if capacity is exceeded | ||
| while total_size + item_size > self.capacity and len(self.history) > index: | ||
| key_delete = self.history[index] | ||
| if key_delete in self.cache: | ||
| self.callback(key_delete, self.cache[key_delete]) | ||
| del self.cache[key_delete] | ||
| total_size -= key_sizes[key_delete] | ||
| index += 1 | ||
|
|
||
| self.cache[key] = value | ||
| self._upsert_key(key) | ||
|
|
||
| @override | ||
| def reset(self): | ||
| self.cache = {} | ||
| self.history = [] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,10 @@ | |
| VectorReader, | ||
| S, | ||
| ) | ||
| import logging | ||
| from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache | ||
| import os | ||
|
|
||
| from chromadb.config import System, get_class | ||
| from chromadb.db.system import SysDB | ||
| from overrides import override | ||
|
|
@@ -21,24 +25,23 @@ | |
| from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata | ||
| from typing import Dict, Type, Sequence, Optional, cast | ||
| from uuid import UUID, uuid4 | ||
| from collections import defaultdict | ||
| import platform | ||
|
|
||
| from chromadb.utils.lru_cache import LRUCache | ||
| from chromadb.utils.directory import get_directory_size | ||
|
|
||
|
|
||
| if platform.system() != "Windows": | ||
| import resource | ||
| elif platform.system() == "Windows": | ||
| import ctypes | ||
|
|
||
|
|
||
| SEGMENT_TYPE_IMPLS = { | ||
| SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", | ||
| SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", | ||
| SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", | ||
| } | ||
|
|
||
|
|
||
| class LocalSegmentManager(SegmentManager): | ||
| _sysdb: SysDB | ||
| _system: System | ||
|
|
@@ -47,9 +50,6 @@ class LocalSegmentManager(SegmentManager): | |
| _vector_instances_file_handle_cache: LRUCache[ | ||
| UUID, PersistentLocalHnswSegment | ||
| ] # LRU cache to manage file handles across vector segment instances | ||
| _segment_cache: Dict[ | ||
| UUID, Dict[SegmentScope, Segment] | ||
| ] # Tracks which segments are loaded for a given collection | ||
| _vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY | ||
| _lock: Lock | ||
| _max_file_handles: int | ||
|
|
@@ -59,8 +59,17 @@ def __init__(self, system: System): | |
| self._sysdb = self.require(SysDB) | ||
| self._system = system | ||
| self._opentelemetry_client = system.require(OpenTelemetryClient) | ||
| self.logger = logging.getLogger(__name__) | ||
| self._instances = {} | ||
| self._segment_cache = defaultdict(dict) | ||
| self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()} | ||
| if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0: | ||
| self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k)) | ||
| else: | ||
| self.segment_cache[SegmentScope.VECTOR] = BasicCache() | ||
|
|
||
|
|
||
|
|
||
|
|
||
| self._lock = Lock() | ||
|
|
||
| # TODO: prototyping with distributed segment for now, but this should be a configurable option | ||
|
|
@@ -72,13 +81,21 @@ def __init__(self, system: System): | |
| else: | ||
| self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore | ||
| segment_limit = ( | ||
| self._max_file_handles | ||
| // PersistentLocalHnswSegment.get_file_handle_count() | ||
| self._max_file_handles | ||
| // PersistentLocalHnswSegment.get_file_handle_count() | ||
| ) | ||
| self._vector_instances_file_handle_cache = LRUCache( | ||
| segment_limit, callback=lambda _, v: v.close_persistent_index() | ||
| ) | ||
|
|
||
| def callback_cache_evict(self, segment: Segment): | ||
| collection_id = segment["collection"] | ||
| self.logger.info(f"LRU cache evict collection {collection_id}") | ||
| instance = self._instance(segment) | ||
| instance.stop() | ||
| del self._instances[segment["id"]] | ||
|
|
||
|
|
||
| @override | ||
| def start(self) -> None: | ||
| for instance in self._instances.values(): | ||
|
|
@@ -97,7 +114,7 @@ def reset_state(self) -> None: | |
| instance.stop() | ||
| instance.reset_state() | ||
| self._instances = {} | ||
| self._segment_cache = defaultdict(dict) | ||
| self.segment_cache[SegmentScope.VECTOR].reset() | ||
| super().reset_state() | ||
|
|
||
| @trace_method( | ||
|
|
@@ -130,16 +147,31 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: | |
| instance = self.get_segment(collection_id, MetadataReader) | ||
| instance.delete() | ||
| del self._instances[segment["id"]] | ||
| if collection_id in self._segment_cache: | ||
| if segment["scope"] in self._segment_cache[collection_id]: | ||
| del self._segment_cache[collection_id][segment["scope"]] | ||
| del self._segment_cache[collection_id] | ||
| if segment["scope"] is SegmentScope.VECTOR: | ||
| self.segment_cache[SegmentScope.VECTOR].pop(collection_id) | ||
| if segment["scope"] is SegmentScope.METADATA: | ||
| self.segment_cache[SegmentScope.METADATA].pop(collection_id) | ||
| return [s["id"] for s in segments] | ||
|
|
||
| @trace_method( | ||
| "LocalSegmentManager.get_segment", | ||
| OpenTelemetryGranularity.OPERATION_AND_SEGMENT, | ||
| ) | ||
| def _get_segment_disk_size(self, collection_id: UUID) -> int: | ||
| segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR) | ||
| if len(segments) == 0: | ||
| return 0 | ||
| # With local segment manager (single server chroma), a collection always have one segment. | ||
| size = get_directory_size( | ||
| os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"]))) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: leave comment stating assumption of one vector segment for this line - otherwise hardcoded [0] is confusing |
||
| return size | ||
|
|
||
| def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope): | ||
| segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | ||
| known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | ||
| # Get the first segment of a known type | ||
| segment = next(filter(lambda s: s["type"] in known_types, segments)) | ||
| return segment | ||
| @override | ||
| def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | ||
| if type == MetadataReader: | ||
|
|
@@ -149,17 +181,15 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | |
| else: | ||
| raise ValueError(f"Invalid segment type: {type}") | ||
|
|
||
| if scope not in self._segment_cache[collection_id]: | ||
| segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | ||
| known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | ||
| # Get the first segment of a known type | ||
| segment = next(filter(lambda s: s["type"] in known_types, segments)) | ||
| self._segment_cache[collection_id][scope] = segment | ||
| segment = self.segment_cache[scope].get(collection_id) | ||
| if segment is None: | ||
| segment = self._get_segment_sysdb(collection_id, scope) | ||
| self.segment_cache[scope].set(collection_id, segment) | ||
|
|
||
| # Instances must be atomically created, so we use a lock to ensure that only one thread | ||
| # creates the instance. | ||
| with self._lock: | ||
| instance = self._instance(self._segment_cache[collection_id][scope]) | ||
| instance = self._instance(segment) | ||
| return cast(S, instance) | ||
|
|
||
| @trace_method( | ||
|
|
@@ -208,5 +238,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> | |
| scope=scope, | ||
| topic=collection["topic"], | ||
| collection=collection["id"], | ||
| metadata=metadata, | ||
| metadata=metadata | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would update the documentation.