Skip to content

Commit e7730dc

Browse files
csbasilatroyn
authored andcommitted
feat: add universal sentence encoder embedding function
1 parent 193988d commit e7730dc

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

chromadb/utils/embedding_functions.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
743743

744744

745745
class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
746-
def __init__(
747-
self, api_key: str = "", api_url = "https://infer.roboflow.com"
748-
) -> None:
746+
def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None:
749747
"""
750748
Create a RoboflowEmbeddingFunction.
751749
@@ -757,7 +755,7 @@ def __init__(
757755
api_key = os.environ.get("ROBOFLOW_API_KEY")
758756

759757
self._api_url = api_url
760-
self._api_key = api_key
758+
self._api_key = api_key
761759

762760
try:
763761
self._PILImage = importlib.import_module("PIL.Image")
@@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
789787
json=infer_clip_payload,
790788
)
791789

792-
result = res.json()['embeddings']
790+
result = res.json()["embeddings"]
793791

794792
embeddings.append(result[0])
795-
793+
796794
elif is_document(item):
797795
infer_clip_payload = {
798796
"text": input,
@@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
803801
json=infer_clip_payload,
804802
)
805803

806-
result = res.json()['embeddings']
804+
result = res.json()["embeddings"]
807805

808806
embeddings.append(result[0])
809807

810808
return embeddings
811809

812-
810+
813811
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
814812
def __init__(
815813
self,
@@ -900,6 +898,22 @@ def __call__(self, input: Documents) -> Embeddings:
900898
)
901899

902900

901+
class UniversalSentenceEncoderEmbeddingFunction(EmbeddingFunction[Documents]):
902+
def __init__(
903+
self, model_name: str = "https://tfhub.dev/google/universal-sentence-encoder/4"
904+
):
905+
try:
906+
import tensorflow_hub as hub
907+
except ImportError:
908+
raise ValueError(
909+
"The tensorflow_hub python package is not installed. Please install it with `pip install tensorflow_hub`"
910+
)
911+
self._model = hub.load(model_name)
912+
913+
def __call__(self, input: Documents) -> Embeddings:
914+
return cast(Embeddings, self._model(input).numpy().tolist())
915+
916+
903917
def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
904918
try:
905919
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
@@ -962,7 +976,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore
962976

963977
return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)
964978

965-
979+
966980
class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
967981
"""
968982
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
@@ -1018,7 +1032,7 @@ def __call__(self, input: Documents) -> Embeddings:
10181032
],
10191033
)
10201034

1021-
1035+
10221036
# List of all classes in this module
10231037
_classes = [
10241038
name

0 commit comments

Comments
 (0)