Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4bd82f9
VoyageAI Embedding function
fodizoltan Mar 4, 2024
f095270
VoyageAI Embedding function, TS
fodizoltan Mar 7, 2024
8e4a553
VoyageAI Embedding function, TS
fodizoltan Mar 7, 2024
1199e70
VoyageAI Embedding function, TS
fodizoltan Mar 8, 2024
7a9abf0
Remove default value
fodizoltan Mar 8, 2024
71b4df8
Export the VoyageAI embedded function
fodizoltan Mar 8, 2024
cbebe89
Change Cohere 2 VoyageAI in comments
fodizoltan Mar 11, 2024
8630b7c
Change Cohere 2 VoyageAI in comments
fodizoltan Mar 12, 2024
298493f
Merge pull request #1 from voyage-ai/voyageai_embedding_function
Liuhong99 Mar 12, 2024
6f3050e
input_type is None
fodizoltan Mar 15, 2024
2c04b1c
input_type enum
fodizoltan Mar 19, 2024
e5320f8
Merge pull request #2 from voyage-ai/voyageai_embedding_function
Liuhong99 Mar 19, 2024
048d8d0
a commit 2 ping...
fzowl Mar 27, 2024
f712cb5
a commit 2 ping...
fzowl Mar 27, 2024
8c6fc6f
Merge pull request #3 from voyage-ai/voyageai_embedding_function
fzowl Mar 27, 2024
cb31e6b
Merge branch 'main' into main
fzowl Mar 27, 2024
3693ccf
Corrections due to the comments
fzowl Apr 10, 2024
acab2fa
Merge branch 'main' into voyageai_embedding_function
fzowl Apr 10, 2024
a7e5d55
Merge pull request #4 from voyage-ai/voyageai_embedding_function
fzowl Apr 11, 2024
31172dd
Corrections due to the comments
fzowl Apr 11, 2024
5b75e4c
Merge pull request #5 from voyage-ai/voyageai_embedding_function
fzowl Apr 11, 2024
cbeb89b
Corrections due to the comments: removing the loop, raising exception
fzowl Apr 11, 2024
aef5d92
Merge pull request #6 from voyage-ai/voyageai_embedding_function
fzowl Apr 11, 2024
c1fc7e4
Merge branch 'main' into main
fzowl Apr 11, 2024
85c78f5
Corrections due to the comments: removing the loop, raising exception
fzowl Apr 18, 2024
f143007
Merge branch 'main' into voyageai_embedding_function
fzowl Apr 18, 2024
008eb0e
Merge pull request #7 from voyage-ai/voyageai_embedding_function
fzowl Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
from enum import Enum
from functools import cached_property

from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception
Expand Down Expand Up @@ -244,6 +245,67 @@ def __call__(self, input: Documents) -> Embeddings:
]


class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
class InputType(Enum):
DOCUMENT = "document"
QUERY = "query"

def __init__(
self,
api_key: str,
model_name: str,
embed_batch_size: Optional[int] = None,
truncation: Optional[bool] = None,
input_type: Optional[InputType] = None,
show_progress_bar: Optional[bool] = False,
):
try:
import voyageai
except ImportError:
raise ValueError(
"The voyageai python package is not installed. Please install it with `pip install -U voyageai`"
)

if embed_batch_size is None:
embed_batch_size = 72 if model_name in ["voyage-2", "voyage-02"] else 7

self._client = voyageai.Client(api_key=api_key)
self._model_name = model_name
self._batch_size = embed_batch_size
self._truncation = truncation
self._show_progress_bar = show_progress_bar
self._input_type = input_type.value

def __call__(self, input: Documents) -> Embeddings:
# Call VoyageAI Embedding API for each document.
embeddings: List[List[float]] = []

if self._show_progress_bar:
try:
from tqdm.auto import tqdm
except ImportError as e:
raise ImportError(
"Must have tqdm installed if `show_progress_bar` is set to True. "
"Please install with `pip install tqdm`."
) from e

_iter = tqdm(range(0, len(input), self._batch_size))
else:
_iter = range(0, len(input), self._batch_size)

for i in _iter:
embeddings.extend(
self._client.embed(
input[i : i + self._batch_size],
model=self._model_name,
truncation=self._truncation,
input_type=self._input_type
).embeddings
)

return embeddings


class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the HuggingFace API.
Expand Down
87 changes: 87 additions & 0 deletions clients/js/src/embeddings/VoyageAIEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

export enum InputType {
DOCUMENT = "document",
QUERY = "query"
}

export class VoyageAIEmbeddingFunction implements IEmbeddingFunction {
private model_name: string;
private api_url: string;
private batch_size: number;
private truncation?: boolean;
private input_type?: InputType;
private headers: { [key: string]: string };

constructor({
voyageai_api_key,
model_name,
batch_size,
truncation,
input_type,
}: {
voyageai_api_key: string;
model_name: string;
batch_size?: number;
truncation?: boolean;
input_type?: InputType;
}) {
this.api_url = "https://api.voyageai.com/v1/embeddings";
this.headers = {
Authorization: `Bearer ${voyageai_api_key}`,
"Content-Type": "application/json",
};

this.model_name = model_name;
this.truncation = truncation;
this.input_type = input_type;
if (batch_size) {
this.batch_size = batch_size;
} else {
if (model_name in ["voyage-2", "voyage-02"]) {
this.batch_size = 72;
} else {
this.batch_size = 7;
}
}
}

public async generate(texts: string[]) {
try {
const result: number[][] = [];
let index = 0;

while (index < texts.length) {
const response = await fetch(this.api_url, {
method: 'POST',
headers: this.headers,
body: JSON.stringify({
input: texts.slice(index, index + this.batch_size),
model: this.model_name,
truncation: this.truncation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this take a document/query param like python?

input_type: this.input_type,
}),
});

const data = (await response.json()) as { data: any[]; detail: string };
if (!data || !data.data) {
throw new Error(data.detail);
}

const embeddings: any[] = data.data;
const sortedEmbeddings = embeddings.sort((a, b) => a.index - b.index);

const embeddingsChunks = sortedEmbeddings.map((result) => result.embedding);
result.push(...embeddingsChunks);
index += this.batch_size;
}
return result;
} catch (error) {
if (error instanceof Error) {
throw new Error(`Error calling VoyageAI API: ${error.message}`);
} else {
throw new Error(`Error calling VoyageAI API: ${error}`);
}
}
}
}
17 changes: 9 additions & 8 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ export { AdminClient } from "./AdminClient";
export { CloudClient } from "./CloudClient";
export { Collection } from "./Collection";

export { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction";
export { OpenAIEmbeddingFunction } from "./embeddings/OpenAIEmbeddingFunction";
export { CohereEmbeddingFunction } from "./embeddings/CohereEmbeddingFunction";
export { TransformersEmbeddingFunction } from "./embeddings/TransformersEmbeddingFunction";
export { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction";
export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbeddingServerFunction";
export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction";
export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction";
export { IEmbeddingFunction } from './embeddings/IEmbeddingFunction';
export { OpenAIEmbeddingFunction } from './embeddings/OpenAIEmbeddingFunction';
export { CohereEmbeddingFunction } from './embeddings/CohereEmbeddingFunction';
export { TransformersEmbeddingFunction } from './embeddings/TransformersEmbeddingFunction';
export { DefaultEmbeddingFunction } from './embeddings/DefaultEmbeddingFunction';
export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';
export { JinaEmbeddingFunction } from './embeddings/JinaEmbeddingFunction';
export { GoogleGenerativeAiEmbeddingFunction } from './embeddings/GoogleGeminiEmbeddingFunction';
export { VoyageAIEmbeddingFunction, InputType } from './embeddings/VoyageAIEmbeddingFunction';

export {
IncludeEnum,
Expand Down
31 changes: 29 additions & 2 deletions clients/js/test/add.collections.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ import chroma from "./initClient";
import { DOCUMENTS, EMBEDDINGS, IDS } from "./data";
import { METADATAS } from "./data";
import { IncludeEnum } from "../src/types";
import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction";
import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction";
import {OpenAIEmbeddingFunction} from "../src/embeddings/OpenAIEmbeddingFunction";
import {CohereEmbeddingFunction} from "../src/embeddings/CohereEmbeddingFunction";
import {VoyageAIEmbeddingFunction, InputType} from "../src/embeddings/VoyageAIEmbeddingFunction";

test("it should add single embeddings to a collection", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
Expand Down Expand Up @@ -82,6 +84,31 @@ if (!process.env.COHERE_API_KEY) {
});
}

if (!process.env.VOYAGE_API_KEY) {
test.skip("it should add VoyageAI embeddings", async () => {
});
} else {
test("it should add VoyageAI embeddings", async () => {
await chroma.reset();
const embedder = new VoyageAIEmbeddingFunction({ voyageai_api_key: process.env.VOYAGE_API_KEY || "", model_name: "voyage-2", batch_size: 2, input_type: InputType.DOCUMENT })
const collection = await chroma.createCollection({ name: "test" ,embeddingFunction: embedder});
const embeddings = await embedder.generate(DOCUMENTS);
await collection.add({ ids: IDS, embeddings: embeddings });
const count = await collection.count();
expect(count).toBe(3);
expect(embeddings.length).toBe(3);
expect(embeddings[0].length).toBe(1024);
expect(embeddings[1].length).toBe(1024);
expect(embeddings[2].length).toBe(1024);
var res = await collection.get({
ids: IDS, include: [
IncludeEnum.Embeddings,
]
});
expect(res.embeddings).toEqual(embeddings); // reverse because of the order of the ids
});
}

test("add documents", async () => {
await chroma.reset();
const collection = await chroma.createCollection({ name: "test" });
Expand Down