-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[ENH] VoyageAI embedding function #1871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
4bd82f9
f095270
8e4a553
1199e70
7a9abf0
71b4df8
cbebe89
8630b7c
298493f
6f3050e
2c04b1c
e5320f8
048d8d0
f712cb5
8c6fc6f
cb31e6b
3693ccf
acab2fa
a7e5d55
31172dd
5b75e4c
cbeb89b
aef5d92
c1fc7e4
85c78f5
f143007
008eb0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -235,6 +235,61 @@ def __call__(self, input: Documents) -> Embeddings: | |
| ] | ||
|
|
||
|
|
||
| class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| def __init__( | ||
| self, | ||
| api_key: str, | ||
| model_name: str, | ||
| embed_batch_size: Optional[int] = None, | ||
| truncation: Optional[bool] = None, | ||
| show_progress_bar: Optional[bool] = False, | ||
| ): | ||
| try: | ||
| import voyageai | ||
| except ImportError: | ||
| raise ValueError( | ||
| "The voyageai python package is not installed. Please install it with `pip install -U voyageai`" | ||
| ) | ||
|
|
||
| if embed_batch_size is None: | ||
| embed_batch_size = 72 if model_name in ["voyage-2", "voyage-02"] else 7 | ||
|
|
||
| self._client = voyageai.Client(api_key=api_key) | ||
| self._model_name = model_name | ||
| self._batch_size = embed_batch_size | ||
| self._truncation = truncation | ||
| self._show_progress_bar = show_progress_bar | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| # Call VoyageAI Embedding API for each document. | ||
| embeddings: List[List[float]] = [] | ||
|
|
||
| if self._show_progress_bar: | ||
| try: | ||
| from tqdm.auto import tqdm | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Must have tqdm installed if `show_progress_bar` is set to True. " | ||
| "Please install with `pip install tqdm`." | ||
| ) from e | ||
|
|
||
| _iter = tqdm(range(0, len(input), self._batch_size)) | ||
| else: | ||
| _iter = range(0, len(input), self._batch_size) | ||
|
|
||
| for i in _iter: | ||
| embeddings.extend( | ||
| self._client.embed( | ||
| input[i : i + self._batch_size], | ||
| model=self._model_name, | ||
| input_type="document", | ||
|
||
| truncation=self._truncation, | ||
| ).embeddings | ||
| ) | ||
|
|
||
| return embeddings | ||
|
|
||
|
|
||
| class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| """ | ||
| This class is used to get embeddings for a list of texts using the HuggingFace API. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import { IEmbeddingFunction } from "./IEmbeddingFunction"; | ||
|
|
||
| export class VoyageAIEmbeddingFunction implements IEmbeddingFunction { | ||
| private model_name: string; | ||
| private api_url: string; | ||
| private batch_size: number; | ||
| private truncation?: boolean; | ||
| private headers: { [key: string]: string }; | ||
|
|
||
| constructor({ | ||
| voyageai_api_key, | ||
fzowl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model_name, | ||
| batch_size, | ||
| truncation, | ||
| }: { | ||
| voyageai_api_key: string; | ||
| model_name: string; | ||
| batch_size?: number; | ||
| truncation?: boolean; | ||
| }) { | ||
| this.api_url = "https://api.voyageai.com/v1/embeddings"; | ||
| this.headers = { | ||
| Authorization: `Bearer ${voyageai_api_key}`, | ||
| "Content-Type": "application/json", | ||
| }; | ||
|
|
||
| this.model_name = model_name; | ||
| this.truncation = truncation; | ||
| if (batch_size) { | ||
| this.batch_size = batch_size; | ||
| } else { | ||
| if (model_name in ["voyage-2", "voyage-02"]) { | ||
fzowl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| this.batch_size = 72; | ||
| } else { | ||
| this.batch_size = 7; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| public async generate(texts: string[]) { | ||
| try { | ||
| const result: number[][] = []; | ||
| let index = 0; | ||
|
|
||
| while (index < texts.length) { | ||
fzowl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const response = await fetch(this.api_url, { | ||
| method: 'POST', | ||
| headers: this.headers, | ||
| body: JSON.stringify({ | ||
| input: texts.slice(index, index + this.batch_size), | ||
| model: this.model_name, | ||
| truncation: this.truncation, | ||
|
||
| }), | ||
| }); | ||
|
|
||
| const data = (await response.json()) as { data: any[]; detail: string }; | ||
| if (!data || !data.data) { | ||
| throw new Error(data.detail); | ||
| } | ||
|
|
||
| const embeddings: any[] = data.data; | ||
| const sortedEmbeddings = embeddings.sort((a, b) => a.index - b.index); | ||
|
|
||
| const embeddingsChunks = sortedEmbeddings.map((result) => result.embedding); | ||
| result.push(...embeddingsChunks); | ||
| index += this.batch_size; | ||
| } | ||
| return result; | ||
| } catch (error) { | ||
| if (error instanceof Error) { | ||
| throw new Error(`Error calling VoyageAI API: ${error.message}`); | ||
| } else { | ||
| throw new Error(`Error calling VoyageAI API: ${error}`); | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.