Skip to content

Commit

Permalink
Merge branch 'main' into renovate/langchain-google-vertexai-2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwarajanand authored Oct 23, 2024
2 parents aa47841 + 247b9cf commit 80c187a
Show file tree
Hide file tree
Showing 11 changed files with 1,705 additions and 6 deletions.
593 changes: 593 additions & 0 deletions docs/model_endpoint_management.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ test = [
"pytest-asyncio==0.24.0",
"pytest==8.3.3",
"pytest-cov==5.0.0",
"pytest-depends==1.0.1",
"Pillow==11.0.0"
]

Expand Down
5 changes: 5 additions & 0 deletions src/langchain_google_alloydb_pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

from .chat_message_history import AlloyDBChatMessageHistory
from .embeddings import AlloyDBEmbeddings
from .engine import AlloyDBEngine, Column
from .loader import AlloyDBDocumentSaver, AlloyDBLoader
from .model_manager import AlloyDBModel, AlloyDBModelManager
from .vectorstore import AlloyDBVectorStore
from .version import __version__

Expand All @@ -25,5 +27,8 @@
"AlloyDBLoader",
"AlloyDBDocumentSaver",
"AlloyDBChatMessageHistory",
"AlloyDBEmbeddings",
"AlloyDBModelManager",
"AlloyDBModel",
"__version__",
]
38 changes: 32 additions & 6 deletions src/langchain_google_alloydb_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import re
import uuid
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type

import numpy as np
import requests
Expand All @@ -30,6 +30,7 @@
from sqlalchemy import RowMapping, text
from sqlalchemy.ext.asyncio import AsyncEngine

from .embeddings import AlloyDBEmbeddings
from .engine import AlloyDBEngine
from .indexes import (
DEFAULT_DISTANCE_STRATEGY,
Expand Down Expand Up @@ -248,6 +249,8 @@ async def aadd_embeddings(
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}'
values = {"id": id, "content": content, "embedding": str(embedding)}
values_stmt = "VALUES (:id, :content, :embedding"
if not embedding and isinstance(self.embedding_service, AlloyDBEmbeddings):
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}"

# Add metadata
extra = metadata
Expand Down Expand Up @@ -288,7 +291,11 @@ async def aadd_texts(
Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
"""
embeddings = self.embedding_service.embed_documents(list(texts))
if isinstance(self.embedding_service, AlloyDBEmbeddings):
embeddings: List[List[float]] = [[] for _ in list(texts)]
else:
embeddings = await self.embedding_service.aembed_documents(list(texts))

ids = await self.aadd_embeddings(
texts, embeddings, metadatas=metadatas, ids=ids, **kwargs
)
Expand Down Expand Up @@ -535,7 +542,15 @@ async def __query_collection(
search_function = self.distance_strategy.search_function

filter = f"WHERE {filter}" if filter else ""
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
if (
not embedding
and isinstance(self.embedding_service, AlloyDBEmbeddings)
and "query" in kwargs
):
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"])
else:
query_embedding = f"'{embedding}'"
stmt = f'SELECT *, {search_function}({self.embedding_column}, {query_embedding}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {query_embedding} LIMIT {k};'
if self.index_query_options:
query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};"
async with self.engine.connect() as conn:
Expand All @@ -558,7 +573,12 @@ async def asimilarity_search(
**kwargs: Any,
) -> List[Document]:
"""Return docs selected by similarity search on query."""
embedding = self.embedding_service.embed_query(text=query)
embedding = (
[]
if isinstance(self.embedding_service, AlloyDBEmbeddings)
else await self.embedding_service.aembed_query(text=query)
)
kwargs["query"] = query

return await self.asimilarity_search_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
Expand Down Expand Up @@ -619,7 +639,13 @@ async def asimilarity_search_with_score(
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and distance scores selected by similarity search on query."""
embedding = self.embedding_service.embed_query(query)
embedding = (
[]
if isinstance(self.embedding_service, AlloyDBEmbeddings)
else await self.embedding_service.aembed_query(text=query)
)
kwargs["query"] = query

docs = await self.asimilarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)
Expand Down Expand Up @@ -682,7 +708,7 @@ async def amax_marginal_relevance_search(
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
embedding = self.embedding_service.embed_query(text=query)
embedding = await self.embedding_service.aembed_query(text=query)

return await self.amax_marginal_relevance_search_by_vector(
embedding=embedding,
Expand Down
172 changes: 172 additions & 0 deletions src/langchain_google_alloydb_pg/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

import json
from typing import List, Type

from langchain_core.embeddings import Embeddings
from sqlalchemy import text

from .engine import AlloyDBEngine
from .model_manager import AlloyDBModelManager


class AlloyDBEmbeddings(Embeddings):
"""Google AlloyDB Embeddings available via Model Endpoint Management."""

__create_key = object()

def __init__(self, key: object, engine: AlloyDBEngine, model_id: str):
"""AlloyDBEmbeddings constructor.
Args:
key (object): Prevent direct constructor usage.
engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database.
model_id (str): The model id used for generating embeddings.
Raises:
:class:`ValueError`: if model does not exist. Use AlloyDBModelManager to create the model.
"""
if key != AlloyDBEmbeddings.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
)
self._engine = engine
self.model_id = model_id

@classmethod
async def create(
cls: Type[AlloyDBEmbeddings], engine: AlloyDBEngine, model_id: str
) -> AlloyDBEmbeddings:
"""Create AlloyDBEmbeddings instance.
Args:
key (object): Prevent direct constructor usage.
engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database.
model_id (str): The model id used for generating embeddings.
Returns:
AlloyDBEmbeddings: Instance of AlloyDBEmbeddings.
"""

embeddings = cls(cls.__create_key, engine, model_id)
model_exists = await embeddings.amodel_exists()
if not model_exists:
raise ValueError(f"Model {model_id} does not exist.")

return embeddings

@classmethod
def create_sync(
cls: Type[AlloyDBEmbeddings], engine: AlloyDBEngine, model_id: str
) -> AlloyDBEmbeddings:
"""Create AlloyDBEmbeddings instance.
Args:
key (object): Prevent direct constructor usage.
engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database.
model_id (str): The model id used for generating embeddings.
Returns:
AlloyDBEmbeddings: Instance of AlloyDBEmbeddings.
"""

embeddings = cls(cls.__create_key, engine, model_id)
if not embeddings.model_exists():
raise ValueError(f"Model {model_id} does not exist.")

return embeddings

async def amodel_exists(self) -> bool:
"""Checks if the embedding model exists.
Return:
`Bool`: True if a model with the given name exists, False otherwise.
"""
return await self._engine._run_as_async(self.__amodel_exists())

def model_exists(self) -> bool:
"""Checks if the embedding model exists.
Return:
`Bool`: True if a model with the given name exists, False otherwise.
"""
return self._engine._run_as_sync(self.__amodel_exists())

async def __amodel_exists(self) -> bool:
"""Checks if the embedding model exists.
Return:
`Bool`: True if a model with the given name exists, False otherwise.
"""
model_manager = await AlloyDBModelManager.create(self._engine)
model = await model_manager.aget_model(model_id=self.model_id)
if model is not None:
return True
return False

def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError(
"Embedding functions are not implemented. Use VertexAIEmbeddings interface instead."
)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError(
"Embedding functions are not implemented. Use VertexAIEmbeddings interface instead."
)

def embed_query_inline(self, query: str) -> str:
return f"embedding('{self.model_id}', '{query}')::vector"

async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.
Args:
query (str): Text to embed.
Returns:
List[float]: Embedding.
"""
embeddings = await self._engine._run_as_async(self.__aembed_query(text))
return embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed query text.
Args:
query (str): Text to embed.
Returns:
List[float]: Embedding.
"""
return self._engine._run_as_sync(self.__aembed_query(text))

async def __aembed_query(self, query: str) -> List[float]:
"""Coroutine for generating embeddings for a given query.
Args:
query (str): Text to embed.
Returns:
List[float]: Embedding.
"""
query = f" SELECT embedding('{self.model_id}', '{query}')::vector "
async with self._engine._pool.connect() as conn:
result = await conn.execute(text(query))
result_map = result.mappings()
results = result_map.fetchall()
return json.loads(results[0]["embedding"])
Loading

0 comments on commit 80c187a

Please sign in to comment.