Skip to content

Commit 7da9620

Browse files
mihirchdevatroyn
andcommitted
[ENH] Support langchain embedding functions with chroma (#1880)
*Summarize the changes made by this PR.* - New functionality - Adding a function to create a chroma langchain embedding interface. This interface acts as a bridge between the langchain embedding function and the chroma custom embedding function. - Native Langchain multimodal support: The PR adds a Passthrough data loader that lets langchain users use OpenClip and other multi-modal embedding functions from langchain with chroma without having to handle storing images themselves. *How are these changes tested?* - installing chroma as an editable package locally and passing langchain integration tests - pytest test_api.py test_client.py succeeds *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* Co-authored-by: Anton Troynikov <[email protected]>
1 parent 70ed520 commit 7da9620

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

chromadb/utils/data_loaders.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import importlib
22
import multiprocessing
3-
from typing import Optional, Sequence, List
3+
from typing import Optional, Sequence, List, Tuple
44
import numpy as np
5-
from chromadb.api.types import URI, DataLoader, Image
5+
from chromadb.api.types import URI, DataLoader, Image, URIs
66
from concurrent.futures import ThreadPoolExecutor
77

88

@@ -22,3 +22,10 @@ def _load_image(self, uri: Optional[URI]) -> Optional[Image]:
2222
def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]:
2323
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
2424
return list(executor.map(self._load_image, uris))
25+
26+
27+
class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]):
28+
# This is a simple pass through data loader that just returns the input data with "images"
29+
# flag which lets the langchain embedding function know that the data is image uris
30+
def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore
31+
return ("images", uris)

chromadb/utils/embedding_functions.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,71 @@ def __init__(
914914

915915
def __call__(self, input: Documents) -> Embeddings:
916916
return cast(Embeddings, self._model(input).numpy().tolist())
917+
918+
919+
def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
920+
try:
921+
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
922+
except ImportError:
923+
raise ValueError(
924+
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
925+
)
926+
927+
class ChromaLangchainEmbeddingFunction(
928+
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
929+
):
930+
"""
931+
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
932+
"""
933+
934+
def __init__(self, embedding_function: LangchainEmbeddings) -> None:
935+
"""
936+
Initialize the ChromaLangchainEmbeddingFunction
937+
938+
Args:
939+
embedding_function : The embedding function implementing Embeddings from langchain_core.
940+
"""
941+
self.embedding_function = embedding_function
942+
943+
def embed_documents(self, documents: Documents) -> List[List[float]]:
944+
return self.embedding_function.embed_documents(documents) # type: ignore
945+
946+
def embed_query(self, query: str) -> List[float]:
947+
return self.embedding_function.embed_query(query) # type: ignore
948+
949+
def embed_image(self, uris: List[str]) -> List[List[float]]:
950+
if hasattr(self.embedding_function, "embed_image"):
951+
return self.embedding_function.embed_image(uris) # type: ignore
952+
else:
953+
raise ValueError(
954+
"The provided embedding function does not support image embeddings."
955+
)
956+
957+
def __call__(self, input: Documents) -> Embeddings: # type: ignore
958+
"""
959+
Get the embeddings for a list of texts or images.
960+
961+
Args:
962+
input (Documents | Images): A list of texts or images to get embeddings for.
963+
Images should be provided as a list of URIs passed through the langchain data loader
964+
965+
Returns:
966+
Embeddings: The embeddings for the texts or images.
967+
968+
Example:
969+
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
970+
>>> texts = ["Hello, world!", "How are you?"]
971+
>>> embeddings = langchain_embedding(texts)
972+
"""
973+
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
974+
if input[0] == "images":
975+
return self.embed_image(list(input[1])) # type: ignore
976+
977+
return self.embed_documents(list(input)) # type: ignore
978+
979+
return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)
980+
981+
917982
class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
918983
"""
919984
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
@@ -969,6 +1034,7 @@ def __call__(self, input: Documents) -> Embeddings:
9691034
],
9701035
)
9711036

1037+
9721038
# List of all classes in this module
9731039
_classes = [
9741040
name

0 commit comments

Comments
 (0)