diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 732a17920d5..ba0cc333ef3 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -22,6 +22,8 @@ WhereDocument, UpdateCollectionMetadata, ) +from chromadb.errors import InvalidArgumentError + # Re-export types from chromadb.types __all__ = [ @@ -189,12 +191,12 @@ def HttpClient( settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" if settings.chroma_server_host and settings.chroma_server_host != host: - raise ValueError( + raise InvalidArgumentError( f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]" ) settings.chroma_server_host = host if settings.chroma_server_http_port and settings.chroma_server_http_port != port: - raise ValueError( + raise InvalidArgumentError( f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]" ) settings.chroma_server_http_port = port @@ -240,12 +242,12 @@ async def AsyncHttpClient( settings.chroma_api_impl = "chromadb.api.async_fastapi.AsyncFastAPI" if settings.chroma_server_host and settings.chroma_server_host != host: - raise ValueError( + raise InvalidArgumentError( f"Chroma server host provided in settings[{settings.chroma_server_host}] is different to the one provided in HttpClient: [{host}]" ) settings.chroma_server_host = host if settings.chroma_server_http_port and settings.chroma_server_http_port != port: - raise ValueError( + raise InvalidArgumentError( f"Chroma server http port provided in settings[{settings.chroma_server_http_port}] is different to the one provided in HttpClient: [{port}]" ) settings.chroma_server_http_port = port diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index d443f9ab005..25ea2ae6a20 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -100,7 +100,7 @@ def delete_collection( name: The name of the collection to delete. Raises: - ValueError: If the collection does not exist. + InvalidArgumentError: If the collection does not exist. Examples: ```python @@ -389,8 +389,8 @@ def create_collection( Collection: The newly created collection. Raises: - ValueError: If the collection already exists and get_or_create is False. - ValueError: If the collection name is invalid. + InvalidArgumentError: If the collection already exists and get_or_create is False. + InvalidArgumentError: If the collection name is invalid. Examples: ```python @@ -423,7 +423,7 @@ def get_collection( Collection: The collection Raises: - ValueError: If the collection does not exist + InvalidArgumentError: If the collection does not exist Examples: ```python diff --git a/chromadb/api/async_api.py b/chromadb/api/async_api.py index 8396d1e9a97..ea5e061b419 100644 --- a/chromadb/api/async_api.py +++ b/chromadb/api/async_api.py @@ -91,7 +91,7 @@ async def delete_collection( name: The name of the collection to delete. Raises: - ValueError: If the collection does not exist. + InvalidArgumentError: If the collection does not exist. Examples: ```python @@ -380,8 +380,8 @@ async def create_collection( Collection: The newly created collection. Raises: - ValueError: If the collection already exists and get_or_create is False. - ValueError: If the collection name is invalid. + InvalidArgumentError: If the collection already exists and get_or_create is False. + InvalidArgumentError: If the collection name is invalid. Examples: ```python @@ -414,7 +414,7 @@ async def get_collection( Collection: The collection Raises: - ValueError: If the collection does not exist + InvalidArgumentError: If the collection does not exist Examples: ```python diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index 56491e5ca81..fe71f4ca579 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -24,7 +24,10 @@ URIs, ) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System -from chromadb.errors import ChromaError +from chromadb.errors import ( + ChromaError, + InvalidArgumentError +) from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef @@ -125,21 +128,21 @@ async def _validate_tenant_database(self, tenant: str, database: str) -> None: try: await self._admin_client.get_tenant(name=tenant) except httpx.ConnectError: - raise ValueError( + raise InvalidArgumentError( "Could not connect to a Chroma server. Are you sure it is running?" ) # Propagate ChromaErrors except ChromaError as e: raise e except Exception: - raise ValueError( + raise InvalidArgumentError( f"Could not connect to tenant {tenant}. Are you sure it exists?" ) try: await self._admin_client.get_database(name=database, tenant=tenant) except httpx.ConnectError: - raise ValueError( + raise InvalidArgumentError( "Could not connect to a Chroma server. Are you sure it is running?" ) diff --git a/chromadb/api/base_http_client.py b/chromadb/api/base_http_client.py index 5a1dd25aae6..827884b3813 100644 --- a/chromadb/api/base_http_client.py +++ b/chromadb/api/base_http_client.py @@ -5,6 +5,7 @@ import httpx import chromadb.errors as errors +from chromadb.errors import InvalidArgumentError from chromadb.config import Settings logger = logging.getLogger(__name__) @@ -18,11 +19,11 @@ class BaseHTTPClient: def _validate_host(host: str) -> None: parsed = urlparse(host) if "/" in host and parsed.scheme not in {"http", "https"}: - raise ValueError( + raise InvalidArgumentError( "Invalid URL. " f"Unrecognized protocol - {parsed.scheme}." ) if "/" in host and (not host.startswith("http")): - raise ValueError( + raise InvalidArgumentError( "Invalid URL. " "Seems that you are trying to pass URL as a host but without \ specifying the protocol. " diff --git a/chromadb/api/client.py b/chromadb/api/client.py index de9ca1e7115..bfea8f17fff 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -26,7 +26,10 @@ from chromadb.config import Settings, System from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE from chromadb.api.models.Collection import Collection -from chromadb.errors import ChromaError +from chromadb.errors import ( + ChromaError, + InvalidArgumentError +) from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef @@ -100,14 +103,14 @@ def get_user_identity(self) -> UserIdentity: try: return self._server.get_user_identity() except httpx.ConnectError: - raise ValueError( + raise InvalidArgumentError( "Could not connect to a Chroma server. Are you sure it is running?" ) # Propagate ChromaErrors except ChromaError as e: raise e except Exception as e: - raise ValueError(str(e)) + raise InvalidArgumentError(str(e)) # region BaseAPI Methods # Note - we could do this in less verbose ways, but they break type checking @@ -416,21 +419,21 @@ def _validate_tenant_database(self, tenant: str, database: str) -> None: try: self._admin_client.get_tenant(name=tenant) except httpx.ConnectError: - raise ValueError( + raise InvalidArgumentError( "Could not connect to a Chroma server. Are you sure it is running?" ) # Propagate ChromaErrors except ChromaError as e: raise e except Exception: - raise ValueError( + raise InvalidArgumentError( f"Could not connect to tenant {tenant}. Are you sure it exists?" ) try: self._admin_client.get_database(name=database, tenant=tenant) except httpx.ConnectError: - raise ValueError( + raise InvalidArgumentError( "Could not connect to a Chroma server. Are you sure it is running?" ) diff --git a/chromadb/api/configuration.py b/chromadb/api/configuration.py index 7a8e1b04896..50b0a4b2fc3 100644 --- a/chromadb/api/configuration.py +++ b/chromadb/api/configuration.py @@ -1,6 +1,7 @@ from abc import abstractmethod import json from overrides import override +from chromadb.errors import InvalidArgumentError from typing import ( Any, ClassVar, @@ -26,7 +27,7 @@ class StaticParameterError(Exception): pass -class InvalidConfigurationError(ValueError): +class InvalidConfigurationError(InvalidArgumentError): """Represents an error that occurs when a configuration is invalid.""" pass @@ -102,23 +103,23 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None): if parameters is not None: for parameter in parameters: if parameter.name not in self.definitions: - raise ValueError(f"Invalid parameter name: {parameter.name}") + raise InvalidArgumentError(f"Invalid parameter name: {parameter.name}") definition = self.definitions[parameter.name] # Handle the case where we have a recursive configuration definition if isinstance(parameter.value, dict): child_type = globals().get(parameter.value.get("_type", None)) if child_type is None: - raise ValueError( + raise InvalidArgumentError( f"Invalid configuration type: {parameter.value}" ) parameter.value = child_type.from_json(parameter.value) if not isinstance(parameter.value, type(definition.default_value)): - raise ValueError(f"Invalid parameter value: {parameter.value}") + raise InvalidArgumentError(f"Invalid parameter value: {parameter.value}") parameter_validator = definition.validator if not parameter_validator(parameter.value): - raise ValueError(f"Invalid parameter value: {parameter.value}") + raise InvalidArgumentError(f"Invalid parameter value: {parameter.value}") self.parameter_map[parameter.name] = parameter # Apply the defaults for any missing parameters for name, definition in self.definitions.items(): @@ -152,7 +153,7 @@ def get_parameters(self) -> List[ConfigurationParameter]: def get_parameter(self, name: str) -> ConfigurationParameter: """Returns the parameter with the given name, or except if it doesn't exist.""" if name not in self.parameter_map: - raise ValueError( + raise InvalidArgumentError( f"Invalid parameter name: {name} for configuration {self.__class__.__name__}" ) param_value = cast(ConfigurationParameter, self.parameter_map.get(name)) @@ -161,13 +162,13 @@ def get_parameter(self, name: str) -> ConfigurationParameter: def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None: """Sets the parameter with the given name to the given value.""" if name not in self.definitions: - raise ValueError(f"Invalid parameter name: {name}") + raise InvalidArgumentError(f"Invalid parameter name: {name}") definition = self.definitions[name] parameter = self.parameter_map[name] if definition.is_static: raise StaticParameterError(f"Cannot set static parameter: {name}") if not definition.validator(value): - raise ValueError(f"Invalid value for parameter {name}: {value}") + raise InvalidArgumentError(f"Invalid value for parameter {name}: {value}") parameter.value = value @override @@ -182,7 +183,7 @@ def from_json_str(cls, json_str: str) -> Self: try: config_json = json.loads(json_str) except json.JSONDecodeError: - raise ValueError( + raise InvalidArgumentError( f"Unable to decode configuration from JSON string: {json_str}" ) return cls.from_json(config_json) @@ -205,7 +206,7 @@ def to_json(self) -> Dict[str, Any]: def from_json(cls, json_map: Dict[str, Any]) -> Self: """Returns a configuration from the given JSON string.""" if cls.__name__ != json_map.get("_type", None): - raise ValueError( + raise InvalidArgumentError( f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}" ) parameters = [] @@ -308,7 +309,7 @@ def from_legacy_params(cls, params: Dict[str, Any]) -> Self: parameters = [] for name, value in params.items(): if name not in old_to_new: - raise ValueError(f"Invalid legacy HNSW parameter name: {name}") + raise InvalidArgumentError(f"Invalid legacy HNSW parameter name: {name}") parameters.append( ConfigurationParameter(name=old_to_new[name], value=value) ) diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index bbae1ce46e9..4237f129c21 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -53,11 +53,11 @@ async def add( None Raises: - ValueError: If you don't provide either embeddings or documents - ValueError: If the length of ids, embeddings, metadatas, or documents don't match - ValueError: If you don't provide an embedding function and don't provide embeddings - ValueError: If you provide both embeddings and documents - ValueError: If you provide an id that already exists + InvalidArgumentError: If you don't provide either embeddings or documents + InvalidArgumentError: If the length of ids, embeddings, metadatas, or documents don't match + InvalidArgumentError: If you don't provide an embedding function and don't provide embeddings + InvalidArgumentError: If you provide both embeddings and documents + InvalidArgumentError: If you provide an id that already exists """ add_request = self._validate_and_prepare_add_request( @@ -194,10 +194,10 @@ async def query( QueryResult: A QueryResult object containing the results. Raises: - ValueError: If you don't provide either query_embeddings, query_texts, or query_images - ValueError: If you provide both query_embeddings and query_texts - ValueError: If you provide both query_embeddings and query_images - ValueError: If you provide both query_texts and query_images + InvalidArgumentError: If you don't provide either query_embeddings, query_texts, or query_images + InvalidArgumentError: If you provide both query_embeddings and query_texts + InvalidArgumentError: If you provide both query_embeddings and query_images + InvalidArgumentError: If you provide both query_texts and query_images """ @@ -356,7 +356,7 @@ async def delete( None Raises: - ValueError: If you don't provide either ids, where, or where_document + InvalidArgumentError: If you don't provide either ids, where, or where_document """ delete_request = self._validate_and_prepare_delete_request( ids, where, where_document diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 4194f20e2ba..8a6052b8062 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -70,11 +70,11 @@ def add( None Raises: - ValueError: If you don't provide either embeddings or documents - ValueError: If the length of ids, embeddings, metadatas, or documents don't match - ValueError: If you don't provide an embedding function and don't provide embeddings - ValueError: If you provide both embeddings and documents - ValueError: If you provide an id that already exists + InvalidArgumentError: If you don't provide either embeddings or documents + InvalidArgumentError: If the length of ids, embeddings, metadatas, or documents don't match + InvalidArgumentError: If you don't provide an embedding function and don't provide embeddings + InvalidArgumentError: If you provide both embeddings and documents + InvalidArgumentError: If you provide an id that already exists """ @@ -200,10 +200,10 @@ def query( QueryResult: A QueryResult object containing the results. Raises: - ValueError: If you don't provide either query_embeddings, query_texts, or query_images - ValueError: If you provide both query_embeddings and query_texts - ValueError: If you provide both query_embeddings and query_images - ValueError: If you provide both query_texts and query_images + InvalidArgumentError: If you don't provide either query_embeddings, query_texts, or query_images + InvalidArgumentError: If you provide both query_embeddings and query_texts + InvalidArgumentError: If you provide both query_embeddings and query_images + InvalidArgumentError: If you provide both query_texts and query_images """ @@ -366,7 +366,7 @@ def delete( None Raises: - ValueError: If you don't provide either ids, where, or where_document + InvalidArgumentError: If you don't provide either ids, where, or where_document """ delete_request = self._validate_and_prepare_delete_request( ids, where, where_document diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index d2b3cd3789e..a8246864736 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -65,6 +65,7 @@ # which are essentially API views. And the actual data models which are # stored / retrieved / transmitted. from chromadb.types import Collection as CollectionModel, Where, WhereDocument +from chromadb.errors import InvalidArgumentError import logging logger = logging.getLogger(__name__) @@ -242,7 +243,7 @@ def _validate_and_prepare_get_request( validate_include(include=include, dissalowed=[IncludeEnum.distances]) if IncludeEnum.data in include and self._data_loader is None: - raise ValueError( + raise InvalidArgumentError( "You must set a data loader on the collection if loading from URIs." ) @@ -423,7 +424,7 @@ def _validate_and_prepare_delete_request( where_document: Optional[WhereDocument], ) -> DeleteRequest: if ids is None and where is None and where_document is None: - raise ValueError( + raise InvalidArgumentError( "At least one of ids, where, or where_document must be provided" ) @@ -493,7 +494,7 @@ def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> No if metadata is not None: validate_metadata(metadata) if "hnsw:space" in metadata: - raise ValueError( + raise InvalidArgumentError( "Changing the distance function of a collection once it is created is not supported currently." ) @@ -516,7 +517,7 @@ def _embed_record_set( # uris require special handling if field == "uris": if self._data_loader is None: - raise ValueError( + raise InvalidArgumentError( "You must set a data loader on the collection if loading from URIs." ) return self._embed( @@ -524,7 +525,7 @@ def _embed_record_set( ) else: return self._embed(input=record_set[field]) # type: ignore[literal-required] - raise ValueError( + raise InvalidArgumentError( "Record does not contain any non-None fields that can be embedded." f"Embeddable Fields: {embeddable_fields}" f"Record Fields: {record_set}" @@ -532,7 +533,7 @@ def _embed_record_set( def _embed(self, input: Any) -> Embeddings: if self._embedding_function is None: - raise ValueError( + raise InvalidArgumentError( "You must provide an embedding function to compute embeddings." "https://docs.trychroma.com/guides/embeddings" ) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 045f16507f6..2248f4ddb89 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -22,6 +22,7 @@ from chromadb import __version__ from chromadb.errors import ( InvalidDimensionException, + InvalidArgumentError, InvalidCollectionException, VersionMismatchError, ) @@ -87,13 +88,13 @@ def check_index_name(index_name: str) -> None: f"got {index_name}" ) if len(index_name) < 3 or len(index_name) > 63: - raise ValueError(msg) + raise InvalidArgumentError(msg) if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): - raise ValueError(msg) + raise InvalidArgumentError(msg) if ".." in index_name: - raise ValueError(msg) + raise InvalidArgumentError(msg) if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): - raise ValueError(msg) + raise InvalidArgumentError(msg) def rate_limit(func: T) -> T: @@ -138,7 +139,7 @@ def heartbeat(self) -> int: @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: if len(name) < 3: - raise ValueError("Database name must be at least 3 characters long") + raise InvalidArgumentError("Database name must be at least 3 characters long") self._quota_enforcer.enforce( action=Action.CREATE_DATABASE, @@ -161,7 +162,7 @@ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: @override def create_tenant(self, name: str) -> None: if len(name) < 3: - raise ValueError("Tenant name must be at least 3 characters long") + raise InvalidArgumentError("Tenant name must be at least 3 characters long") self._sysdb.create_tenant( name=name, @@ -389,7 +390,7 @@ def delete_collection( ) self._manager.delete_segments(existing[0].id) else: - raise ValueError(f"Collection {name} does not exist.") + raise InvalidArgumentError(f"Collection {name} does not exist.") @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION) @override @@ -667,7 +668,7 @@ def _delete( or (where_document is not None and len(where_document) == 0) ) ): - raise ValueError( + raise InvalidArgumentError( """ You must provide either ids, where, or where_document to delete. If you want to delete all data in a collection you can delete the diff --git a/chromadb/api/shared_system_client.py b/chromadb/api/shared_system_client.py index 9e2f7f90052..3be8728e384 100644 --- a/chromadb/api/shared_system_client.py +++ b/chromadb/api/shared_system_client.py @@ -3,6 +3,7 @@ from chromadb.api import ServerAPI from chromadb.config import Settings, System +from chromadb.errors import InvalidArgumentError from chromadb.telemetry.product import ProductTelemetryClient from chromadb.telemetry.product.events import ClientStartEvent @@ -35,7 +36,7 @@ def _create_system_if_not_exists( # For now, the settings must match if previous_system.settings != settings: - raise ValueError( + raise InvalidArgumentError( f"An instance of Chroma already exists for {identifier} with different settings" ) @@ -47,7 +48,7 @@ def _get_identifier_from_settings(settings: Settings) -> str: api_impl = settings.chroma_api_impl if api_impl is None: - raise ValueError("Chroma API implementation must be set in settings") + raise InvalidArgumentError("Chroma API implementation must be set in settings") elif api_impl == "chromadb.api.segment.SegmentAPI": if settings.is_persistent: identifier = settings.persist_directory @@ -62,7 +63,7 @@ def _get_identifier_from_settings(settings: Settings) -> str: # FastAPI clients can all use unique system identifiers since their configurations can be independent, e.g. different auth tokens identifier = str(uuid.uuid4()) else: - raise ValueError(f"Unsupported Chroma API implementation {api_impl}") + raise InvalidArgumentError(f"Unsupported Chroma API implementation {api_impl}") return identifier diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 7ea2a130aee..76b020bc610 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -5,6 +5,7 @@ from enum import Enum from pydantic import Field import chromadb.errors as errors +from chromadb.errors import InvalidArgumentError from chromadb.types import ( Metadata, UpdateMetadata, @@ -79,7 +80,7 @@ def normalize_embeddings( if target.ndim == 2: return list(target) - raise ValueError( + raise InvalidArgumentError( f"Expected embeddings to be a list of floats or ints, a list of lists, a numpy array, or a list of numpy arrays, got {target}" ) @@ -228,7 +229,7 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None: lengths = [len(lst) for lst in record_set.values() if lst is not None] # type: ignore[arg-type] if not lengths: - raise ValueError( + raise InvalidArgumentError( f"At least one of one of {', '.join(record_set.keys())} must be provided" ) @@ -237,13 +238,13 @@ def _validate_record_set_length_consistency(record_set: BaseRecordSet) -> None: ] if zero_lengths: - raise ValueError(f"Non-empty lists are required for {zero_lengths}") + raise InvalidArgumentError(f"Non-empty lists are required for {zero_lengths}") if len(set(lengths)) > 1: error_str = ", ".join( f"{key}: {len(lst)}" for key, lst in record_set.items() if lst is not None # type: ignore[arg-type] ) - raise ValueError(f"Unequal lengths for fields: {error_str}") + raise InvalidArgumentError(f"Unequal lengths for fields: {error_str}") def validate_record_set_for_embedding( @@ -253,7 +254,7 @@ def validate_record_set_for_embedding( Validates that the Record is ready to be embedded, i.e. that it contains exactly one of the embeddable fields. """ if record_set["embeddings"] is not None: - raise ValueError("Attempting to embed a record that already has embeddings.") + raise InvalidArgumentError("Attempting to embed a record that already has embeddings.") if embeddable_fields is None: embeddable_fields = get_default_embeddable_record_set_fields() validate_record_set_contains_one(record_set, embeddable_fields) @@ -268,7 +269,7 @@ def validate_record_set_contains_any( _validate_record_set_contains(record_set, contains_any) if not any(record_set[field] is not None for field in contains_any): # type: ignore[literal-required] - raise ValueError(f"At least one of {', '.join(contains_any)} must be provided") + raise InvalidArgumentError(f"At least one of {', '.join(contains_any)} must be provided") def validate_record_set_contains_one( @@ -279,7 +280,7 @@ def validate_record_set_contains_one( """ _validate_record_set_contains(record_set, contains_one) if sum(record_set[field] is not None for field in contains_one) != 1: # type: ignore[literal-required] - raise ValueError(f"Exactly one of {', '.join(contains_one)} must be provided") + raise InvalidArgumentError(f"Exactly one of {', '.join(contains_one)} must be provided") def _validate_record_set_contains( @@ -289,7 +290,7 @@ def _validate_record_set_contains( Validates that all fields in contains are valid fields of the Record. """ if any(field not in record_set for field in contains): - raise ValueError( + raise InvalidArgumentError( f"Invalid field in contains: {', '.join(contains)}, available fields: {', '.join(record_set.keys())}" ) @@ -478,7 +479,7 @@ def validate_embedding_function( protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys() if not function_signature == protocol_signature: - raise ValueError( + raise InvalidArgumentError( f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n" "Please see https://docs.trychroma.com/guides/embeddings for details of the EmbeddingFunction interface.\n" "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/deployment/migration#migration-to-0.4.16---november-7,-2023 \n" @@ -493,14 +494,14 @@ def __call__(self, uris: URIs) -> L: def validate_ids(ids: IDs) -> IDs: """Validates ids to ensure it is a list of strings""" if not isinstance(ids, list): - raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs") + raise InvalidArgumentError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs") if len(ids) == 0: - raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs") + raise InvalidArgumentError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs") seen = set() dups = set() for id_ in ids: if not isinstance(id_, str): - raise ValueError(f"Expected ID to be a str, got {id_}") + raise InvalidArgumentError(f"Expected ID to be a str, got {id_}") if id_ in seen: dups.add(id_) else: @@ -529,27 +530,27 @@ def validate_ids(ids: IDs) -> IDs: def validate_metadata(metadata: Metadata) -> Metadata: """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" if not isinstance(metadata, dict) and metadata is not None: - raise ValueError( + raise InvalidArgumentError( f"Expected metadata to be a dict or None, got {type(metadata).__name__} as metadata" ) if metadata is None: return metadata if len(metadata) == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected metadata to be a non-empty dict, got {len(metadata)} metadata attributes" ) for key, value in metadata.items(): if key == META_KEY_CHROMA_DOCUMENT: - raise ValueError( + raise InvalidArgumentError( f"Expected metadata to not contain the reserved key {META_KEY_CHROMA_DOCUMENT}" ) if not isinstance(key, str): - raise TypeError( + raise InvalidArgumentError( f"Expected metadata key to be a str, got {key} which is a {type(key).__name__}" ) # isinstance(True, int) evaluates to True, so we need to check for bools separately if not isinstance(value, bool) and not isinstance(value, (str, int, float)): - raise ValueError( + raise InvalidArgumentError( f"Expected metadata value to be a str, int, float or bool, got {value} which is a {type(value).__name__}" ) return metadata @@ -558,21 +559,21 @@ def validate_metadata(metadata: Metadata) -> Metadata: def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata: """Validates metadata to ensure it is a dictionary of strings to strings, ints, floats or bools""" if not isinstance(metadata, dict) and metadata is not None: - raise ValueError( + raise InvalidArgumentError( f"Expected metadata to be a dict or None, got {type(metadata)}" ) if metadata is None: return metadata if len(metadata) == 0: - raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}") + raise InvalidArgumentError(f"Expected metadata to be a non-empty dict, got {metadata}") for key, value in metadata.items(): if not isinstance(key, str): - raise ValueError(f"Expected metadata key to be a str, got {key}") + raise InvalidArgumentError(f"Expected metadata key to be a str, got {key}") # isinstance(True, int) evaluates to True, so we need to check for bools separately if not isinstance(value, bool) and not isinstance( value, (str, int, float, type(None)) ): - raise ValueError( + raise InvalidArgumentError( f"Expected metadata value to be a str, int, or float, got {value}" ) return metadata @@ -581,7 +582,7 @@ def validate_update_metadata(metadata: UpdateMetadata) -> UpdateMetadata: def validate_metadatas(metadatas: Metadatas) -> Metadatas: """Validates metadatas to ensure it is a list of dictionaries of strings to strings, ints, floats or bools""" if not isinstance(metadatas, list): - raise ValueError(f"Expected metadatas to be a list, got {metadatas}") + raise InvalidArgumentError(f"Expected metadatas to be a list, got {metadatas}") for metadata in metadatas: validate_metadata(metadata) return metadatas @@ -593,12 +594,12 @@ def validate_where(where: Where) -> None: or in the case of $and and $or, a list of where expressions """ if not isinstance(where, dict): - raise ValueError(f"Expected where to be a dict, got {where}") + raise InvalidArgumentError(f"Expected where to be a dict, got {where}") if len(where) != 1: - raise ValueError(f"Expected where to have exactly one operator, got {where}") + raise InvalidArgumentError(f"Expected where to have exactly one operator, got {where}") for key, value in where.items(): if not isinstance(key, str): - raise ValueError(f"Expected where key to be a str, got {key}") + raise InvalidArgumentError(f"Expected where key to be a str, got {key}") if ( key != "$and" and key != "$or" @@ -606,16 +607,16 @@ def validate_where(where: Where) -> None: and key != "$nin" and not isinstance(value, (str, int, float, dict)) ): - raise ValueError( + raise InvalidArgumentError( f"Expected where value to be a str, int, float, or operator expression, got {value}" ) if key == "$and" or key == "$or": if not isinstance(value, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) if len(value) <= 1: - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $and or $or to be a list with at least two where expressions, got {value}" ) for where_expression in value: @@ -624,7 +625,7 @@ def validate_where(where: Where) -> None: if isinstance(value, dict): # Ensure there is only one operator if len(value) != 1: - raise ValueError( + raise InvalidArgumentError( f"Expected operator expression to have exactly one operator, got {value}" ) @@ -632,12 +633,12 @@ def validate_where(where: Where) -> None: # Only numbers can be compared with gt, gte, lt, lte if operator in ["$gt", "$gte", "$lt", "$lte"]: if not isinstance(operand, (int, float)): - raise ValueError( + raise InvalidArgumentError( f"Expected operand value to be an int or a float for operator {operator}, got {operand}" ) if operator in ["$in", "$nin"]: if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected operand value to be an list for operator {operator}, got {operand}" ) if operator not in [ @@ -650,20 +651,20 @@ def validate_where(where: Where) -> None: "$in", "$nin", ]: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be one of $gt, $gte, $lt, $lte, $ne, $eq, $in, $nin, " f"got {operator}" ) if not isinstance(operand, (str, int, float, list)): - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a str, int, float, or list of those type, got {operand}" ) if isinstance(operand, list) and ( len(operand) == 0 or not all(isinstance(x, type(operand[0])) for x in operand) ): - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a non-empty list, and all values to be of the same type " f"got {operand}" ) @@ -675,36 +676,36 @@ def validate_where_document(where_document: WhereDocument) -> None: a list of where_document expressions """ if not isinstance(where_document, dict): - raise ValueError( + raise InvalidArgumentError( f"Expected where document to be a dictionary, got {where_document}" ) if len(where_document) != 1: - raise ValueError( + raise InvalidArgumentError( f"Expected where document to have exactly one operator, got {where_document}" ) for operator, operand in where_document.items(): if operator not in ["$contains", "$not_contains", "$and", "$or"]: - raise ValueError( + raise InvalidArgumentError( f"Expected where document operator to be one of $contains, $and, $or, got {operator}" ) if operator == "$and" or operator == "$or": if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected document value for $and or $or to be a list of where document expressions, got {operand}" ) if len(operand) <= 1: - raise ValueError( + raise InvalidArgumentError( f"Expected document value for $and or $or to be a list with at least two where document expressions, got {operand}" ) for where_document_expression in operand: validate_where_document(where_document_expression) # Value is a $contains operator elif not isinstance(operand, str): - raise ValueError( + raise InvalidArgumentError( f"Expected where document operand value for operator $contains to be a str, got {operand}" ) elif len(operand) == 0: - raise ValueError( + raise InvalidArgumentError( "Expected where document operand value for operator $contains to be a non-empty str" ) @@ -714,18 +715,18 @@ def validate_include(include: Include, dissalowed: Optional[Include] = None) -> to control if distances is allowed""" if not isinstance(include, list): - raise ValueError(f"Expected include to be a list, got {include}") + raise InvalidArgumentError(f"Expected include to be a list, got {include}") for item in include: if not isinstance(item, str): - raise ValueError(f"Expected include item to be a str, got {item}") + raise InvalidArgumentError(f"Expected include item to be a str, got {item}") if not any(item == e for e in IncludeEnum): - raise ValueError( + raise InvalidArgumentError( f"Expected include item to be one of {', '.join(IncludeEnum)}, got {item}" ) if dissalowed is not None and any(item == e for e in dissalowed): - raise ValueError( + raise InvalidArgumentError( f"Include item cannot be one of {', '.join(dissalowed)}, got {item}" ) @@ -734,11 +735,11 @@ def validate_n_results(n_results: int) -> int: """Validates n_results to ensure it is a positive Integer. Since hnswlib does not allow n_results to be negative.""" # Check Number of requested results if not isinstance(n_results, int): - raise ValueError( + raise InvalidArgumentError( f"Expected requested number of results to be a int, got {n_results}" ) if n_results <= 0: - raise TypeError( + raise InvalidArgumentError( f"Number of requested results {n_results}, cannot be negative, or zero." ) return n_results @@ -747,25 +748,25 @@ def validate_n_results(n_results: int) -> int: def validate_embeddings(embeddings: Embeddings) -> Embeddings: """Validates embeddings to ensure it is a list of numpy arrays of ints, or floats""" if not isinstance(embeddings, (list, np.ndarray)): - raise ValueError( + raise InvalidArgumentError( f"Expected embeddings to be a list, got {type(embeddings).__name__}" ) if len(embeddings) == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings" ) if not all([isinstance(e, np.ndarray) for e in embeddings]): - raise ValueError( + raise InvalidArgumentError( "Expected each embedding in the embeddings to be a numpy array, got " f"{list(set([type(e).__name__ for e in embeddings]))}" ) for i, embedding in enumerate(embeddings): if embedding.ndim == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}" ) if embedding.size == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected each embedding in the embeddings to be a 1-dimensional numpy array with at least 1 int/float value. Got a 1-dimensional numpy array with no values at pos {i}" ) if not all( @@ -775,7 +776,7 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: for value in embedding ] ): - raise ValueError( + raise InvalidArgumentError( "Expected each value in the embedding to be a int or float, got an embedding with " f"{list(set([type(value).__name__ for value in embedding]))} - {embedding}" ) @@ -785,11 +786,11 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: def validate_documents(documents: Documents, nullable: bool = False) -> None: """Validates documents to ensure it is a list of strings""" if not isinstance(documents, list): - raise ValueError( + raise InvalidArgumentError( f"Expected documents to be a list, got {type(documents).__name__}" ) if len(documents) == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected documents to be a non-empty list, got {len(documents)} documents" ) for document in documents: @@ -797,20 +798,20 @@ def validate_documents(documents: Documents, nullable: bool = False) -> None: if document is None and nullable: continue if not is_document(document): - raise ValueError(f"Expected document to be a str, got {document}") + raise InvalidArgumentError(f"Expected document to be a str, got {document}") def validate_images(images: Images) -> None: """Validates images to ensure it is a list of numpy arrays""" if not isinstance(images, list): - raise ValueError(f"Expected images to be a list, got {type(images).__name__}") + raise InvalidArgumentError(f"Expected images to be a list, got {type(images).__name__}") if len(images) == 0: - raise ValueError( + raise InvalidArgumentError( f"Expected images to be a non-empty list, got {len(images)} images" ) for image in images: if not is_image(image): - raise ValueError(f"Expected image to be a numpy array, got {image}") + raise InvalidArgumentError(f"Expected image to be a numpy array, got {image}") def validate_batch( @@ -824,7 +825,7 @@ def validate_batch( limits: Dict[str, Any], ) -> None: if len(batch[0]) > limits["max_batch_size"]: - raise ValueError( + raise InvalidArgumentError( f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" ) diff --git a/chromadb/auth/__init__.py b/chromadb/auth/__init__.py index ad0f7e3c99f..e0a7dd6d12d 100644 --- a/chromadb/auth/__init__.py +++ b/chromadb/auth/__init__.py @@ -10,6 +10,7 @@ Tuple, TypeVar, ) +from chromadb.errors import InvalidArgumentError from dataclasses import dataclass from pydantic import SecretStr @@ -109,12 +110,12 @@ def read_creds_or_creds_file(self) -> List[str]: if self._system.settings.chroma_server_authn_credentials: _creds = str(self._system.settings["chroma_server_authn_credentials"]) if not _creds_file and not _creds: - raise ValueError( + raise InvalidArgumentError( "No credentials file or credentials found in " "[chroma_server_authn_credentials]." ) if _creds_file and _creds: - raise ValueError( + raise InvalidArgumentError( "Both credentials file and credentials found." "Please provide only one." ) @@ -123,7 +124,7 @@ def read_creds_or_creds_file(self) -> List[str]: elif _creds_file: with open(_creds_file, "r") as f: return f.readlines() - raise ValueError("Should never happen") + raise InvalidArgumentError("Should never happen") def singleton_tenant_database_if_applicable( self, user: Optional[UserIdentity] @@ -219,11 +220,11 @@ def read_config_or_config_file(self) -> List[str]: if self._system.settings.chroma_server_authz_config: _config = str(self._system.settings["chroma_server_authz_config"]) if not _config_file and not _config: - raise ValueError( + raise InvalidArgumentError( "No authz configuration file or authz configuration found." ) if _config_file and _config: - raise ValueError( + raise InvalidArgumentError( "Both authz configuration file and authz configuration found." "Please provide only one." ) @@ -232,4 +233,4 @@ def read_config_or_config_file(self) -> List[str]: elif _config_file: with open(_config_file, "r") as f: return f.readlines() - raise ValueError("Should never happen") + raise InvalidArgumentError("Should never happen") diff --git a/chromadb/auth/basic_authn/__init__.py b/chromadb/auth/basic_authn/__init__.py index 66af698513b..fa7932b92fb 100644 --- a/chromadb/auth/basic_authn/__init__.py +++ b/chromadb/auth/basic_authn/__init__.py @@ -18,7 +18,10 @@ AuthError, ) from chromadb.config import System -from chromadb.errors import ChromaAuthError +from chromadb.errors import ( + ChromaAuthError, + InvalidArgumentError +) from chromadb.telemetry.opentelemetry import ( OpenTelemetryGranularity, trace_method, @@ -84,14 +87,14 @@ def __init__(self, system: System) -> None: and len(_raw_creds) != 2 or not all(_raw_creds) ): - raise ValueError( + raise InvalidArgumentError( f"Invalid htpasswd credentials found: {_raw_creds}. " "Lines must be exactly :." ) username = _raw_creds[0] password = _raw_creds[1] if username in self._creds: - raise ValueError( + raise InvalidArgumentError( "Duplicate username found in " "[chroma_server_authn_credentials]. " "Usernames must be unique." diff --git a/chromadb/auth/token_authn/__init__.py b/chromadb/auth/token_authn/__init__.py index 00c6ae3a449..4d9fc9b3ddc 100644 --- a/chromadb/auth/token_authn/__init__.py +++ b/chromadb/auth/token_authn/__init__.py @@ -20,7 +20,10 @@ AuthError, ) from chromadb.config import System -from chromadb.errors import ChromaAuthError +from chromadb.errors import ( + ChromaAuthError, + InvalidArgumentError +) from chromadb.telemetry.opentelemetry import ( OpenTelemetryGranularity, trace_method, @@ -56,7 +59,7 @@ class TokenTransportHeader(str, Enum): def _check_token(token: str) -> None: token_str = str(token) if not all(c in valid_token_chars for c in token_str): - raise ValueError( + raise InvalidArgumentError( "Invalid token. Must contain only ASCII letters, digits, and punctuation." ) @@ -69,7 +72,7 @@ def _check_token(token: str) -> None: def _check_allowed_token_headers(token_header: str) -> None: if token_header not in allowed_token_headers: - raise ValueError( + raise InvalidArgumentError( f"Invalid token transport header: {token_header}. " f"Must be one of {allowed_token_headers}" ) @@ -169,7 +172,7 @@ def __init__(self, system: System) -> None: self._users = cast(List[User], yaml.safe_load("\n".join(creds))["users"]) for user in self._users: if "tokens" not in user: - raise ValueError("User missing tokens") + raise InvalidArgumentError("User missing tokens") if "tenant" not in user: user["tenant"] = "*" if "databases" not in user: @@ -180,7 +183,7 @@ def __init__(self, system: System) -> None: token in self._token_user_mapping and self._token_user_mapping[token] != user ): - raise ValueError( + raise InvalidArgumentError( f"Token {token} already in use: wanted to use it for " f"user {user['id']} but it's already in use by " f"user {self._token_user_mapping[token]}" diff --git a/chromadb/config.py b/chromadb/config.py index acb63b98b96..65f3d1944cd 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -2,6 +2,7 @@ import inspect import logging from abc import ABC +from chromadb.errors import InvalidArgumentError from enum import Enum from graphlib import TopologicalSorter from typing import Optional, List, Any, Dict, Set, Iterable, Union @@ -282,14 +283,14 @@ def require(self, key: str) -> Any: set""" val = self[key] if val is None: - raise ValueError(f"Missing required config value '{key}'") + raise InvalidArgumentError(f"Missing required config value '{key}'") return val def __getitem__(self, key: str) -> Any: val = getattr(self, key) # Error on legacy config values if isinstance(val, str) and val in _legacy_config_values: - raise ValueError(LEGACY_ERROR) + raise InvalidArgumentError(LEGACY_ERROR) return val class Config: @@ -356,7 +357,7 @@ def __init__(self, settings: Settings): # Validate settings don't contain any legacy config values for key in _legacy_config_keys: if settings[key] is not None: - raise ValueError(LEGACY_ERROR) + raise InvalidArgumentError(LEGACY_ERROR) if ( settings["chroma_segment_cache_policy"] is not None @@ -414,7 +415,7 @@ def instance(self, type: Type[T]) -> T: if inspect.isabstract(type): type_fqn = get_fqn(type) if type_fqn not in _abstract_type_keys: - raise ValueError(f"Cannot instantiate abstract type: {type}") + raise InvalidArgumentError(f"Cannot instantiate abstract type: {type}") key = _abstract_type_keys[type_fqn] fqn = self.settings.require(key) type = get_class(fqn, type) @@ -453,7 +454,7 @@ def stop(self) -> None: def reset_state(self) -> None: """Reset the state of this system and all constituents in reverse dependency order""" if not self.settings.allow_reset: - raise ValueError( + raise InvalidArgumentError( "Resetting is not allowed by this configuration (to enable it, set `allow_reset` to `True` in your Settings() or include `ALLOW_RESET=TRUE` in your environment variables)" ) for component in reversed(list(self.components())): diff --git a/chromadb/db/base.py b/chromadb/db/base.py index 0abe62208e9..3d1672aeb4e 100644 --- a/chromadb/db/base.py +++ b/chromadb/db/base.py @@ -7,6 +7,7 @@ import pypika import pypika.queries from chromadb.config import System, Component +from chromadb.errors import InvalidArgumentError from uuid import UUID from itertools import islice, count from chromadb.types import SeqId @@ -117,7 +118,7 @@ def decode_seq_id(seq_id_bytes: Union[bytes, int]) -> SeqId: elif len(seq_id_bytes) == 24: return int.from_bytes(seq_id_bytes, "big") else: - raise ValueError(f"Unknown SeqID type with length {len(seq_id_bytes)}") + raise InvalidArgumentError(f"Unknown SeqID type with length {len(seq_id_bytes)}") @staticmethod def encode_seq_id(seq_id: SeqId) -> bytes: @@ -127,7 +128,7 @@ def encode_seq_id(seq_id: SeqId) -> bytes: elif seq_id.bit_length() <= 192: return int.to_bytes(seq_id, 24, "big") else: - raise ValueError(f"Unsupported SeqID: {seq_id}") + raise InvalidArgumentError(f"Unsupported SeqID: {seq_id}") _context = local() diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index b663f873f8a..30984eb446f 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -5,7 +5,12 @@ from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger from chromadb.db.system import SysDB -from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError +from chromadb.errors import ( + NotFoundError, + UniqueConstraintError, + InternalError, + InvalidArgumentError +) from chromadb.proto.convert import ( from_proto_collection, from_proto_segment, @@ -333,7 +338,7 @@ def get_collections( ) if name is not None: if tenant is None and database is None: - raise ValueError( + raise InvalidArgumentError( "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" ) request = GetCollectionsRequest( diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index 9cc0ca50802..6cd24f83323 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -3,6 +3,7 @@ from chromadb.db.migrations import MigratableDB, Migration from chromadb.config import System, Settings import chromadb.db.base as base +from chromadb.errors import InvalidArgumentError from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue from chromadb.db.mixins.sysdb import SqlSysDB from chromadb.telemetry.opentelemetry import ( @@ -147,7 +148,7 @@ def tx(self) -> TxWrapper: @override def reset_state(self) -> None: if not self._settings.require("allow_reset"): - raise ValueError( + raise InvalidArgumentError( "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) with self.tx() as cur: diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index d072c852e1a..deb3f4348f3 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -5,7 +5,10 @@ EmbeddingsQueueConfigurationInternal, ) from chromadb.db.base import SqlDB, ParameterValue, get_sql -from chromadb.errors import BatchSizeExceededError +from chromadb.errors import ( + BatchSizeExceededError, + InvalidArgumentError +) from chromadb.ingest import ( Producer, Consumer, @@ -408,9 +411,9 @@ def _validate_range( start = start or self._next_seq_id() end = end or self.max_seqid() if not isinstance(start, int) or not isinstance(end, int): - raise TypeError("SeqIDs must be integers for sql-based EmbeddingsDB") + raise InvalidArgumentError("SeqIDs must be integers for sql-based EmbeddingsDB") if start >= end: - raise ValueError(f"Invalid SeqID range: {start} to {end}") + raise InvalidArgumentError(f"Invalid SeqID range: {start} to {end}") return start, end @trace_method("SqlEmbeddingsQueue._next_seq_id", OpenTelemetryGranularity.ALL) diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index a10042cefe9..23b47ad80cd 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -10,6 +10,7 @@ CollectionConfigurationInternal, ConfigurationParameter, HNSWConfigurationInternal, + InvalidArgumentError, InvalidConfigurationError, ) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System @@ -218,7 +219,7 @@ def create_collection( database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: if id is None and not get_or_create: - raise ValueError("id must be specified if get_or_create is False") + raise InvalidArgumentError("id must be specified if get_or_create is False") add_attributes_to_current_span( { @@ -386,7 +387,7 @@ def get_collections( """Get collections by name, embedding function and/or metadata""" if name is not None and (tenant is None or database is None): - raise ValueError( + raise InvalidArgumentError( "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" ) @@ -799,7 +800,7 @@ def _load_config_from_json_str_and_migrate( try: config_json = json.loads(json_str) except json.JSONDecodeError: - raise ValueError( + raise InvalidArgumentError( f"Unable to decode configuration from JSON string: {json_str}" ) diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index f40ba2f50c9..a47fa3c4b74 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -7,6 +7,7 @@ Vector, ScalarEncoding, ) +from chromadb.errors import InvalidArgumentError from chromadb.config import Component from uuid import UUID import numpy as np @@ -20,7 +21,7 @@ def encode_vector(vector: Vector, encoding: ScalarEncoding) -> bytes: elif encoding == ScalarEncoding.INT32: return np.array(vector, dtype=np.int32).tobytes() else: - raise ValueError(f"Unsupported encoding: {encoding.value}") + raise InvalidArgumentError(f"Unsupported encoding: {encoding.value}") def decode_vector(vector: bytes, encoding: ScalarEncoding) -> Vector: @@ -31,7 +32,7 @@ def decode_vector(vector: bytes, encoding: ScalarEncoding) -> Vector: elif encoding == ScalarEncoding.INT32: return np.frombuffer(vector, dtype=np.float32) else: - raise ValueError(f"Unsupported encoding: {encoding.value}") + raise InvalidArgumentError(f"Unsupported encoding: {encoding.value}") class Producer(Component): diff --git a/chromadb/ingest/impl/utils.py b/chromadb/ingest/impl/utils.py index 4ad92df6bc3..c74a8f5b108 100644 --- a/chromadb/ingest/impl/utils.py +++ b/chromadb/ingest/impl/utils.py @@ -3,6 +3,7 @@ from uuid import UUID from chromadb.db.base import SqlDB +from chromadb.errors import InvalidArgumentError from chromadb.segment import SegmentManager, VectorReader topic_regex = r"persistent:\/\/(?P.+)\/(?P.+)\/(?P.+)" @@ -12,7 +13,7 @@ def parse_topic_name(topic_name: str) -> Tuple[str, str, str]: """Parse the topic name into the tenant, namespace and topic name""" match = re.match(topic_regex, topic_name) if not match: - raise ValueError(f"Invalid topic name: {topic_name}") + raise InvalidArgumentError(f"Invalid topic name: {topic_name}") return match.group("tenant"), match.group("namespace"), match.group("topic") diff --git a/chromadb/proto/convert.py b/chromadb/proto/convert.py index 47f5e0e08b9..047024a7024 100644 --- a/chromadb/proto/convert.py +++ b/chromadb/proto/convert.py @@ -16,6 +16,7 @@ SegmentScan, ) from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan +from chromadb.errors import InvalidArgumentError from chromadb.types import ( Collection, LogRecord, @@ -55,7 +56,7 @@ def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vecto as_bytes = np.array(vector, dtype=np.int32).tobytes() proto_encoding = chroma_pb.ScalarEncoding.INT32 else: - raise ValueError( + raise InvalidArgumentError( f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ or {ScalarEncoding.INT32}" ) @@ -73,7 +74,7 @@ def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncodi as_array = np.frombuffer(vector.vector, dtype=np.int32) out_encoding = ScalarEncoding.INT32 else: - raise ValueError( + raise InvalidArgumentError( f"Unknown encoding {encoding}, expected one of \ {chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}" ) @@ -125,7 +126,7 @@ def _from_proto_metadata_handle_none( elif is_update: out_metadata[key] = None else: - raise ValueError(f"Metadata key {key} value cannot be None") + raise InvalidArgumentError(f"Metadata key {key} value cannot be None") return out_metadata @@ -216,7 +217,7 @@ def to_proto_metadata_update_value( elif value is None: return chroma_pb.UpdateMetadataValue() else: - raise ValueError( + raise InvalidArgumentError( f"Unknown metadata value type {type(value)}, expected one of str, int, \ float, or None" ) @@ -268,7 +269,7 @@ def to_proto_operation(operation: Operation) -> chroma_pb.Operation: elif operation == Operation.DELETE: return chroma_pb.Operation.DELETE else: - raise ValueError( + raise InvalidArgumentError( f"Unknown operation {operation}, expected one of {Operation.ADD}, \ {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}" ) @@ -343,15 +344,15 @@ def to_proto_request_version_context( def to_proto_where(where: Where) -> chroma_pb.Where: response = chroma_pb.Where() if len(where) != 1: - raise ValueError(f"Expected where to have exactly one operator, got {where}") + raise InvalidArgumentError(f"Expected where to have exactly one operator, got {where}") for key, value in where.items(): if not isinstance(key, str): - raise ValueError(f"Expected where key to be a str, got {key}") + raise InvalidArgumentError(f"Expected where key to be a str, got {key}") if key == "$and" or key == "$or": if not isinstance(value, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) children: chroma_pb.WhereChildren = chroma_pb.WhereChildren( @@ -394,20 +395,20 @@ def to_proto_where(where: Where) -> chroma_pb.Where: sdc.generic_comparator = chroma_pb.GenericComparator.EQ dc.single_double_operand.CopyFrom(sdc) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where value to be a string, int, or float, got {value}" ) else: for operator, operand in value.items(): if operator in ["$in", "$nin"]: if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $in or $nin to be a list of values, got {value}" ) if len(operand) == 0 or not all( isinstance(x, type(operand[0])) for x in operand ): - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a non-empty list, and all values to be of the same type " f"got {operand}" ) @@ -441,7 +442,7 @@ def to_proto_where(where: Where) -> chroma_pb.Where: dlo.list_operator = list_operator dc.double_list_operand.CopyFrom(dlo) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a list of strings, ints, or floats, got {operand}" ) elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: @@ -454,7 +455,7 @@ def to_proto_where(where: Where) -> chroma_pb.Where: elif operator == "$ne": ssc.comparator = chroma_pb.GenericComparator.NE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_string_operand.CopyFrom(ssc) @@ -466,7 +467,7 @@ def to_proto_where(where: Where) -> chroma_pb.Where: elif operator == "$ne": sbc.comparator = chroma_pb.GenericComparator.NE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_bool_operand.CopyFrom(sbc) @@ -486,7 +487,7 @@ def to_proto_where(where: Where) -> chroma_pb.Where: elif operator == "$lte": sic.number_comparator = chroma_pb.NumberComparator.LTE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_int_operand.CopyFrom(sic) @@ -506,12 +507,12 @@ def to_proto_where(where: Where) -> chroma_pb.Where: elif operator == "$lte": sfc.number_comparator = chroma_pb.NumberComparator.LTE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_double_operand.CopyFrom(sfc) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a string, int, or float, got {operand}" ) else: @@ -526,7 +527,7 @@ def to_proto_where(where: Where) -> chroma_pb.Where: def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDocument: response = chroma_pb.WhereDocument() if len(where_document) != 1: - raise ValueError( + raise InvalidArgumentError( f"Expected where_document to have exactly one operator, got {where_document}" ) @@ -534,7 +535,7 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc if operator == "$and" or operator == "$or": # Nested "$and" or "$or" expression. if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" ) children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren( @@ -550,7 +551,7 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc # Direct "$contains" or "$not_contains" comparison to a single # value. if not isinstance(operand, str): - raise ValueError( + raise InvalidArgumentError( f"Expected where_document operand to be a string, got {operand}" ) dwd = chroma_pb.DirectWhereDocument() @@ -560,7 +561,7 @@ def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDoc elif operator == "$not_contains": dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS else: - raise ValueError( + raise InvalidArgumentError( f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" ) response.direct.CopyFrom(dwd) diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 12a6b35fa7d..09d9348fc53 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -7,6 +7,7 @@ from overrides import EnforceOverrides, override from chromadb.config import System +from chromadb.errors import InvalidArgumentError from chromadb.segment.distributed import ( Memberlist, MemberlistProvider, @@ -77,7 +78,7 @@ def __init__(self, system: System): @override def start(self) -> None: if self._memberlist_name is None: - raise ValueError("Memberlist name must be set before starting") + raise InvalidArgumentError("Memberlist name must be set before starting") self.get_memberlist() self._done_waiting_for_reset.clear() self._watch_worker_memberlist() @@ -103,7 +104,7 @@ def reset_state(self) -> None: # get propagated back again # Note that the component must be running in order to reset the state if not self._system.settings.require("allow_reset"): - raise ValueError( + raise InvalidArgumentError( "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) if self._memberlist_name: @@ -242,7 +243,7 @@ def stop(self) -> None: @override def get_segment_endpoint(self, segment: Segment) -> str: if self._curr_memberlist is None or len(self._curr_memberlist) == 0: - raise ValueError("Memberlist is not initialized") + raise InvalidArgumentError("Memberlist is not initialized") # Query to the same collection should end up on the same endpoint assignment = assign( segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1 @@ -273,3 +274,4 @@ def extract_service_name(self, pod_name: str) -> Optional[str]: if len(parts) > 1: return "-".join(parts[:-1]) return None + \ No newline at end of file diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index ba10155216e..1abfc474061 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -12,6 +12,10 @@ SegmentManager, SegmentType, ) +from chromadb.config import System, get_class +from chromadb.db.system import SysDB +from chromadb.errors import InvalidArgumentError +from overrides import override from chromadb.segment.distributed import SegmentDirectory from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams from chromadb.telemetry.opentelemetry import ( @@ -84,7 +88,14 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: "DistributedSegmentManager.get_segment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) - def get_segment(self, collection_id: UUID, scope: SegmentScope) -> Segment: + def get_segment(self, collection_id: UUID, type: Type[S]) -> S: + if type == MetadataReader: + scope = SegmentScope.METADATA + elif type == VectorReader: + scope = SegmentScope.VECTOR + else: + raise InvalidArgumentError(f"Invalid segment type: {type}") + if scope not in self._segment_cache[collection_id]: # For now, there is exactly one segment per scope for a given collection segment = self._sysdb.get_segments(collection=collection_id, scope=scope)[0] diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index 296ace7f9e7..e500f493706 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -17,6 +17,7 @@ from chromadb.config import System, get_class from chromadb.db.system import SysDB +from chromadb.errors import InvalidArgumentError from overrides import override from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, @@ -204,7 +205,7 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: elif type == VectorReader: scope = SegmentScope.VECTOR else: - raise ValueError(f"Invalid segment type: {type}") + raise InvalidArgumentError(f"Invalid segment type: {type}") segment = self.segment_cache[scope].get(collection_id) if segment is None: diff --git a/chromadb/segment/impl/metadata/grpc_segment.py b/chromadb/segment/impl/metadata/grpc_segment.py index 53ffdc72734..de57311c50b 100644 --- a/chromadb/segment/impl/metadata/grpc_segment.py +++ b/chromadb/segment/impl/metadata/grpc_segment.py @@ -143,17 +143,17 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: if where is None: return response if len(where) != 1: - raise ValueError( + raise InvalidArgumentError( f"Expected where to have exactly one operator, got {where}" ) for key, value in where.items(): if not isinstance(key, str): - raise ValueError(f"Expected where key to be a str, got {key}") + raise InvalidArgumentError(f"Expected where key to be a str, got {key}") if key == "$and" or key == "$or": if not isinstance(value, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) children: pb.WhereChildren = pb.WhereChildren( @@ -196,20 +196,20 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: sdc.generic_comparator = pb.GenericComparator.EQ dc.single_double_operand.CopyFrom(sdc) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where value to be a string, int, or float, got {value}" ) else: for operator, operand in value.items(): if operator in ["$in", "$nin"]: if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where value for $in or $nin to be a list of values, got {value}" ) if len(operand) == 0 or not all( isinstance(x, type(operand[0])) for x in operand ): - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a non-empty list, and all values to be of the same type " f"got {operand}" ) @@ -243,7 +243,7 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: dlo.list_operator = list_operator dc.double_list_operand.CopyFrom(dlo) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a list of strings, ints, or floats, got {operand}" ) elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: @@ -256,7 +256,7 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: elif operator == "$ne": ssc.comparator = pb.GenericComparator.NE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_string_operand.CopyFrom(ssc) @@ -268,7 +268,7 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: elif operator == "$ne": sbc.comparator = pb.GenericComparator.NE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_bool_operand.CopyFrom(sbc) @@ -288,7 +288,7 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: elif operator == "$lte": sic.number_comparator = pb.NumberComparator.LTE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_int_operand.CopyFrom(sic) @@ -308,12 +308,12 @@ def _where_to_proto(self, where: Optional[Where]) -> pb.Where: elif operator == "$lte": sfc.number_comparator = pb.NumberComparator.LTE else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_double_operand.CopyFrom(sfc) else: - raise ValueError( + raise InvalidArgumentError( f"Expected where operand value to be a string, int, or float, got {operand}" ) else: @@ -331,7 +331,7 @@ def _where_document_to_proto( if where_document is None: return response if len(where_document) != 1: - raise ValueError( + raise InvalidArgumentError( f"Expected where_document to have exactly one operator, got {where_document}" ) @@ -339,7 +339,7 @@ def _where_document_to_proto( if operator == "$and" or operator == "$or": # Nested "$and" or "$or" expression. if not isinstance(operand, list): - raise ValueError( + raise InvalidArgumentError( f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" ) children: pb.WhereDocumentChildren = pb.WhereDocumentChildren( @@ -355,7 +355,7 @@ def _where_document_to_proto( # Direct "$contains" or "$not_contains" comparison to a single # value. if not isinstance(operand, str): - raise ValueError( + raise InvalidArgumentError( f"Expected where_document operand to be a string, got {operand}" ) dwd = pb.DirectWhereDocument() @@ -365,7 +365,7 @@ def _where_document_to_proto( elif operator == "$not_contains": dwd.operator = pb.WhereDocumentOperator.NOT_CONTAINS else: - raise ValueError( + raise InvalidArgumentError( f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" ) response.direct.CopyFrom(dwd) @@ -387,7 +387,7 @@ def _from_proto( elif value.HasField("float_value"): translated_metadata[key] = value.float_value else: - raise ValueError(f"Unknown metadata value type: {value}") + raise InvalidArgumentError(f"Unknown metadata value type: {value}") mer = MetadataEmbeddingRecord(id=record.id, metadata=translated_metadata) diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index d54dc95f342..e94f83fdde1 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -10,6 +10,7 @@ ParameterValue, get_sql, ) +from chromadb.errors import InvalidArgumentError from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, OpenTelemetryGranularity, @@ -129,7 +130,7 @@ def get_metadata( offset = offset or 0 if limit < 0: - raise ValueError("Limit cannot be negative") + raise InvalidArgumentError("Limit cannot be negative") select_clause = [ embeddings_t.id, @@ -587,8 +588,8 @@ def _where_doc_criterion( else embeddings_t.id.notin(sq) ) else: - raise ValueError(f"Unknown where_doc operator {k}") - raise ValueError("Empty where_doc") + raise InvalidArgumentError(f"Unknown where_doc operator {k}") + raise InvalidArgumentError("Empty where_doc") @trace_method("SqliteMetadataSegment.delete", OpenTelemetryGranularity.ALL) @override diff --git a/chromadb/segment/impl/vector/hnsw_params.py b/chromadb/segment/impl/vector/hnsw_params.py index b12c4281508..3b07218651f 100644 --- a/chromadb/segment/impl/vector/hnsw_params.py +++ b/chromadb/segment/impl/vector/hnsw_params.py @@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, Union from chromadb.types import Metadata +from chromadb.errors import InvalidArgumentError + Validator = Callable[[Union[str, int, float]], bool] @@ -38,9 +40,9 @@ def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> Non # Validate it for param, value in metadata.items(): if param not in validators: - raise ValueError(f"Unknown HNSW parameter: {param}") + raise InvalidArgumentError(f"Unknown HNSW parameter: {param}") if not validators[param](value): - raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") + raise InvalidArgumentError(f"Invalid value for HNSW parameter: {param} = {value}") class HnswParams(Params): diff --git a/chromadb/telemetry/product/events.py b/chromadb/telemetry/product/events.py index 568a84ca7c7..eccce4d737f 100644 --- a/chromadb/telemetry/product/events.py +++ b/chromadb/telemetry/product/events.py @@ -1,8 +1,8 @@ import os from typing import cast, ClassVar +from chromadb.errors import InvalidArgumentError from chromadb.telemetry.product import ProductTelemetryEvent - class ClientStartEvent(ProductTelemetryEvent): def __init__(self) -> None: super().__init__() @@ -70,7 +70,7 @@ def batch_key(self) -> str: def batch(self, other: "ProductTelemetryEvent") -> "CollectionAddEvent": if not self.batch_key == other.batch_key: - raise ValueError("Cannot batch events") + raise InvalidArgumentError("Cannot batch events") other = cast(CollectionAddEvent, other) total_amount = self.add_amount + other.add_amount return CollectionAddEvent( @@ -118,7 +118,7 @@ def batch_key(self) -> str: def batch(self, other: "ProductTelemetryEvent") -> "CollectionUpdateEvent": if not self.batch_key == other.batch_key: - raise ValueError("Cannot batch events") + raise InvalidArgumentError("Cannot batch events") other = cast(CollectionUpdateEvent, other) total_amount = self.update_amount + other.update_amount return CollectionUpdateEvent( @@ -176,7 +176,7 @@ def batch_key(self) -> str: def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent": if not self.batch_key == other.batch_key: - raise ValueError("Cannot batch events") + raise InvalidArgumentError("Cannot batch events") other = cast(CollectionQueryEvent, other) total_amount = self.query_amount + other.query_amount return CollectionQueryEvent( @@ -228,7 +228,7 @@ def batch_key(self) -> str: def batch(self, other: "ProductTelemetryEvent") -> "CollectionGetEvent": if not self.batch_key == other.batch_key: - raise ValueError("Cannot batch events") + raise InvalidArgumentError("Cannot batch events") other = cast(CollectionGetEvent, other) total_amount = self.ids_count + other.ids_count return CollectionGetEvent( diff --git a/chromadb/test/api/test_types.py b/chromadb/test/api/test_types.py index d2c6ebfb4fc..fe25100a339 100644 --- a/chromadb/test/api/test_types.py +++ b/chromadb/test/api/test_types.py @@ -1,6 +1,7 @@ import pytest from typing import List, cast from chromadb.api.types import EmbeddingFunction, Documents, Image, Document, Embeddings +from chromadb.errors import InvalidArgumentError import numpy as np @@ -40,6 +41,6 @@ def __call__(self, input: Documents) -> Embeddings: return cast(Embeddings, invalid_embedding) ef = TestEmbeddingFunction() - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: ef(random_documents()) - assert e.type is ValueError + assert e.type is InvalidArgumentError diff --git a/chromadb/test/client/create_http_client_with_basic_auth.py b/chromadb/test/client/create_http_client_with_basic_auth.py index 86be4727c5c..5f9395da4dc 100644 --- a/chromadb/test/client/create_http_client_with_basic_auth.py +++ b/chromadb/test/client/create_http_client_with_basic_auth.py @@ -5,6 +5,7 @@ import chromadb from chromadb.config import Settings +from chromadb.errors import InvalidArgumentError import sys @@ -18,7 +19,7 @@ def main() -> None: chroma_client_auth_credentials="admin:testDb@home2", ), ) - except ValueError: + except InvalidArgumentError: # We don't expect to be able to connect to Chroma. We just want to make sure # there isn't an ImportError. sys.exit(0) diff --git a/chromadb/test/configurations/test_configurations.py b/chromadb/test/configurations/test_configurations.py index 0d952957a73..7ac246a96f7 100644 --- a/chromadb/test/configurations/test_configurations.py +++ b/chromadb/test/configurations/test_configurations.py @@ -3,6 +3,7 @@ from chromadb.api.configuration import ( ConfigurationInternal, ConfigurationDefinition, + InvalidArgumentError, InvalidConfigurationError, StaticParameterError, ConfigurationParameter, @@ -56,7 +57,7 @@ def test_set_values() -> None: def test_get_invalid_parameter() -> None: test_configuration = TestConfiguration() - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): test_configuration.get_parameter("invalid_name") @@ -75,13 +76,13 @@ def test_validation() -> None: invalid_parameter_values = [ ConfigurationParameter(name="static_str_value", value=1.0) ] - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): TestConfiguration(parameters=invalid_parameter_values) invalid_parameter_names = [ ConfigurationParameter(name="invalid_name", value="some_value") ] - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): TestConfiguration(parameters=invalid_parameter_names) @@ -101,10 +102,10 @@ def configuration_validator(self) -> None: if self.parameter_map.get("foo") != "bar": raise InvalidConfigurationError("foo must be 'bar'") - with pytest.raises(ValueError, match="foo must be 'bar'"): + with pytest.raises(InvalidArgumentError, match="foo must be 'bar'"): FooConfiguration(parameters=[ConfigurationParameter(name="foo", value="baz")]) def test_hnsw_validation() -> None: - with pytest.raises(ValueError, match="must be less than or equal"): + with pytest.raises(InvalidArgumentError, match="must be less than or equal"): HNSWConfiguration(batch_size=500, sync_threshold=100) diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 47e03c11fe3..6edb9d122a0 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -29,6 +29,7 @@ from chromadb.api import ClientAPI, ServerAPI, BaseAPI from chromadb.config import Settings, System from chromadb.db.mixins import embeddings_queue +from chromadb.errors import InvalidArgumentError from chromadb.ingest import Producer from chromadb.types import SeqId, OperationRecord from chromadb.api.client import Client as ClientCreator, AdminClient @@ -45,7 +46,7 @@ CURRENT_PRESET = os.getenv("PROPERTY_TESTING_PRESET", "fast") if CURRENT_PRESET not in VALID_PRESETS: - raise ValueError( + raise InvalidArgumentError( f"Invalid property testing preset: {CURRENT_PRESET}. Must be one of {VALID_PRESETS}." ) diff --git a/chromadb/test/data_loader/test_data_loader.py b/chromadb/test/data_loader/test_data_loader.py index a19b519a6fb..5fd5e534071 100644 --- a/chromadb/test/data_loader/test_data_loader.py +++ b/chromadb/test/data_loader/test_data_loader.py @@ -6,6 +6,7 @@ import chromadb from chromadb.api.types import URI, DataLoader, Documents, IDs, Image, URIs from chromadb.api import ClientAPI +from chromadb.errors import InvalidArgumentError from chromadb.test.ef.test_multimodal_ef import hashing_multimodal_ef @@ -59,14 +60,14 @@ def test_without_data_loader( record_set = record_set_with_uris(n=n_examples) # Can't embed data in URIs without a data loader - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection_without_data_loader.add( ids=record_set["ids"], uris=record_set["uris"], ) # Can't get data from URIs without a data loader - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection_without_data_loader.get(include=["data"]) diff --git a/chromadb/test/ef/test_default_ef.py b/chromadb/test/ef/test_default_ef.py index a80ccd2813b..d1c8e0d33de 100644 --- a/chromadb/test/ef/test_default_ef.py +++ b/chromadb/test/ef/test_default_ef.py @@ -7,6 +7,7 @@ import pytest from hypothesis import given, settings +from chromadb.errors import InvalidArgumentError from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( ONNXMiniLM_L6_V2, _verify_sha256, @@ -28,7 +29,7 @@ def unique_by(x: Hashable) -> Hashable: ) ) def test_unavailable_provider_multiple(providers: List[str]) -> None: - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: ef = ONNXMiniLM_L6_V2(preferred_providers=providers) ef(["test"]) assert "Preferred providers must be subset of available providers" in str(e.value) @@ -62,7 +63,7 @@ def test_warning_no_providers_supplied() -> None: ).filter(lambda x: len(x) > len(set(x))) ) def test_provider_repeating(providers: List[str]) -> None: - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: ef = ONNXMiniLM_L6_V2(preferred_providers=providers) ef(["test"]) assert "Preferred providers must be unique" in str(e.value) @@ -71,7 +72,7 @@ def test_provider_repeating(providers: List[str]) -> None: def test_invalid_sha256() -> None: ef = ONNXMiniLM_L6_V2() shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: ef._MODEL_SHA256 = "invalid" ef(["test"]) assert "does not match expected SHA256 hash" in str(e.value) diff --git a/chromadb/test/ef/test_multimodal_ef.py b/chromadb/test/ef/test_multimodal_ef.py index 44ef4ad9f5b..8018afc4bc4 100644 --- a/chromadb/test/ef/test_multimodal_ef.py +++ b/chromadb/test/ef/test_multimodal_ef.py @@ -2,6 +2,7 @@ import numpy as np import pytest import chromadb +from chromadb.errors import InvalidArgumentError from chromadb.api.types import ( Embeddable, EmbeddingFunction, @@ -69,7 +70,7 @@ def test_multimodal( # Trying to add a document and an image at the same time should fail with pytest.raises( - ValueError, + InvalidArgumentError, # This error string may be in any order match=r"Exactly one of (images|documents|uris)(?:, (images|documents|uris))?(?:, (images|documents|uris))? must be provided in add\.", ): @@ -119,7 +120,7 @@ def test_multimodal( ] # Querying with both images and documents should fail - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): multimodal_collection.query( query_images=[query_image], query_texts=[query_document] ) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 710de47d4c1..6cc381d5f07 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -18,6 +18,7 @@ Embeddings, Metadata, ) +from chromadb.errors import InvalidArgumentError from chromadb.types import LiteralValue, WhereOperator, LogicalOperator # Set the random seed for reproducibility @@ -300,7 +301,7 @@ def collections( use_persistent_hnsw_params = draw(with_persistent_hnsw_params) if use_persistent_hnsw_params and not with_hnsw_params: - raise ValueError( + raise InvalidArgumentError( "with_persistent_hnsw_params requires with_hnsw_params to be true" ) diff --git a/chromadb/test/property/test_client_url.py b/chromadb/test/property/test_client_url.py index cc5df1e0514..fc9a4d961dc 100644 --- a/chromadb/test/property/test_client_url.py +++ b/chromadb/test/property/test_client_url.py @@ -4,6 +4,7 @@ import pytest from hypothesis import given, strategies as st +from chromadb.errors import InvalidArgumentError from chromadb.api.fastapi import FastAPI @@ -124,7 +125,7 @@ def test_resolve_invalid( ssl_enabled: bool, default_api_path: Optional[str], ) -> None: - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: FastAPI.resolve_url( chroma_server_host=hostname, chroma_server_http_port=port, diff --git a/chromadb/test/property/test_collections_with_database_tenant_overwrite.py b/chromadb/test/property/test_collections_with_database_tenant_overwrite.py index e4844dcca2f..8b63338b593 100644 --- a/chromadb/test/property/test_collections_with_database_tenant_overwrite.py +++ b/chromadb/test/property/test_collections_with_database_tenant_overwrite.py @@ -13,6 +13,7 @@ from chromadb.api import AdminAPI from chromadb.api.client import AdminClient, Client from chromadb.config import Settings, System +from chromadb.errors import InvalidArgumentError from chromadb.test.conftest import ( ClientFactories, fastapi_fixture_admin_and_singleton_tenant_db_user, @@ -105,7 +106,7 @@ def set_tenant_model( # thanks to the above overriding of get_tenant_model(), # the underlying state machine test should always expect an error # when it sends the request, so shouldn't try to update the model. - raise ValueError("trying to overwrite the model for singleton??") + raise InvalidArgumentError("trying to overwrite the model for singleton??") self.tenant_to_database_to_model[tenant] = model @overrides @@ -121,7 +122,7 @@ def set_database_model_for_tenant( # thanks to the above overriding of has_database_for_tenant(), # the underlying state machine test should always expect an error # when it sends the request, so shouldn't try to update the model. - raise ValueError("trying to overwrite the model for singleton??") + raise InvalidArgumentError("trying to overwrite the model for singleton??") self.tenant_to_database_to_model[tenant][database] = database_model @overrides diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 507cd20eb19..6d8699021bc 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -18,6 +18,7 @@ ) from chromadb.config import System import chromadb.errors as errors +from chromadb.errors import InvalidArgumentError from chromadb.api import ClientAPI from chromadb.api.models.Collection import Collection import chromadb.test.property.strategies as strategies @@ -1157,7 +1158,7 @@ def test_escape_chars_in_ids(client: ClientAPI) -> None: def test_delete_empty_fails(client: ClientAPI) -> None: reset(client) coll = client.create_collection(name="foo") - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): coll.delete() @@ -1235,7 +1236,7 @@ def test_autocasting_validate_embeddings_incompatible_types( unsupported_types: List[Any], ) -> None: embds = strategies.create_embeddings(10, 10, unsupported_types) - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: validate_embeddings(cast(Embeddings, normalize_embeddings(embds))) assert ( @@ -1246,7 +1247,7 @@ def test_autocasting_validate_embeddings_incompatible_types( def test_0dim_embedding_validation() -> None: embds: Embeddings = [np.array([])] - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: validate_embeddings(embds) assert ( "Expected each embedding in the embeddings to be a 1-dimensional numpy array with at least 1 int/float value. Got a 1-dimensional numpy array with no values at pos" diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index df94ed59fe8..c8da3fdb35a 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -14,6 +14,7 @@ Where, WhereDocument, ) +from chromadb.errors import InvalidArgumentError from chromadb.test.conftest import reset, NOT_CLUSTER_ONLY import chromadb.test.property.strategies as strategies import hypothesis.strategies as st @@ -76,7 +77,7 @@ def _filter_where_clause(clause: Where, metadata: Optional[Metadata]) -> bool: elif op == "$lte": return key in metadata and metadata[key] <= val else: - raise ValueError("Unknown operator: {}".format(key)) + raise InvalidArgumentError("Unknown operator: {}".format(key)) def _filter_where_doc_clause(clause: WhereDocument, doc: Document) -> bool: @@ -110,7 +111,7 @@ def _filter_where_doc_clause(clause: WhereDocument, doc: Document) -> bool: return re.search(expr, doc) is None return expr not in doc else: - raise ValueError("Unknown operator: {}".format(key)) + raise InvalidArgumentError("Unknown operator: {}".format(key)) EMPTY_DICT: Dict[Any, Any] = {} diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 50bab861800..2d7027abce4 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -17,6 +17,7 @@ from chromadb.config import System, Settings from chromadb.db.base import ParameterValue, get_sql from chromadb.db.impl.sqlite import SqliteDB +from chromadb.errors import InvalidArgumentError from chromadb.test.conftest import ProducerFn from chromadb.types import ( OperationRecord, @@ -809,7 +810,7 @@ def test_limit( assert len(res) == 3 # if limit is negative, throw error - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): segment.get_metadata(limit=-1, request_version_context=request_version_context) # if offset is more than number of results, return empty list @@ -996,7 +997,7 @@ def test_include_metadata( def test_metadata_validation_forbidden_key() -> None: - with pytest.raises(ValueError, match="chroma:document"): + with pytest.raises(InvalidArgumentError, match="chroma:document"): validate_metadata( {"chroma:document": "this is not the document you are looking for"} ) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index ab91408c992..4d9d2b879b8 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -3,7 +3,10 @@ import httpx import chromadb -from chromadb.errors import ChromaError +from chromadb.errors import ( + ChromaError, + InvalidArgumentError +) from chromadb.api.fastapi import FastAPI from chromadb.api.types import QueryResult, EmbeddingFunction, Document from chromadb.config import Settings @@ -709,7 +712,7 @@ def test_metadata_update_get_int_float(client): def test_metadata_validation_add(client): client.reset() collection = client.create_collection("test_metadata_validation") - with pytest.raises(ValueError, match="metadata"): + with pytest.raises(InvalidArgumentError, match="metadata"): collection.add(**bad_metadata_records) @@ -717,21 +720,21 @@ def test_metadata_validation_update(client): client.reset() collection = client.create_collection("test_metadata_validation") collection.add(**metadata_records) - with pytest.raises(ValueError, match="metadata"): + with pytest.raises(InvalidArgumentError, match="metadata"): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) def test_where_validation_get(client): client.reset() collection = client.create_collection("test_where_validation") - with pytest.raises(ValueError, match="where"): + with pytest.raises(InvalidArgumentError, match="where"): collection.get(where={"value": {"nested": "5"}}) def test_where_validation_query(client): client.reset() collection = client.create_collection("test_where_validation") - with pytest.raises(ValueError, match="where"): + with pytest.raises(InvalidArgumentError, match="where"): collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}}) @@ -799,39 +802,39 @@ def test_where_valid_operators(client): client.reset() collection = client.create_collection("test_where_valid_operators") collection.add(**operator_records) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"int_value": {"$invalid": 2}}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"int_value": {"$lt": "2"}}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"int_value": {"$lt": 2, "$gt": 1}}) # Test invalid $and, $or - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"$and": {"int_value": {"$lt": 2}}}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get( where={"int_value": {"$lt": 2}, "$or": {"int_value": {"$gt": 1}}} ) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get( where={"$gt": [{"int_value": {"$lt": 2}}, {"int_value": {"$gt": 1}}]} ) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"$or": [{"int_value": {"$lt": 2}}]}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"$or": []}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where={"a": {"$contains": "test"}}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get( where={ "$or": [ @@ -882,31 +885,31 @@ def test_query_document_valid_operators(client): client.reset() collection = client.create_collection("test_where_valid_operators") collection.add(**operator_records) - with pytest.raises(ValueError, match="where document"): + with pytest.raises(InvalidArgumentError, match="where document"): collection.get(where_document={"$lt": {"$nested": 2}}) - with pytest.raises(ValueError, match="where document"): + with pytest.raises(InvalidArgumentError, match="where document"): collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2}) - with pytest.raises(ValueError, match="where document"): + with pytest.raises(InvalidArgumentError, match="where document"): collection.get(where_document={"$contains": []}) # Test invalid $and, $or - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where_document={"$and": {"$unsupported": "doc"}}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get( where_document={"$or": [{"$unsupported": "doc"}, {"$unsupported": "doc"}]} ) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where_document={"$or": [{"$contains": "doc"}]}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get(where_document={"$or": []}) - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): collection.get( where_document={ "$or": [{"$and": [{"$contains": "doc"}]}, {"$contains": "doc"}] @@ -1174,10 +1177,10 @@ def test_get_include(client): assert items["ids"][0] == "id1" assert items["included"] == [] - with pytest.raises(ValueError, match="include"): + with pytest.raises(InvalidArgumentError, match="include"): items = collection.get(include=["metadatas", "undefined"]) - with pytest.raises(ValueError, match="include"): + with pytest.raises(InvalidArgumentError, match="include"): items = collection.get(include=None) @@ -1206,17 +1209,17 @@ def test_invalid_id(client): client.reset() collection = client.create_collection("test_invalid_id") # Add with non-string id - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}]) assert "ID" in str(e.value) # Get with non-list id - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.get(ids=1) assert "ID" in str(e.value) # Delete with malformed ids - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.delete(ids=["valid", 0]) assert "ID" in str(e.value) @@ -1429,7 +1432,7 @@ def test_invalid_n_results_param(client): client.reset() collection = client.create_collection("testspace") collection.add(**records) - with pytest.raises(TypeError) as exc: + with pytest.raises(InvalidArgumentError) as exc: collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results=-1, @@ -1438,16 +1441,16 @@ def test_invalid_n_results_param(client): assert "Number of requested results -1, cannot be negative, or zero." in str( exc.value ) - assert exc.type == TypeError + assert exc.type == InvalidArgumentError - with pytest.raises(ValueError) as exc: + with pytest.raises(InvalidArgumentError) as exc: collection.query( query_embeddings=[[1.1, 2.3, 3.2]], n_results="one", include=["embeddings", "documents", "metadatas", "distances"], ) assert "int" in str(exc.value) - assert exc.type == ValueError + assert exc.type == InvalidArgumentError initial_records = { @@ -1548,12 +1551,12 @@ def test_invalid_embeddings(client): "embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]], "ids": ["id1", "id2"], } - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.add(**invalid_records) assert "embedding" in str(e.value) # Query with invalid embeddings - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.query( query_embeddings=[["1.1", "2.3", "3.2"]], n_results=1, @@ -1565,7 +1568,7 @@ def test_invalid_embeddings(client): "embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]], "ids": ["id1", "id2"], } - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.update(**invalid_records) assert "embedding" in str(e.value) @@ -1574,7 +1577,7 @@ def test_invalid_embeddings(client): "embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]], "ids": ["id1", "id2"], } - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: collection.upsert(**invalid_records) assert "embedding" in str(e.value) @@ -1616,7 +1619,7 @@ def test_ssl_self_signed_without_ssl_verify(client_ssl): pytest.skip("Skipping test for integration test") client_ssl.heartbeat() _port = client_ssl._server._settings.chroma_server_http_port - with pytest.raises(ValueError) as e: + with pytest.raises(InvalidArgumentError) as e: chromadb.HttpClient(ssl=True, port=_port) stack_trace = traceback.format_exception( type(e.value), e.value, e.value.__traceback__ diff --git a/chromadb/test/test_chroma.py b/chromadb/test/test_chroma.py index 89b4ae924eb..0b73fe9efae 100644 --- a/chromadb/test/test_chroma.py +++ b/chromadb/test/test_chroma.py @@ -6,6 +6,7 @@ import chromadb.config from chromadb.db.system import SysDB from chromadb.ingest import Consumer, Producer +from chromadb.errors import InvalidArgumentError class GetDBTest(unittest.TestCase): @@ -100,7 +101,7 @@ def test_settings_pass_to_fastapi(self, mock: Mock) -> None: def test_legacy_values() -> None: - with pytest.raises(ValueError): + with pytest.raises(InvalidArgumentError): client = chromadb.Client( chromadb.config.Settings( chroma_api_impl="chromadb.api.local.LocalAPI", diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index f8180f9ca26..e468cc02fad 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -4,6 +4,7 @@ import chromadb from chromadb.config import Settings from chromadb.api import ClientAPI +from chromadb.errors import InvalidArgumentError import chromadb.server.fastapi import pytest import tempfile @@ -79,7 +80,7 @@ def test_http_client_with_inconsistent_host_settings( ) -> None: try: http_api_factory(settings=Settings(chroma_server_host="127.0.0.1")) - except ValueError as e: + except InvalidArgumentError as e: assert ( str(e) == "Chroma server host provided in settings[127.0.0.1] is different to the one provided in HttpClient: [localhost]" @@ -96,7 +97,7 @@ def test_http_client_with_inconsistent_port_settings( chroma_server_http_port=8001, ), ) - except ValueError as e: + except InvalidArgumentError as e: assert ( str(e) == "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]" diff --git a/chromadb/utils/data_loaders.py b/chromadb/utils/data_loaders.py index 82ea894aa9a..f0971d78a7b 100644 --- a/chromadb/utils/data_loaders.py +++ b/chromadb/utils/data_loaders.py @@ -3,6 +3,7 @@ from typing import Optional, Sequence, List, Tuple import numpy as np from chromadb.api.types import URI, DataLoader, Image, URIs +from chromadb.errors import InvalidArgumentError from concurrent.futures import ThreadPoolExecutor @@ -12,7 +13,7 @@ def __init__(self, max_workers: int = multiprocessing.cpu_count()) -> None: self._PILImage = importlib.import_module("PIL.Image") self._max_workers = max_workers except ImportError: - raise ValueError( + raise InvalidArgumentError( "The PIL python package is not installed. Please install it with `pip install pillow`" ) diff --git a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py index 445cca5b128..1c665760e8b 100644 --- a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py +++ b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py @@ -2,6 +2,7 @@ from typing import Any, List, Union from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Images +from chromadb.errors import InvalidArgumentError logger = logging.getLogger(__name__) @@ -10,7 +11,7 @@ def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore try: from langchain_core.embeddings import Embeddings as LangchainEmbeddings except ImportError: - raise ValueError( + raise InvalidArgumentError( "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" ) @@ -40,7 +41,7 @@ def embed_image(self, uris: List[str]) -> List[List[float]]: if hasattr(self.embedding_function, "embed_image"): return self.embedding_function.embed_image(uris) # type: ignore else: - raise ValueError( + raise InvalidArgumentError( "The provided embedding function does not support image embeddings." ) diff --git a/chromadb/utils/embedding_functions/cohere_embedding_function.py b/chromadb/utils/embedding_functions/cohere_embedding_function.py index ef9c33e24b9..0f54428daea 100644 --- a/chromadb/utils/embedding_functions/cohere_embedding_function.py +++ b/chromadb/utils/embedding_functions/cohere_embedding_function.py @@ -1,5 +1,6 @@ import logging +from chromadb.errors import InvalidArgumentError from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -10,7 +11,7 @@ def __init__(self, api_key: str, model_name: str = "large"): try: import cohere except ImportError: - raise ValueError( + raise InvalidArgumentError( "The cohere python package is not installed. Please install it with `pip install cohere`" ) diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py index 0534d790674..2dda52f10b6 100644 --- a/chromadb/utils/embedding_functions/google_embedding_function.py +++ b/chromadb/utils/embedding_functions/google_embedding_function.py @@ -3,6 +3,8 @@ import httpx from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.errors import InvalidArgumentError + logger = logging.getLogger(__name__) @@ -12,15 +14,15 @@ class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): if not api_key: - raise ValueError("Please provide a PaLM API key.") + raise InvalidArgumentError("Please provide a PaLM API key.") if not model_name: - raise ValueError("Please provide the model name.") + raise InvalidArgumentError("Please provide the model name.") try: import google.generativeai as palm except ImportError: - raise ValueError( + raise InvalidArgumentError( "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" ) @@ -49,15 +51,15 @@ def __init__( task_type: str = "RETRIEVAL_DOCUMENT", ): if not api_key: - raise ValueError("Please provide a Google API key.") + raise InvalidArgumentError("Please provide a Google API key.") if not model_name: - raise ValueError("Please provide the model name.") + raise InvalidArgumentError("Please provide the model name.") try: import google.generativeai as genai except ImportError: - raise ValueError( + raise InvalidArgumentError( "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" ) diff --git a/chromadb/utils/embedding_functions/instructor_embedding_function.py b/chromadb/utils/embedding_functions/instructor_embedding_function.py index 18d13d8ec0b..425fd1b6ac4 100644 --- a/chromadb/utils/embedding_functions/instructor_embedding_function.py +++ b/chromadb/utils/embedding_functions/instructor_embedding_function.py @@ -2,6 +2,7 @@ from typing import Optional, cast from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.errors import InvalidArgumentError logger = logging.getLogger(__name__) @@ -18,7 +19,7 @@ def __init__( try: from InstructorEmbedding import INSTRUCTOR except ImportError: - raise ValueError( + raise InvalidArgumentError( "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" ) self._model = INSTRUCTOR(model_name, device=device) diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py index 2bef446116c..d74e17e8eb7 100644 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py @@ -12,6 +12,7 @@ import httpx from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random +from chromadb.errors import InvalidArgumentError from chromadb.api.types import Documents, EmbeddingFunction, Embeddings logger = logging.getLogger(__name__) @@ -51,32 +52,32 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: if preferred_providers and not all( [isinstance(i, str) for i in preferred_providers] ): - raise ValueError("Preferred providers must be a list of strings") + raise InvalidArgumentError("Preferred providers must be a list of strings") # check for duplicate providers if preferred_providers and len(preferred_providers) != len( set(preferred_providers) ): - raise ValueError("Preferred providers must be unique") + raise InvalidArgumentError("Preferred providers must be unique") self._preferred_providers = preferred_providers try: # Equivalent to import onnxruntime self.ort = importlib.import_module("onnxruntime") except ImportError: - raise ValueError( + raise InvalidArgumentError( "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" ) try: # Equivalent to from tokenizers import Tokenizer self.Tokenizer = importlib.import_module("tokenizers").Tokenizer except ImportError: - raise ValueError( + raise InvalidArgumentError( "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" ) try: # Equivalent to from tqdm import tqdm self.tqdm = importlib.import_module("tqdm").tqdm except ImportError: - raise ValueError( + raise InvalidArgumentError( "The tqdm python package is not installed. Please install it with `pip install tqdm`" ) @@ -112,7 +113,7 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: if not _verify_sha256(fname, self._MODEL_SHA256): # if the integrity of the file is not verified, remove it os.remove(fname) - raise ValueError( + raise InvalidArgumentError( f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." ) @@ -178,7 +179,7 @@ def model(self) -> "InferenceSession": # noqa F821 elif not set(self._preferred_providers).issubset( set(self.ort.get_available_providers()) ): - raise ValueError( + raise InvalidArgumentError( f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" ) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 0d05b6c27b6..98bba14f1bf 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -13,6 +13,8 @@ is_document, is_image, ) +from chromadb.errors import InvalidArgumentError + logger = logging.getLogger(__name__) @@ -27,20 +29,20 @@ def __init__( try: import open_clip except ImportError: - raise ValueError( + raise InvalidArgumentError( "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" ) try: self._torch = importlib.import_module("torch") except ImportError: - raise ValueError( + raise InvalidArgumentError( "The torch python package is not installed. Please install it with `pip install torch`" ) try: self._PILImage = importlib.import_module("PIL.Image") except ImportError: - raise ValueError( + raise InvalidArgumentError( "The PIL python package is not installed. Please install it with `pip install pillow`" ) diff --git a/chromadb/utils/embedding_functions/openai_embedding_function.py b/chromadb/utils/embedding_functions/openai_embedding_function.py index ce333cc3bba..3efb7cb50de 100644 --- a/chromadb/utils/embedding_functions/openai_embedding_function.py +++ b/chromadb/utils/embedding_functions/openai_embedding_function.py @@ -2,6 +2,7 @@ from typing import Mapping, Optional, cast from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.errors import InvalidArgumentError logger = logging.getLogger(__name__) @@ -46,14 +47,14 @@ def __init__( try: import openai except ImportError: - raise ValueError( + raise InvalidArgumentError( "The openai python package is not installed. Please install it with `pip install openai`" ) self._api_key = api_key or openai.api_key # If the api key is still not set, raise an error if self._api_key is None: - raise ValueError( + raise InvalidArgumentError( "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" ) @@ -142,3 +143,4 @@ def __call__(self, input: Documents) -> Embeddings: return cast( Embeddings, [result["embedding"] for result in sorted_embeddings] ) + \ No newline at end of file diff --git a/chromadb/utils/embedding_functions/roboflow_embedding_function.py b/chromadb/utils/embedding_functions/roboflow_embedding_function.py index b118aa01c64..d6edbb3498c 100644 --- a/chromadb/utils/embedding_functions/roboflow_embedding_function.py +++ b/chromadb/utils/embedding_functions/roboflow_embedding_function.py @@ -15,6 +15,8 @@ is_document, is_image, ) +from chromadb.errors import InvalidArgumentError + logger = logging.getLogger(__name__) @@ -39,7 +41,7 @@ def __init__( try: self._PILImage = importlib.import_module("PIL.Image") except ImportError: - raise ValueError( + raise InvalidArgumentError( "The PIL python package is not installed. Please install it with `pip install pillow`" ) diff --git a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py index 7a2b57c6ae9..5626c1c505a 100644 --- a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py +++ b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py @@ -2,6 +2,7 @@ from typing import Any, Dict, cast from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.errors import InvalidArgumentError logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def __init__( try: from sentence_transformers import SentenceTransformer except ImportError: - raise ValueError( + raise InvalidArgumentError( "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" ) self.models[model_name] = SentenceTransformer( diff --git a/chromadb/utils/embedding_functions/text2vec_embedding_function.py b/chromadb/utils/embedding_functions/text2vec_embedding_function.py index 9e4639c23d0..4bc0eace66e 100644 --- a/chromadb/utils/embedding_functions/text2vec_embedding_function.py +++ b/chromadb/utils/embedding_functions/text2vec_embedding_function.py @@ -2,6 +2,8 @@ from typing import cast from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from chromadb.errors import InvalidArgumentError + logger = logging.getLogger(__name__) @@ -11,7 +13,7 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): try: from text2vec import SentenceModel except ImportError: - raise ValueError( + raise InvalidArgumentError( "The text2vec python package is not installed. Please install it with `pip install text2vec`" ) self._model = SentenceModel(model_name_or_path=model_name) diff --git a/chromadb/utils/fastapi.py b/chromadb/utils/fastapi.py index 8300880e402..a6efc4483a3 100644 --- a/chromadb/utils/fastapi.py +++ b/chromadb/utils/fastapi.py @@ -1,7 +1,12 @@ from uuid import UUID from starlette.responses import JSONResponse -from chromadb.errors import ChromaError, InvalidUUIDError +from chromadb.errors import ( + ChromaError, + InvalidUUIDError, + InvalidArgumentError +) + def fastapi_json_response(error: ChromaError) -> JSONResponse: @@ -14,5 +19,5 @@ def fastapi_json_response(error: ChromaError) -> JSONResponse: def string_to_uuid(uuid_str: str) -> UUID: try: return UUID(uuid_str) - except ValueError: + except InvalidArgumentError: raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID") diff --git a/chromadb/utils/rendezvous_hash.py b/chromadb/utils/rendezvous_hash.py index f21e2863225..08f853e1945 100644 --- a/chromadb/utils/rendezvous_hash.py +++ b/chromadb/utils/rendezvous_hash.py @@ -1,4 +1,5 @@ # An implementation of https://en.wikipedia.org/wiki/Rendezvous_hashing +from chromadb.errors import InvalidArgumentError from typing import Callable, List, Tuple import mmh3 import heapq @@ -23,16 +24,16 @@ def assign( """ if replication > len(members): - raise ValueError( + raise InvalidArgumentError( "Replication factor cannot be greater than the number of members" ) if len(members) == 0: - raise ValueError("Cannot assign key to empty memberlist") + raise InvalidArgumentError("Cannot assign key to empty memberlist") if len(members) == 1: # Don't copy the input list for some safety return [members[0]] if key == "": - raise ValueError("Cannot assign empty key") + raise InvalidArgumentError("Cannot assign empty key") member_score_heap: List[Tuple[int, Member]] = [] for member in members: diff --git a/clients/js/src/ChromaFetch.ts b/clients/js/src/ChromaFetch.ts index b7373f9c14f..32962287f11 100644 --- a/clients/js/src/ChromaFetch.ts +++ b/clients/js/src/ChromaFetch.ts @@ -15,9 +15,9 @@ import { FetchAPI } from "./generated"; function isOfflineError(error: any): boolean { return Boolean( (error?.name === "TypeError" || error?.name === "FetchError") && - (error.message?.includes("fetch failed") || - error.message?.includes("Failed to fetch") || - error.message?.includes("ENOTFOUND")), + (error.message?.includes("fetch failed") || + error.message?.includes("Failed to fetch") || + error.message?.includes("ENOTFOUND")), ); } diff --git a/docs/docs.trychroma.com/pages/reference/py-client.md b/docs/docs.trychroma.com/pages/reference/py-client.md index 2ab4bac6318..61ffa2317ab 100644 --- a/docs/docs.trychroma.com/pages/reference/py-client.md +++ b/docs/docs.trychroma.com/pages/reference/py-client.md @@ -194,7 +194,7 @@ Delete a collection with the given name. **Raises**: -- `ValueError` - If the collection does not exist. +- `InvalidArgumentError` - If the collection does not exist. **Examples**: @@ -311,8 +311,8 @@ Create a new collection with the given name and metadata. **Raises**: -- `ValueError` - If the collection already exists and get_or_create is False. -- `ValueError` - If the collection name is invalid. +- `InvalidArgumentError` - If the collection already exists and get_or_create is False. +- `InvalidArgumentError` - If the collection name is invalid. **Examples**: @@ -353,7 +353,7 @@ Get a collection with the given name. **Raises**: -- `ValueError` - If the collection does not exist +- `InvalidArgumentError` - If the collection does not exist **Examples**: diff --git a/docs/docs.trychroma.com/pages/reference/py-collection.md b/docs/docs.trychroma.com/pages/reference/py-collection.md index 3202689e504..87977d89169 100644 --- a/docs/docs.trychroma.com/pages/reference/py-collection.md +++ b/docs/docs.trychroma.com/pages/reference/py-collection.md @@ -46,9 +46,9 @@ Add embeddings to the data store. **Raises**: -- `ValueError` - If you don't provide either embeddings or documents -- `ValueError` - If the length of ids, embeddings, metadatas, or documents don't match -- `ValueError` - If you don't provide an embedding function and don't provide embeddings +- `InvalidArgumentError` - If you don't provide either embeddings or documents +- `InvalidArgumentError` - If the length of ids, embeddings, metadatas, or documents don't match +- `InvalidArgumentError` - If you don't provide an embedding function and don't provide embeddings - `DuplicateIDError` - If you provide an id that already exists # get @@ -128,8 +128,8 @@ Get the n_results nearest neighbor embeddings for provided query_embeddings or q **Raises**: -- `ValueError` - If you don't provide either query_embeddings or query_texts -- `ValueError` - If you provide both query_embeddings and query_texts +- `InvalidArgumentError` - If you don't provide either query_embeddings or query_texts +- `InvalidArgumentError` - If you provide both query_embeddings and query_texts # modify