Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
DEFAULT_TENANT = "default_tenant"
DEFAULT_DATABASE = "default_database"


class Settings(BaseSettings): # type: ignore
environment: str = ""

Expand Down Expand Up @@ -116,6 +115,9 @@ class Settings(BaseSettings): # type: ignore
is_persistent: bool = False
persist_directory: str = "./chroma"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would update the documentation.

chroma_memory_limit_bytes: int = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I turn this capability on and off? Is 0 implicitly off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. 0 is unlimited

chroma_segment_cache_policy: Optional[str] = None

chroma_server_host: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we introduce a config - called segment_manager_cache_policy and make this one of many types?

chroma_server_headers: Optional[Dict[str, str]] = None
chroma_server_http_port: Optional[str] = None
Expand Down Expand Up @@ -313,6 +315,15 @@ def __init__(self, settings: Settings):
if settings[key] is not None:
raise ValueError(LEGACY_ERROR)

if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU":
logger.error(
f"Failed to set chroma_segment_cache_policy: Only LRU is available."
)
if settings["chroma_memory_limit_bytes"] == 0:
logger.error(
f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require."
)

# Apply the nofile limit if set
if settings["chroma_server_nofile"] is not None:
if platform.system() != "Windows":
Expand Down
Empty file.
Empty file.
Empty file.
104 changes: 104 additions & 0 deletions chromadb/segment/impl/manager/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import uuid
from typing import Any, Callable
from chromadb.types import Segment
from overrides import override
from typing import Dict, Optional
from abc import ABC, abstractmethod

class SegmentCache(ABC):
@abstractmethod
def get(self, key: uuid.UUID) -> Optional[Segment]:
pass

@abstractmethod
def pop(self, key: uuid.UUID) -> Optional[Segment]:
pass

@abstractmethod
def set(self, key: uuid.UUID, value: Segment) -> None:
pass

@abstractmethod
def reset(self) -> None:
pass


class BasicCache(SegmentCache):
def __init__(self):
self.cache:Dict[uuid.UUID, Segment] = {}

@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
return self.cache.get(key)

@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
return self.cache.pop(key, None)

@override
def set(self, key: uuid.UUID, value: Segment) -> None:
self.cache[key] = value

@override
def reset(self) -> None:
self.cache = {}


class SegmentLRUCache(BasicCache):
"""A simple LRU cache implementation that handles objects with dynamic sizes.
The size of each object is determined by a user-provided size function."""

def __init__(self, capacity: int, size_func: Callable[[uuid.UUID], int],
callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None):
self.capacity = capacity
self.size_func = size_func
self.cache: Dict[uuid.UUID, Segment] = {}
self.history = []
self.callback = callback

def _upsert_key(self, key: uuid.UUID):
if key in self.history:
self.history.remove(key)
self.history.append(key)
else:
self.history.append(key)

@override
def get(self, key: uuid.UUID) -> Optional[Segment]:
self._upsert_key(key)
if key in self.cache:
return self.cache[key]
else:
return None

@override
def pop(self, key: uuid.UUID) -> Optional[Segment]:
if key in self.history:
self.history.remove(key)
return self.cache.pop(key, None)


@override
def set(self, key: uuid.UUID, value: Segment) -> None:
if key in self.cache:
return
item_size = self.size_func(key)
key_sizes = {key: self.size_func(key) for key in self.cache}
total_size = sum(key_sizes.values())
index = 0
# Evict items if capacity is exceeded
while total_size + item_size > self.capacity and len(self.history) > index:
key_delete = self.history[index]
if key_delete in self.cache:
self.callback(key_delete, self.cache[key_delete])
del self.cache[key_delete]
total_size -= key_sizes[key_delete]
index += 1

self.cache[key] = value
self._upsert_key(key)

@override
def reset(self):
self.cache = {}
self.history = []
74 changes: 52 additions & 22 deletions chromadb/segment/impl/manager/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
VectorReader,
S,
)
import logging
from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache
import os

from chromadb.config import System, get_class
from chromadb.db.system import SysDB
from overrides import override
Expand All @@ -21,24 +25,23 @@
from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata
from typing import Dict, Type, Sequence, Optional, cast
from uuid import UUID, uuid4
from collections import defaultdict
import platform

from chromadb.utils.lru_cache import LRUCache
from chromadb.utils.directory import get_directory_size


if platform.system() != "Windows":
import resource
elif platform.system() == "Windows":
import ctypes


SEGMENT_TYPE_IMPLS = {
SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment",
SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment",
SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment",
}


class LocalSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
Expand All @@ -47,9 +50,6 @@ class LocalSegmentManager(SegmentManager):
_vector_instances_file_handle_cache: LRUCache[
UUID, PersistentLocalHnswSegment
] # LRU cache to manage file handles across vector segment instances
_segment_cache: Dict[
UUID, Dict[SegmentScope, Segment]
] # Tracks which segments are loaded for a given collection
_vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY
_lock: Lock
_max_file_handles: int
Expand All @@ -59,8 +59,17 @@ def __init__(self, system: System):
self._sysdb = self.require(SysDB)
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self.logger = logging.getLogger(__name__)
self._instances = {}
self._segment_cache = defaultdict(dict)
self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()}
if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0:
self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k))
else:
self.segment_cache[SegmentScope.VECTOR] = BasicCache()




self._lock = Lock()

# TODO: prototyping with distributed segment for now, but this should be a configurable option
Expand All @@ -72,13 +81,21 @@ def __init__(self, system: System):
else:
self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore
segment_limit = (
self._max_file_handles
// PersistentLocalHnswSegment.get_file_handle_count()
self._max_file_handles
// PersistentLocalHnswSegment.get_file_handle_count()
)
self._vector_instances_file_handle_cache = LRUCache(
segment_limit, callback=lambda _, v: v.close_persistent_index()
)

def callback_cache_evict(self, segment: Segment):
collection_id = segment["collection"]
self.logger.info(f"LRU cache evict collection {collection_id}")
instance = self._instance(segment)
instance.stop()
del self._instances[segment["id"]]


@override
def start(self) -> None:
for instance in self._instances.values():
Expand All @@ -97,7 +114,7 @@ def reset_state(self) -> None:
instance.stop()
instance.reset_state()
self._instances = {}
self._segment_cache = defaultdict(dict)
self.segment_cache[SegmentScope.VECTOR].reset()
super().reset_state()

@trace_method(
Expand Down Expand Up @@ -130,16 +147,31 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
instance = self.get_segment(collection_id, MetadataReader)
instance.delete()
del self._instances[segment["id"]]
if collection_id in self._segment_cache:
if segment["scope"] in self._segment_cache[collection_id]:
del self._segment_cache[collection_id][segment["scope"]]
del self._segment_cache[collection_id]
if segment["scope"] is SegmentScope.VECTOR:
self.segment_cache[SegmentScope.VECTOR].pop(collection_id)
if segment["scope"] is SegmentScope.METADATA:
self.segment_cache[SegmentScope.METADATA].pop(collection_id)
return [s["id"] for s in segments]

@trace_method(
"LocalSegmentManager.get_segment",
OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
)
def _get_segment_disk_size(self, collection_id: UUID) -> int:
segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR)
if len(segments) == 0:
return 0
# With local segment manager (single server chroma), a collection always have one segment.
size = get_directory_size(
os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: leave comment stating assumption of one vector segment for this line - otherwise hardcoded [0] is confusing

return size

def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope):
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
# Get the first segment of a known type
segment = next(filter(lambda s: s["type"] in known_types, segments))
return segment
@override
def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
if type == MetadataReader:
Expand All @@ -149,17 +181,15 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S:
else:
raise ValueError(f"Invalid segment type: {type}")

if scope not in self._segment_cache[collection_id]:
segments = self._sysdb.get_segments(collection=collection_id, scope=scope)
known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()])
# Get the first segment of a known type
segment = next(filter(lambda s: s["type"] in known_types, segments))
self._segment_cache[collection_id][scope] = segment
segment = self.segment_cache[scope].get(collection_id)
if segment is None:
segment = self._get_segment_sysdb(collection_id, scope)
self.segment_cache[scope].set(collection_id, segment)

# Instances must be atomically created, so we use a lock to ensure that only one thread
# creates the instance.
with self._lock:
instance = self._instance(self._segment_cache[collection_id][scope])
instance = self._instance(segment)
return cast(S, instance)

@trace_method(
Expand Down Expand Up @@ -208,5 +238,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) ->
scope=scope,
topic=collection["topic"],
collection=collection["id"],
metadata=metadata,
metadata=metadata
)
2 changes: 1 addition & 1 deletion chromadb/test/db/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def test_update_segment(sysdb: SysDB) -> None:
scope=SegmentScope.VECTOR,
topic="test_topic_a",
collection=sample_collections[0]["id"],
metadata=metadata,
metadata=metadata
)

sysdb.reset_state()
Expand Down
6 changes: 5 additions & 1 deletion chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing_extensions import TypedDict
import uuid
import numpy as np
import numpy.typing as npt
import chromadb.api.types as types
Expand Down Expand Up @@ -237,16 +238,17 @@ def embedding_function_strategy(
@dataclass
class Collection:
name: str
id: uuid.UUID
metadata: Optional[types.Metadata]
dimension: int
dtype: npt.DTypeLike
topic: str
known_metadata_keys: types.Metadata
known_document_keywords: List[str]
has_documents: bool = False
has_embeddings: bool = False
embedding_function: Optional[types.EmbeddingFunction[Embeddable]] = None


@st.composite
def collections(
draw: st.DrawFn,
Expand Down Expand Up @@ -309,7 +311,9 @@ def collections(
embedding_function = draw(embedding_function_strategy(dimension, dtype))

return Collection(
id=uuid.uuid4(),
name=name,
topic="topic",
metadata=metadata,
dimension=dimension,
dtype=dtype,
Expand Down
Loading