@@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
743743
744744
745745class RoboflowEmbeddingFunction (EmbeddingFunction [Union [Documents , Images ]]):
746- def __init__ (
747- self , api_key : str = "" , api_url = "https://infer.roboflow.com"
748- ) -> None :
746+ def __init__ (self , api_key : str = "" , api_url = "https://infer.roboflow.com" ) -> None :
749747 """
750748 Create a RoboflowEmbeddingFunction.
751749
@@ -757,7 +755,7 @@ def __init__(
757755 api_key = os .environ .get ("ROBOFLOW_API_KEY" )
758756
759757 self ._api_url = api_url
760- self ._api_key = api_key
758+ self ._api_key = api_key
761759
762760 try :
763761 self ._PILImage = importlib .import_module ("PIL.Image" )
@@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
789787 json = infer_clip_payload ,
790788 )
791789
792- result = res .json ()[' embeddings' ]
790+ result = res .json ()[" embeddings" ]
793791
794792 embeddings .append (result [0 ])
795-
793+
796794 elif is_document (item ):
797795 infer_clip_payload = {
798796 "text" : input ,
@@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
803801 json = infer_clip_payload ,
804802 )
805803
806- result = res .json ()[' embeddings' ]
804+ result = res .json ()[" embeddings" ]
807805
808806 embeddings .append (result [0 ])
809807
810808 return embeddings
811809
812-
810+
813811class AmazonBedrockEmbeddingFunction (EmbeddingFunction [Documents ]):
814812 def __init__ (
815813 self ,
@@ -900,6 +898,22 @@ def __call__(self, input: Documents) -> Embeddings:
900898 )
901899
902900
901+ class UniversalSentenceEncoderEmbeddingFunction (EmbeddingFunction [Documents ]):
902+ def __init__ (
903+ self , model_name : str = "https://tfhub.dev/google/universal-sentence-encoder/4"
904+ ):
905+ try :
906+ import tensorflow_hub as hub
907+ except ImportError :
908+ raise ValueError (
909+ "The tensorflow_hub python package is not installed. Please install it with `pip install tensorflow_hub`"
910+ )
911+ self ._model = hub .load (model_name )
912+
913+ def __call__ (self , input : Documents ) -> Embeddings :
914+ return cast (Embeddings , self ._model (input ).numpy ().tolist ())
915+
916+
903917def create_langchain_embedding (langchain_embdding_fn : Any ): # type: ignore
904918 try :
905919 from langchain_core .embeddings import Embeddings as LangchainEmbeddings
@@ -962,7 +976,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore
962976
963977 return ChromaLangchainEmbeddingFunction (embedding_function = langchain_embdding_fn )
964978
965-
979+
966980class OllamaEmbeddingFunction (EmbeddingFunction [Documents ]):
967981 """
968982 This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
@@ -1018,7 +1032,7 @@ def __call__(self, input: Documents) -> Embeddings:
10181032 ],
10191033 )
10201034
1021-
1035+
10221036# List of all classes in this module
10231037_classes = [
10241038 name
0 commit comments