Skip to content

Commit 091f8bd

Browse files
authored
[CHORE] Disallow invalid fields in pydantic schema (#5838)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Fixes #5825 - Disallows invalid fields in schema types - New functionality - ... ## Test plan added tests to ensure incorrect keys are caught _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent cc7cd4b commit 091f8bd

File tree

2 files changed

+1066
-263
lines changed

2 files changed

+1066
-263
lines changed

chromadb/api/types.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_args,
1313
TYPE_CHECKING,
1414
Final,
15+
Type,
1516
)
1617
from copy import deepcopy
1718
from typing_extensions import TypeAlias
@@ -20,7 +21,8 @@
2021
import numpy as np
2122
import warnings
2223
from typing_extensions import TypedDict, Protocol, runtime_checkable
23-
from pydantic import BaseModel, field_validator
24+
from pydantic import BaseModel, field_validator, model_validator
25+
from pydantic_core import PydanticCustomError
2426

2527
import chromadb.errors as errors
2628
from chromadb.base_types import (
@@ -1493,15 +1495,57 @@ def validate_sparse_embedding_function(
14931495

14941496

14951497
# Index Configuration Types for Collection Schema
1498+
def _create_extra_fields_validator(valid_fields: list[str]) -> Any:
1499+
"""Create a model validator that provides helpful error messages for invalid fields."""
1500+
1501+
@model_validator(mode="before")
1502+
def validate_extra_fields(cls: Type[BaseModel], data: Any) -> Any:
1503+
if isinstance(data, dict):
1504+
invalid_fields = [k for k in data.keys() if k not in valid_fields]
1505+
if invalid_fields:
1506+
invalid_fields_str = ", ".join(f"'{f}'" for f in invalid_fields)
1507+
class_name = cls.__name__
1508+
# Create a clear, actionable error message
1509+
if len(invalid_fields) == 1:
1510+
msg = (
1511+
f"'{invalid_fields[0]}' is not a valid field for {class_name}. "
1512+
)
1513+
else:
1514+
msg = f"Invalid fields for {class_name}: {invalid_fields_str}. "
1515+
1516+
raise PydanticCustomError(
1517+
"invalid_field",
1518+
msg,
1519+
{"invalid_fields": invalid_fields},
1520+
)
1521+
return data
1522+
1523+
return validate_extra_fields
1524+
1525+
14961526
class FtsIndexConfig(BaseModel):
14971527
"""Configuration for Full-Text Search index. No parameters required."""
14981528

1529+
model_config = {"extra": "forbid"}
1530+
14991531
pass
15001532

15011533

15021534
class HnswIndexConfig(BaseModel):
15031535
"""Configuration for HNSW vector index."""
15041536

1537+
_validate_extra_fields = _create_extra_fields_validator(
1538+
[
1539+
"ef_construction",
1540+
"max_neighbors",
1541+
"ef_search",
1542+
"num_threads",
1543+
"batch_size",
1544+
"sync_threshold",
1545+
"resize_factor",
1546+
]
1547+
)
1548+
15051549
ef_construction: Optional[int] = None
15061550
max_neighbors: Optional[int] = None
15071551
ef_search: Optional[int] = None
@@ -1514,6 +1558,27 @@ class HnswIndexConfig(BaseModel):
15141558
class SpannIndexConfig(BaseModel):
15151559
"""Configuration for SPANN vector index."""
15161560

1561+
_validate_extra_fields = _create_extra_fields_validator(
1562+
[
1563+
"search_nprobe",
1564+
"search_rng_factor",
1565+
"search_rng_epsilon",
1566+
"nreplica_count",
1567+
"write_nprobe",
1568+
"write_rng_factor",
1569+
"write_rng_epsilon",
1570+
"split_threshold",
1571+
"num_samples_kmeans",
1572+
"initial_lambda",
1573+
"reassign_neighbor_count",
1574+
"merge_threshold",
1575+
"num_centers_to_merge_to",
1576+
"ef_construction",
1577+
"ef_search",
1578+
"max_neighbors",
1579+
]
1580+
)
1581+
15171582
search_nprobe: Optional[int] = None
15181583
write_nprobe: Optional[int] = None
15191584
ef_construction: Optional[int] = None
@@ -1527,7 +1592,8 @@ class SpannIndexConfig(BaseModel):
15271592
class VectorIndexConfig(BaseModel):
15281593
"""Configuration for vector index with space, embedding function, and algorithm config."""
15291594

1530-
model_config = {"arbitrary_types_allowed": True}
1595+
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
1596+
15311597
space: Optional[Space] = None
15321598
embedding_function: Optional[Any] = DefaultEmbeddingFunction()
15331599
source_key: Optional[
@@ -1577,7 +1643,8 @@ def validate_embedding_function_field(cls, v: Any) -> Any:
15771643
class SparseVectorIndexConfig(BaseModel):
15781644
"""Configuration for sparse vector index."""
15791645

1580-
model_config = {"arbitrary_types_allowed": True}
1646+
model_config = {"arbitrary_types_allowed": True, "extra": "forbid"}
1647+
15811648
# TODO(Sanket): Change this to the appropriate sparse ef and use a default here.
15821649
embedding_function: Optional[Any] = None
15831650
source_key: Optional[
@@ -1628,24 +1695,32 @@ def validate_embedding_function_field(cls, v: Any) -> Any:
16281695
class StringInvertedIndexConfig(BaseModel):
16291696
"""Configuration for string inverted index."""
16301697

1698+
model_config = {"extra": "forbid"}
1699+
16311700
pass
16321701

16331702

16341703
class IntInvertedIndexConfig(BaseModel):
16351704
"""Configuration for integer inverted index."""
16361705

1706+
model_config = {"extra": "forbid"}
1707+
16371708
pass
16381709

16391710

16401711
class FloatInvertedIndexConfig(BaseModel):
16411712
"""Configuration for float inverted index."""
16421713

1714+
model_config = {"extra": "forbid"}
1715+
16431716
pass
16441717

16451718

16461719
class BoolInvertedIndexConfig(BaseModel):
16471720
"""Configuration for boolean inverted index."""
16481721

1722+
model_config = {"extra": "forbid"}
1723+
16491724
pass
16501725

16511726

0 commit comments

Comments
 (0)