Skip to content
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,59 @@ def __call__(self, texts: Documents) -> Embeddings:
return embeddings


class VoyageAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, api_key: str, model_name: str = "voyage-01", batch_size: int = 8):
"""
Initialize the VoyageAIEmbeddingFunction.

Args:
api_key (str): Your API key for the HuggingFace API.
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "voyage-01".
batch_size (int, optional): The number of documents to send at a time. Defaults to 8 (The max supported 3rd Nov 2023).
"""
if batch_size > 8:
print(f"Voyage AI as of (3rd Nov 2023) has a batch size of max 8")

if not api_key:
raise ValueError("Please provide a VoyageAI API key.")

try:
import voyageai
from voyageai import get_embeddings
except ImportError:
raise ValueError("The VoyageAI python package is not installed. Please install it with `pip install voyageai`")

voyageai.api_key = api_key # Voyage API Key
self.batch_size = batch_size
self.model = model_name
self.get_embeddings = get_embeddings

def __call__(self, texts: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.

Args:
texts (Documents): A list of texts to get embeddings for.

Returns:
Embeddings: The embeddings for the texts.

Example:
>>> voyage_ef = VoyageAIEmbeddingFunction(api_key="your_api_key")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = voyage_ef(texts)
"""
iters = range(0, len(texts), self.batch_size)
embeddings = []
for i in iters:
results = self.get_embeddings(
texts[i : i + self.batch_size],
batch_size=self.batch_size,
model=self.model
)
embeddings += results;
return embeddings;

# List of all classes in this module
_classes = [
name
Expand Down