Skip to content

Commit 3d898ed

Browse files
authored
Merge branch 'chroma-core:main' into main
2 parents 83b2308 + da68516 commit 3d898ed

File tree

8 files changed

+35
-19
lines changed

8 files changed

+35
-19
lines changed

chromadb/api/fastapi.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import json
1+
import orjson as json
22
import logging
33
from typing import Optional, cast, Tuple
44
from typing import Sequence
@@ -147,7 +147,7 @@ def heartbeat(self) -> int:
147147
"""Returns the current server time in nanoseconds to check if the server is alive"""
148148
resp = self._session.get(self._api_url)
149149
raise_chroma_error(resp)
150-
return int(resp.json()["nanosecond heartbeat"])
150+
return int(json.loads(resp.text)["nanosecond heartbeat"])
151151

152152
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
153153
@override
@@ -177,7 +177,7 @@ def get_database(
177177
params={"tenant": tenant},
178178
)
179179
raise_chroma_error(resp)
180-
resp_json = resp.json()
180+
resp_json = json.loads(resp.text)
181181
return Database(
182182
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
183183
)
@@ -198,7 +198,7 @@ def get_tenant(self, name: str) -> Tenant:
198198
self._api_url + "/tenants/" + name,
199199
)
200200
raise_chroma_error(resp)
201-
resp_json = resp.json()
201+
resp_json = json.loads(resp.text)
202202
return Tenant(name=resp_json["name"])
203203

204204
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
@@ -221,7 +221,7 @@ def list_collections(
221221
},
222222
)
223223
raise_chroma_error(resp)
224-
json_collections = resp.json()
224+
json_collections = json.loads(resp.text)
225225
collections = []
226226
for json_collection in json_collections:
227227
collections.append(Collection(self, **json_collection))
@@ -239,7 +239,7 @@ def count_collections(
239239
params={"tenant": tenant, "database": database},
240240
)
241241
raise_chroma_error(resp)
242-
return cast(int, resp.json())
242+
return cast(int, json.loads(resp.text))
243243

244244
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
245245
@override
@@ -268,7 +268,7 @@ def create_collection(
268268
params={"tenant": tenant, "database": database},
269269
)
270270
raise_chroma_error(resp)
271-
resp_json = resp.json()
271+
resp_json = json.loads(resp.text)
272272
return Collection(
273273
client=self,
274274
id=resp_json["id"],
@@ -302,7 +302,7 @@ def get_collection(
302302
self._api_url + "/collections/" + name if name else str(id), params=_params
303303
)
304304
raise_chroma_error(resp)
305-
resp_json = resp.json()
305+
resp_json = json.loads(resp.text)
306306
return Collection(
307307
client=self,
308308
name=resp_json["name"],
@@ -381,7 +381,7 @@ def _count(
381381
self._api_url + "/collections/" + str(collection_id) + "/count"
382382
)
383383
raise_chroma_error(resp)
384-
return cast(int, resp.json())
384+
return cast(int, json.loads(resp.text))
385385

386386
@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
387387
@override
@@ -434,7 +434,7 @@ def _get(
434434
)
435435

436436
raise_chroma_error(resp)
437-
body = resp.json()
437+
body = json.loads(resp.text)
438438
return GetResult(
439439
ids=body["ids"],
440440
embeddings=body.get("embeddings", None),
@@ -462,7 +462,7 @@ def _delete(
462462
)
463463

464464
raise_chroma_error(resp)
465-
return cast(IDs, resp.json())
465+
return cast(IDs, json.loads(resp.text))
466466

467467
@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
468468
def _submit_batch(
@@ -586,7 +586,7 @@ def _query(
586586
)
587587

588588
raise_chroma_error(resp)
589-
body = resp.json()
589+
body = json.loads(resp.text)
590590

591591
return QueryResult(
592592
ids=body["ids"],
@@ -604,15 +604,15 @@ def reset(self) -> bool:
604604
"""Resets the database"""
605605
resp = self._session.post(self._api_url + "/reset")
606606
raise_chroma_error(resp)
607-
return cast(bool, resp.json())
607+
return cast(bool, json.loads(resp.text))
608608

609609
@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
610610
@override
611611
def get_version(self) -> str:
612612
"""Returns the version of the server"""
613613
resp = self._session.get(self._api_url + "/version")
614614
raise_chroma_error(resp)
615-
return cast(str, resp.json())
615+
return cast(str, json.loads(resp.text))
616616

617617
@override
618618
def get_settings(self) -> Settings:
@@ -626,7 +626,7 @@ def max_batch_size(self) -> int:
626626
if self._max_batch_size == -1:
627627
resp = self._session.get(self._api_url + "/pre-flight-checks")
628628
raise_chroma_error(resp)
629-
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
629+
self._max_batch_size = cast(int, json.loads(resp.text)["max_batch_size"])
630630
return self._max_batch_size
631631

632632

@@ -637,7 +637,7 @@ def raise_chroma_error(resp: requests.Response) -> None:
637637

638638
chroma_error = None
639639
try:
640-
body = resp.json()
640+
body = json.loads(resp.text)
641641
if "error" in body:
642642
if body["error"] in errors.error_types:
643643
chroma_error = errors.error_types[body["error"]](body["message"])

chromadb/api/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Re-export types from chromadb.types
2222
__all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"]
23-
23+
META_KEY_CHROMA_DOCUMENT = "chroma:document"
2424
T = TypeVar("T")
2525
OneOrMany = Union[T, List[T]]
2626

@@ -265,6 +265,10 @@ def validate_metadata(metadata: Metadata) -> Metadata:
265265
if len(metadata) == 0:
266266
raise ValueError(f"Expected metadata to be a non-empty dict, got {metadata}")
267267
for key, value in metadata.items():
268+
if key == META_KEY_CHROMA_DOCUMENT:
269+
raise ValueError(
270+
f"Expected metadata to not contain the reserved key {META_KEY_CHROMA_DOCUMENT}"
271+
)
268272
if not isinstance(key, str):
269273
raise TypeError(
270274
f"Expected metadata key to be a str, got {key} which is a {type(key)}"
@@ -476,7 +480,7 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
476480
raise ValueError(
477481
f"Expected each embedding in the embeddings to be a list, got {embeddings}"
478482
)
479-
for i,embedding in enumerate(embeddings):
483+
for i, embedding in enumerate(embeddings):
480484
if len(embedding) == 0:
481485
raise ValueError(
482486
f"Expected each embedding in the embeddings to be a non-empty list, got empty embedding at pos {i}"

chromadb/test/segment/test_metadata.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import tempfile
44
import pytest
55
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
6+
7+
from chromadb.api.types import validate_metadata
68
from chromadb.config import System, Settings
79
from chromadb.db.base import ParameterValue, get_sql
810
from chromadb.db.impl.sqlite import SqliteDB
@@ -677,3 +679,10 @@ def test_delete_segment(
677679
res = cur.execute(sql, params)
678680
# assert that all FTS rows are gone
679681
assert len(res.fetchall()) == 0
682+
683+
684+
def test_metadata_validation_forbidden_key() -> None:
685+
with pytest.raises(ValueError, match="chroma:document"):
686+
validate_metadata(
687+
{"chroma:document": "this is not the document you are looking for"}
688+
)

clients/python/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
'typing_extensions >= 4.5.0',
2727
'tenacity>=8.2.3',
2828
'PyYAML>=6.0.0',
29+
'orjson>=3.9.12',
2930
]
3031

3132
[tool.black]

clients/python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ PyYAML>=6.0.0
99
requests >= 2.28
1010
tenacity>=8.2.3
1111
typing_extensions >= 4.5.0
12+
orjson>=3.9.12

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dependencies = [
4343
'tenacity>=8.2.3',
4444
'PyYAML>=6.0.0',
4545
'mmh3>=4.0.1',
46+
'orjson>=3.9.12',
4647
]
4748

4849
[tool.black]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ tqdm>=4.65.0
2525
typer>=0.9.0
2626
typing_extensions>=4.5.0
2727
uvicorn[standard]==0.18.3
28+
orjson>=3.9.12

server.htpasswd

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)