From 247b9cfdbc758c68b976183831cdff7d41f11848 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Wed, 23 Oct 2024 15:00:00 +0530 Subject: [PATCH] feat: enable model endpoint management for embeddings (#233) --- docs/model_endpoint_management.ipynb | 593 ++++++++++++++++++ pyproject.toml | 1 + src/langchain_google_alloydb_pg/__init__.py | 5 + .../async_vectorstore.py | 38 +- src/langchain_google_alloydb_pg/embeddings.py | 172 +++++ .../model_manager.py | 304 +++++++++ tests/test_async_vectorstore_search.py | 4 + tests/test_embeddings.py | 109 ++++ tests/test_model_manager.py | 103 +++ tests/test_vectorstore_embeddings.py | 378 +++++++++++ tests/test_vectorstore_search.py | 4 + 11 files changed, 1705 insertions(+), 6 deletions(-) create mode 100644 docs/model_endpoint_management.ipynb create mode 100644 src/langchain_google_alloydb_pg/embeddings.py create mode 100644 src/langchain_google_alloydb_pg/model_manager.py create mode 100644 tests/test_embeddings.py create mode 100644 tests/test_model_manager.py create mode 100644 tests/test_vectorstore_embeddings.py diff --git a/docs/model_endpoint_management.ipynb b/docs/model_endpoint_management.ipynb new file mode 100644 index 00000000..b616450d --- /dev/null +++ b/docs/model_endpoint_management.ipynb @@ -0,0 +1,593 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Endpoint Management in AlloyDB for PostgreSQL\n", + "\n", + "> [AlloyDB](https://cloud.google.com/alloydb) is a fully managed PostgreSQL compatible database service for your most demanding enterprise workloads.\n", + "AlloyDB combines the best of Google with PostgreSQL, for superior performance, scale, and availability. Extend your database application to build AI-powered\n", + "experiences leveraging AlloyDB Langchain integrations.\n", + "\n", + "This notebook goes over how to use Model endpoint management in AlloyDB using the `AlloyDBModelManager` and `AlloyDBEmbeddings` classes.\n", + "\n", + "Model Endpoint Management allows Google Cloud Databases, such as AlloyDB, Cloud SQL, Spanner, etc. to directly invoke Large Language Models (LLMs) within SQL queries, facilitating seamless integration of AI capabilities into data workflows. This feature enables developers to leverage LLM-powered insights in real time, improving the efficiency of data processing tasks.\n", + "\n", + "Learn more about the package on [GitHub](https://github.com/googleapis/langchain-google-alloydb-pg-python/).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/langchain-google-alloydb-pg-python/blob/main/docs/model_endpoint_management.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + "\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Enable the AlloyDB API](https://console.cloud.google.com/flows/enableapi?apiid=alloydb.googleapis.com)\n", + " * [Create a AlloyDB instance](https://cloud.google.com/alloydb/docs/instance-primary-create)\n", + " * [Create a AlloyDB database](https://cloud.google.com/alloydb/docs/database-create)\n", + " * [Set the google_ml_integration.enable_model_support database flag to on for an instance](https://cloud.google.com/alloydb/docs/instance-configure-database-flags)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lfBhGVEM_97X" + }, + "source": [ + "### (Optional) Set up non-default database users\n", + "\n", + "First, [Add an IAM database user to the database](https://cloud.google.com/alloydb/docs/manage-iam-authn) or a [custom database user](https://cloud.google.com/alloydb/docs/database-users/about#create).\n", + "\n", + "Second, set up the required user permissions by running the following commands on [AlloyDBStudio](https://cloud.google.com/alloydb/docs/manage-data-using-studio) or any `psql` terminal.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2H9-FzGN_97X" + }, + "source": [ + "The `google_ml_integration` extension must first be installed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mn0Pt7Gf_97X", + "vscode": { + "languageId": "sql" + } + }, + "outputs": [], + "source": [ + "CREATE EXTENSION google_ml_integration VERSION '1.3';" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5U0LwwqFGTqP" + }, + "source": [ + "Grant permissions for the user to access the tables in the `google_ml_extension`. Replace the `` with your user.\n", + "\n", + "For more information, see [Enabling extension](https://cloud.google.com/alloydb/docs/ai/model-endpoint-register-model#enable-extension).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vulrB_51_97X", + "vscode": { + "languageId": "sql" + } + }, + "outputs": [], + "source": [ + "GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA google_ml TO ;" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YWlx7lMz_97Y" + }, + "source": [ + "Grant permission for the user to access the `embedding` function of `google_ml_extension`. Replace the `` with your user.\n", + "\n", + "For more information about the above permission, see [Generate embeddings](https://cloud.google.com/alloydb/docs/ai/work-with-embeddings)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Aqa2pPZ__97Y", + "vscode": { + "languageId": "sql" + } + }, + "outputs": [], + "source": [ + "GRANT EXECUTE ON FUNCTION embedding TO ;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🦜🔗 Library Installation\n", + "Install the integration library, `langchain-google-alloydb-pg`. The library must be version v0.8.0 or higher." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-google-alloydb-pg langchain-core" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel.\n", + "For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "! gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DxC_TgYUDTUK" + }, + "source": [ + "### Enable database integration with Vertex AI\n", + "\n", + "To enable database integration with Vertex AI, the AlloyDB service agent (`service-@gcp-sa-alloydb.iam.gserviceaccount.com`) must be granted the Vertex AI User role. For more information on authentication for Vertex AI, see [this](https://cloud.google.com/alloydb/docs/ai/model-endpoint-register-model#vertex-provider)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hwClEF1g_97X", + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "PROJECT_NUMBER=!gcloud projects describe {PROJECT_ID} --format=\"value(projectNumber)\"\n", + "\n", + "!gcloud projects add-iam-policy-binding {PROJECT_ID} \\\n", + "--member=\"serviceAccount:service-{PROJECT_NUMBER[0]}@gcp-sa-alloydb.iam.gserviceaccount.com\" \\\n", + "--role=\"roles/aiplatform.user\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up connection pool" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set AlloyDB database values\n", + "Find your database values, in the [AlloyDB cluster page](https://console.cloud.google.com/alloydb?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "REGION = \"us-central1\" # @param {type: \"string\"}\n", + "CLUSTER = \"my-alloydb-cluster\" # @param {type: \"string\"}\n", + "INSTANCE = \"my-alloydb-instance\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"vector_store\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### AlloyDBEngine Connection Pool\n", + "\n", + "To connect to AlloyDB and use Model endpoint management is an `AlloyDBEngine` object is required. The `AlloyDBEngine` configures a connection pool to your AlloyDB database, enabling successful connections from your application and following industry best practices.\n", + "\n", + "To create a `AlloyDBEngine` using `AlloyDBEngine.from_instance()` you need to provide only 5 things:\n", + "\n", + "1. `project_id` : Project ID of the Google Cloud Project where the AlloyDB instance is located.\n", + "1. `region` : Region where the AlloyDB instance is located.\n", + "1. `cluster`: The name of the AlloyDB cluster.\n", + "1. `instance` : The name of the AlloyDB instance.\n", + "1. `database` : The name of the database to connect to on the AlloyDB instance.\n", + "\n", + "By default, [IAM database authentication](https://cloud.google.com/alloydb/docs/connect-iam) will be used as the method of database authentication. This library uses the IAM principal belonging to the [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials) sourced from the environment.\n", + "\n", + "Optionally, [built-in database authentication](https://cloud.google.com/alloydb/docs/database-users/about) using a username and password to access the AlloyDB database can also be used. Just provide the optional `user` and `password` arguments to `AlloyDBEngine.from_instance()`:\n", + "\n", + "* `user` : Database user to use for built-in database authentication and login.\n", + "* `password` : Database password to use for built-in database authentication and login.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_google_alloydb_pg import AlloyDBEngine\n", + "\n", + "engine = await AlloyDBEngine.afrom_instance(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " cluster=CLUSTER,\n", + " instance=INSTANCE,\n", + " database=DATABASE,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register a model with `AlloyDBModelManager`\n", + "The `AlloyDBModelManager` class allows the user to create, get, list, and drop models. A model is required by the `AlloyDBEmbeddings` class to be used to embed documents on insertion into the vector store and during similarity searches.\n", + "\n", + "Initialize an instance of `AlloyDBModelManager` with the connection pool through the `AlloyDBEngine` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_google_alloydb_pg import AlloyDBModelManager, AlloyDBModel\n", + "\n", + "model_manager = await AlloyDBModelManager.create(engine)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On creating the `AlloyDBModelManager` object, it will run a prerequisite check to ensure:\n", + "* The extension is up to date: `google_ml_integration` extension is installed and the version is greater than 1.3\n", + "* The database flag is set: `google_ml_integration.enable_model_support` is set to on." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### List all models available\n", + "This list includes the [pre built models](https://cloud.google.com/alloydb/docs/ai/model-endpoint-register-model#add-vertex) and any other model you may have created.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = await model_manager.alist_models()\n", + "print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Get a specific model\n", + "To retrieve a specific model you will have to provide the `model_id` to the `aget_model()` function.\n", + "\n", + "If the model with the specified model_id exists, then the AlloyDBModel dataclass of it is returned.\n", + "Otherwise None is returned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = await model_manager.aget_model(model_id=\"textembedding-gecko\")\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create a custom text embedding model\n", + "\n", + "To create a custom textembedding model you need to pass these parameters to the `acreate_model()` function :\n", + "\n", + "* model_id: A unique ID for the model endpoint that you define.\n", + "* model_provider: The provider of the model endpoint (`google` for vertexAI and `custom` for custom hosted models).\n", + "* model_type: The model type (set this value to `text_embedding` for text embedding model endpoints or `generic` for all other model endpoints).\n", + "* model_qualified_name: The fully qualified name in case the model endpoint has multiple versions or if the model endpoint defines it.\n", + "\n", + "You can customize your model further with some optional parameters. For all the details and possibilities, check out the [reference doc](https://cloud.google.com/alloydb/docs/reference/model-endpoint-reference#google_mlcreate_model).\n", + "\n", + "\n", + "**Note**: The `acreate_model()` function doesn't return any value directly.\n", + "You'll need to use `alist_models()` or `aget_model()` to verify if your model was created successfully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await model_manager.acreate_model(\n", + " model_id=\"textembedding-gecko@003\",\n", + " model_provider=\"google\",\n", + " model_qualified_name=\"textembedding-gecko@003\",\n", + " model_type=\"text_embedding\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: A model once created can also be dropped." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create custom third-party models\n", + "You can also create a third-party custom text embedding model using these steps. You can also create a third-party custom text embedding model, such as Hugging Face models.\n", + "\n", + "For all models except Vertex AI model endpoints, you can store your API keys or bearer tokens in Secret Manager. This step is optional if your model endpoint doesn't handle authentication through Secret Manager.\n", + "\n", + "For information, see [Authentication for custom hosted models](https://cloud.google.com/alloydb/docs/ai/model-endpoint-register-model#set_up_authentication_for_custom-hosted_models).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Removing a Model\n", + "If you no longer need a specific model, you can easily remove it using the adrop_model function by providing the model_id.\n", + "\n", + "To make sure the model has been deleted, you can use the alist_models function to list all your remaining models. The deleted model should no longer appear in the list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await model_manager.adrop_model(model_id=\"textembedding-gecko@003\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understanding the AlloyDBModel Dataclass\n", + "\n", + "When you retrieve a model using the function `aget_model()`, you'll receive an `AlloyDBModel` object.\n", + "\n", + "Here's a breakdown of what's inside:\n", + "* model_id (str) : A unique ID for the model endpoint that you define.\n", + "* model_request_url (Optional[str]) : The model-specific endpoint when adding other text embedding and generic model endpoints.\n", + "* model_provider (str) : The provider of the model endpoint. Set to google for Vertex AI model endpoints and custom for custom-hosted model endpoints.\n", + "* model_type (str) : The model type. You can set this value to text_embedding for text embedding model endpoints or generic for all other model endpoints.\n", + "* model_qualified_name (str) : The fully qualified name in case the model endpoint has multiple versions or if the model endpoint defines it.\n", + "* model_auth_type (Optional[str]) : The authentication type used by the model endpoint. You can set it to either alloydb_service_agent_iam for Vertex AI models or secret_manager for other providers.\n", + "* model_auth_id (Optional[str]) : The secret ID that you set and is subsequently used when registering a model endpoint.\n", + "* input_transform_fn (Optional[str]) : The SQL function name to transform input of the corresponding prediction function to the model-specific input.\n", + "* output_transform_fn (Optional[str]) : The SQL function name to transform model specific output to the prediction function output.\n", + "\n", + "See below for an example of AlloyDBModel instance on using `aget_model(model_id=\"textembedding-gecko@001\")`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "AlloyDBModel(\n", + " model_id=\"textembedding-gecko@001\",\n", + " model_request_url=\"publishers/google/models/textembedding-gecko@001\",\n", + " model_provider=\"google\",\n", + " model_type=\"text_embedding\",\n", + " model_qualified_name=\"textembedding-gecko@001\",\n", + " model_auth_type=\"alloydb_service_agent_iam\",\n", + " model_auth_id=None,\n", + " input_transform_fn=\"google_ml.vertexai_text_embedding_input_transform\",\n", + " output_transform_fn=\"google_ml.vertexai_text_embedding_output_transform\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate Vector Embeddings with `AlloyDBEmbeddings`\n", + "The `AlloyDBEmbeddings` class allows users to utilize the in database embedding generation functions available via Model Endpoint Management.\n", + "\n", + "In the below example, we are using the `textembedding-gecko@003` model that we created using the Model Manager.\n", + "\n", + "**Note**: If you have dropped the above model, you can recreate it or use `textembedding-gecko@001`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_google_alloydb_pg import AlloyDBEmbeddings\n", + "\n", + "model_id = \"textembedding-gecko@003\"\n", + "embedding_service = await AlloyDBEmbeddings.create(engine=engine, model_id=model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: This tutorial demonstrates the async interface. All async methods have corresponding sync methods." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On creating an instance of the `AlloyDBEmbeddings` class, it checks if the model exists.\n", + "If the model does not exist with that model_id, the class throws a `ValueError`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Using `AlloyDBEmbeddings` as an embedding service" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `AlloyDBEmbeddings` class can be used as the embedding service with an `AlloyDBVectorStore` to generate embeddings on document insertion and for similarity searches.\n", + "\n", + "Learn more about getting started with the [`AlloyDBVectorStore`](https://github.com/googleapis/langchain-google-alloydb-pg-python/blob/main/docs/vector_store.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "from langchain_core.documents import Document\n", + "from langchain_google_alloydb_pg import AlloyDBVectorStore\n", + "\n", + "# (Optional) Create a new vector store table\n", + "VECTOR_SIZE = 768 # For textembeddding-gecko@003 model\n", + "await engine.ainit_vectorstore_table(\n", + " table_name=\"vector_store_table\",\n", + " vector_size=VECTOR_SIZE,\n", + " overwrite_existing=True,\n", + ")\n", + "\n", + "# Initialize the vector store instance with AlloyDBEmbeddings\n", + "vs = await AlloyDBVectorStore.create(\n", + " engine,\n", + " embedding_service=embedding_service,\n", + " table_name=\"vector_store_table\",\n", + ")\n", + "\n", + "# Add documents\n", + "texts = [\"foo\", \"bar\", \"baz\", \"boo\"]\n", + "ids = [str(uuid.uuid4()) for i in range(len(texts))]\n", + "docs = [Document(page_content=texts[i]) for i in range(len(texts))]\n", + "await vs.aadd_documents(docs, ids=ids)\n", + "\n", + "# Search documents\n", + "results = await vs.asimilarity_search(\"foo\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 07d91abc..49b2a228 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ test = [ "pytest-asyncio==0.24.0", "pytest==8.3.3", "pytest-cov==5.0.0", + "pytest-depends==1.0.1", "Pillow==11.0.0" ] diff --git a/src/langchain_google_alloydb_pg/__init__.py b/src/langchain_google_alloydb_pg/__init__.py index 370e8d32..1c68f361 100644 --- a/src/langchain_google_alloydb_pg/__init__.py +++ b/src/langchain_google_alloydb_pg/__init__.py @@ -13,8 +13,10 @@ # limitations under the License. from .chat_message_history import AlloyDBChatMessageHistory +from .embeddings import AlloyDBEmbeddings from .engine import AlloyDBEngine, Column from .loader import AlloyDBDocumentSaver, AlloyDBLoader +from .model_manager import AlloyDBModel, AlloyDBModelManager from .vectorstore import AlloyDBVectorStore from .version import __version__ @@ -25,5 +27,8 @@ "AlloyDBLoader", "AlloyDBDocumentSaver", "AlloyDBChatMessageHistory", + "AlloyDBEmbeddings", + "AlloyDBModelManager", + "AlloyDBModel", "__version__", ] diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index b33193aa..fb3bfc1e 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -19,7 +19,7 @@ import json import re import uuid -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type import numpy as np import requests @@ -30,6 +30,7 @@ from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncEngine +from .embeddings import AlloyDBEmbeddings from .engine import AlloyDBEngine from .indexes import ( DEFAULT_DISTANCE_STRATEGY, @@ -248,6 +249,8 @@ async def aadd_embeddings( insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}' values = {"id": id, "content": content, "embedding": str(embedding)} values_stmt = "VALUES (:id, :content, :embedding" + if not embedding and isinstance(self.embedding_service, AlloyDBEmbeddings): + values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # Add metadata extra = metadata @@ -288,7 +291,11 @@ async def aadd_texts( Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. """ - embeddings = self.embedding_service.embed_documents(list(texts)) + if isinstance(self.embedding_service, AlloyDBEmbeddings): + embeddings: List[List[float]] = [[] for _ in list(texts)] + else: + embeddings = await self.embedding_service.aembed_documents(list(texts)) + ids = await self.aadd_embeddings( texts, embeddings, metadatas=metadatas, ids=ids, **kwargs ) @@ -535,7 +542,15 @@ async def __query_collection( search_function = self.distance_strategy.search_function filter = f"WHERE {filter}" if filter else "" - stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" + if ( + not embedding + and isinstance(self.embedding_service, AlloyDBEmbeddings) + and "query" in kwargs + ): + query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) + else: + query_embedding = f"'{embedding}'" + stmt = f'SELECT *, {search_function}({self.embedding_column}, {query_embedding}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {query_embedding} LIMIT {k};' if self.index_query_options: query_options_stmt = f"SET LOCAL {self.index_query_options.to_string()};" async with self.engine.connect() as conn: @@ -558,7 +573,12 @@ async def asimilarity_search( **kwargs: Any, ) -> List[Document]: """Return docs selected by similarity search on query.""" - embedding = self.embedding_service.embed_query(text=query) + embedding = ( + [] + if isinstance(self.embedding_service, AlloyDBEmbeddings) + else await self.embedding_service.aembed_query(text=query) + ) + kwargs["query"] = query return await self.asimilarity_search_by_vector( embedding=embedding, k=k, filter=filter, **kwargs @@ -619,7 +639,13 @@ async def asimilarity_search_with_score( **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" - embedding = self.embedding_service.embed_query(query) + embedding = ( + [] + if isinstance(self.embedding_service, AlloyDBEmbeddings) + else await self.embedding_service.aembed_query(text=query) + ) + kwargs["query"] = query + docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -682,7 +708,7 @@ async def amax_marginal_relevance_search( **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance.""" - embedding = self.embedding_service.embed_query(text=query) + embedding = await self.embedding_service.aembed_query(text=query) return await self.amax_marginal_relevance_search_by_vector( embedding=embedding, diff --git a/src/langchain_google_alloydb_pg/embeddings.py b/src/langchain_google_alloydb_pg/embeddings.py new file mode 100644 index 00000000..a50799cd --- /dev/null +++ b/src/langchain_google_alloydb_pg/embeddings.py @@ -0,0 +1,172 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import json +from typing import List, Type + +from langchain_core.embeddings import Embeddings +from sqlalchemy import text + +from .engine import AlloyDBEngine +from .model_manager import AlloyDBModelManager + + +class AlloyDBEmbeddings(Embeddings): + """Google AlloyDB Embeddings available via Model Endpoint Management.""" + + __create_key = object() + + def __init__(self, key: object, engine: AlloyDBEngine, model_id: str): + """AlloyDBEmbeddings constructor. + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database. + model_id (str): The model id used for generating embeddings. + + Raises: + :class:`ValueError`: if model does not exist. Use AlloyDBModelManager to create the model. + + """ + if key != AlloyDBEmbeddings.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._engine = engine + self.model_id = model_id + + @classmethod + async def create( + cls: Type[AlloyDBEmbeddings], engine: AlloyDBEngine, model_id: str + ) -> AlloyDBEmbeddings: + """Create AlloyDBEmbeddings instance. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database. + model_id (str): The model id used for generating embeddings. + + Returns: + AlloyDBEmbeddings: Instance of AlloyDBEmbeddings. + """ + + embeddings = cls(cls.__create_key, engine, model_id) + model_exists = await embeddings.amodel_exists() + if not model_exists: + raise ValueError(f"Model {model_id} does not exist.") + + return embeddings + + @classmethod + def create_sync( + cls: Type[AlloyDBEmbeddings], engine: AlloyDBEngine, model_id: str + ) -> AlloyDBEmbeddings: + """Create AlloyDBEmbeddings instance. + + Args: + key (object): Prevent direct constructor usage. + engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database. + model_id (str): The model id used for generating embeddings. + + Returns: + AlloyDBEmbeddings: Instance of AlloyDBEmbeddings. + """ + + embeddings = cls(cls.__create_key, engine, model_id) + if not embeddings.model_exists(): + raise ValueError(f"Model {model_id} does not exist.") + + return embeddings + + async def amodel_exists(self) -> bool: + """Checks if the embedding model exists. + + Return: + `Bool`: True if a model with the given name exists, False otherwise. + """ + return await self._engine._run_as_async(self.__amodel_exists()) + + def model_exists(self) -> bool: + """Checks if the embedding model exists. + + Return: + `Bool`: True if a model with the given name exists, False otherwise. + """ + return self._engine._run_as_sync(self.__amodel_exists()) + + async def __amodel_exists(self) -> bool: + """Checks if the embedding model exists. + + Return: + `Bool`: True if a model with the given name exists, False otherwise. + """ + model_manager = await AlloyDBModelManager.create(self._engine) + model = await model_manager.aget_model(model_id=self.model_id) + if model is not None: + return True + return False + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError( + "Embedding functions are not implemented. Use VertexAIEmbeddings interface instead." + ) + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError( + "Embedding functions are not implemented. Use VertexAIEmbeddings interface instead." + ) + + def embed_query_inline(self, query: str) -> str: + return f"embedding('{self.model_id}', '{query}')::vector" + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text. + + Args: + query (str): Text to embed. + + Returns: + List[float]: Embedding. + """ + embeddings = await self._engine._run_as_async(self.__aembed_query(text)) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed query text. + + Args: + query (str): Text to embed. + + Returns: + List[float]: Embedding. + """ + return self._engine._run_as_sync(self.__aembed_query(text)) + + async def __aembed_query(self, query: str) -> List[float]: + """Coroutine for generating embeddings for a given query. + + Args: + query (str): Text to embed. + + Returns: + List[float]: Embedding. + """ + query = f" SELECT embedding('{self.model_id}', '{query}')::vector " + async with self._engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + return json.loads(results[0]["embedding"]) diff --git a/src/langchain_google_alloydb_pg/model_manager.py b/src/langchain_google_alloydb_pg/model_manager.py new file mode 100644 index 00000000..336a0548 --- /dev/null +++ b/src/langchain_google_alloydb_pg/model_manager.py @@ -0,0 +1,304 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Sequence, Type + +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from .engine import AlloyDBEngine + + +@dataclass +class AlloyDBModel: + model_id: str + model_request_url: Optional[str] + model_provider: str + model_type: str + model_qualified_name: str + model_auth_type: Optional[str] + model_auth_id: Optional[str] + input_transform_fn: Optional[str] + output_transform_fn: Optional[str] + + +class AlloyDBModelManager: + """Manage models to be used with google_ml_integration Extension. + Refer to [Model Endpoint Management](https://cloud.google.com/alloydb/docs/ai/model-endpoint-overview). + """ + + __create_key = object() + + def __init__( + self, + key: object, + engine: AlloyDBEngine, + ): + """AlloyDBModelManager constructor. + Args: + engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database. + """ + if key != AlloyDBModelManager.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + self._engine = engine + + @classmethod + async def create( + cls: Type[AlloyDBModelManager], + engine: AlloyDBEngine, + ) -> AlloyDBModelManager: + manager = AlloyDBModelManager(cls.__create_key, engine) + coro = manager.__avalidate() + await engine._run_as_async(coro) + return manager + + @classmethod + def create_sync( + cls: Type[AlloyDBModelManager], + engine: AlloyDBEngine, + ) -> AlloyDBModelManager: + manager = AlloyDBModelManager(cls.__create_key, engine) + coro = manager.__avalidate() + engine._run_as_sync(coro) + return manager + + async def aget_model(self, model_id: str) -> Optional[AlloyDBModel]: + """Lists the model details for a specific model_id. + + Args: + model_id (str): A unique ID for the model endpoint that you have defined. + + Returns: + :class: `AlloyDBModel` object of the specified model if it exists otherwise `None`. + + """ + result = await self._engine._run_as_async(self.__aget_model(model_id=model_id)) + return result + + async def alist_models(self) -> List[AlloyDBModel]: + """Lists all the models and its details. + + Returns: + List[`AlloyDBModel`] of all available model.. + """ + results = await self._engine._run_as_async(self.__alist_models()) + return results + + async def acreate_model( + self, + model_id: str, + model_provider: str, + model_type: str, + model_qualified_name: str, + **kwargs: dict[str, str], + ) -> None: + """Creates a registration for custom text model. + + Args: + model_id (str): A unique ID for the model endpoint that you define. + model_provider (str): The provider of the model endpoint. + model_type (str): The model type. Either text_embedding or generic. + model_qualified_name (str): The fully qualified name in case the model endpoint has multiple versions + **kwargs : + model_request_url (str): The model-specific endpoint when adding other text embedding and generic model endpoints + model_auth_type (str): The authentication type used by the model endpoint. + model_auth_id (str): The secret ID that you set and is subsequently used when registering a model endpoint. + generate_headers_fn (str): The SQL function name you set to generate custom headers. + model_in_transform_fn (str): The SQL function name to transform input of the corresponding prediction function to the model-specific input. + model_out_transform_fn (str): The SQL function name to transform model specific output to the prediction function output. + + Returns: + None + + Raises: + :class:`DBAPIError `: if argument names mismatch create_model function specification. + """ + await self._engine._run_as_async( + self.__acreate_model( + model_id, model_provider, model_type, model_qualified_name, **kwargs + ) + ) + + async def adrop_model(self, model_id: str) -> None: + """Removes an already registered model. + + Args: + model_id (str): A unique ID for the model endpoint that you have defined. + + Returns: + None + """ + await self._engine._run_as_async(self.__adrop_model(model_id)) + + async def __avalidate(self) -> None: + """Private async function to validate prerequisites. + + Raises: + Exception if google_ml_integration EXTENSION is not 1.3. + Exception if google_ml_integration.enable_model_support DB Flag not set. + """ + extension_version = await self.__fetch_google_ml_extension() + db_flag = await self.__fetch_db_flag() + if extension_version < 1.3: + raise Exception( + "Please upgrade google_ml_integration EXTENSION to version 1.3 or above." + ) + if db_flag != "on": + raise Exception( + "google_ml_integration.enable_model_support DB Flag not set." + ) + + async def __query_db(self, query: str) -> Sequence[RowMapping]: + """Queries the Postgres database through the engine. + + Args: + query (str): Query to execute on the DB. + + Raises: + Exception if the query is not a returning type.""" + async with self._engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + return results + + async def __aget_model(self, model_id: str) -> Optional[AlloyDBModel]: + """Lists the model details for a specific model_id. Returns None if it doesn't exist. + + Args: + model_id (str): A unique ID for the model endpoint that you have defined. + + """ + query = f"""SELECT * FROM + google_ml.list_model('{model_id}') + AS t(model_id VARCHAR, + model_request_url VARCHAR, + model_provider google_ml.model_provider, + model_type google_ml.model_type, + model_qualified_name VARCHAR, + model_auth_type google_ml.auth_type, + model_auth_id VARCHAR, + input_transform_fn VARCHAR, + output_transform_fn VARCHAR)""" + + try: + result = await self.__query_db(query) + except Exception: + return None + data_class = self.__convert_dict_to_dataclass(result)[0] + return data_class + + async def __alist_models(self) -> List[AlloyDBModel]: + """Lists all the models and its details.""" + query = "SELECT * FROM google_ml.model_info_view;" + result = await self.__query_db(query) + list_of_data_classes = self.__convert_dict_to_dataclass(result) + return list_of_data_classes + + async def __acreate_model( + self, + model_id: str, + model_provider: str, + model_type: str, + model_qualified_name: str, + **kwargs: dict[str, str], + ) -> None: + """Creates a registration for custom text model. + + Args: + model_id (str): A unique ID for the model endpoint that you define. + model_provider (str): The provider of the model endpoint. + model_type (str): The model type. Either text_embedding or generic. + model_qualified_name (str): The fully qualified name in case the model endpoint has multiple versions. + **kwargs : + model_request_url (str): The model-specific endpoint when adding other text embedding and generic model endpoints + model_auth_type (str): The authentication type used by the model endpoint. + model_auth_id (str): The secret ID that you set and is subsequently used when registering a model endpoint. + generate_headers_fn (str): The SQL function name you set to generate custom headers. + model_in_transform_fn (str): The SQL function name to transform input of the corresponding prediction function to the model-specific input. + model_out_transform_fn (str): The SQL function name to transform model specific output to the prediction function output. + + Raises: + :class:`DBAPIError `: if argument names mismatch create_model function specification. + """ + query = f""" + CALL + google_ml.create_model( + model_id => '{model_id}', + model_provider => '{model_provider}', + model_type => '{model_type}', + model_qualified_name => '{model_qualified_name}',""" + for key, value in kwargs.items(): + query = query + f" {key} => '{value}'," + query = query.strip(",") + query = query + ");" + async with self._engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def __adrop_model(self, model_id: str) -> None: + """Removes an already registered model. + + Args: + model_id (str): A unique ID for the model endpoint that you have defined. + """ + query = f"CALL google_ml.drop_model('{model_id}');" + async with self._engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def __fetch_google_ml_extension(self) -> float: + """Creates the Google ML Extension if it does not exist and returns the version number (Default creates version 1.3).""" + create_extension_query = """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_extension WHERE extname = 'google_ml_integration' ) + THEN CREATE EXTENSION google_ml_integration VERSION '1.3' CASCADE; + END IF; + END + $$; + """ + async with self._engine._pool.connect() as conn: + await conn.execute(text(create_extension_query)) + await conn.commit() + extension_version_query = "SELECT extversion FROM pg_extension WHERE extname = 'google_ml_integration';" + result = await self.__query_db(extension_version_query) + version = result[0]["extversion"] + return float(version) + + async def __fetch_db_flag(self) -> str: + """Fetches the enable_model_support DB flag.""" + db_flag_query = "SELECT setting FROM pg_settings where name = 'google_ml_integration.enable_model_support';" + result = await self.__query_db(db_flag_query) + flag = result[0]["setting"] + return flag + + def __convert_dict_to_dataclass( + self, list_of_rows: Sequence[RowMapping] + ) -> List[AlloyDBModel]: + """Converts a list of DB rows to list of AlloyDBModel dataclass. + + Args: + list_of_rows (Sequence[RowMapping]): A unique ID for the model endpoint that you define. + """ + list_of_dataclass = [AlloyDBModel(**row) for row in list_of_rows] + return list_of_dataclass diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 6bb60ac2..4684ac3e 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -34,6 +34,10 @@ embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) +# Note: The following texts are chosen to produce diverse +# similarity scores when using the DeterministicFakeEmbedding service. This ensures +# that the test cases can effectively validate the filtering and scoring logic. +# The scoring might be different if using a different embedding service. texts = ["foo", "bar", "baz", "boo"] ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 00000000..0b0c5220 --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,109 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio +from langchain_core.documents import Document + +from langchain_google_alloydb_pg import AlloyDBEmbeddings, AlloyDBEngine + +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +cluster_id = os.environ["CLUSTER_ID"] +instance_id = os.environ["INSTANCE_ID"] +db_name = os.environ["DATABASE_ID"] +table_name = "test-table" + str(uuid.uuid4()) + + +@pytest.mark.asyncio +class TestAlloyDBEmbeddings: + + @pytest_asyncio.fixture + async def engine(self): + AlloyDBEngine._connector = None + engine = await AlloyDBEngine.afrom_instance( + project_id=project_id, + cluster=cluster_id, + instance=instance_id, + region=region, + database=db_name, + ) + yield engine + + await engine.close() + + @pytest_asyncio.fixture + async def sync_engine(self): + AlloyDBEngine._connector = None + engine = AlloyDBEngine.from_instance( + project_id=project_id, + cluster=cluster_id, + instance=instance_id, + region=region, + database=db_name, + ) + yield engine + + await engine.close() + + @pytest.fixture(scope="module") + def model_id(self) -> str: + return "textembedding-gecko@001" + + @pytest_asyncio.fixture + def embeddings(self, engine, model_id): + return AlloyDBEmbeddings.create_sync(engine=engine, model_id=model_id) + + async def test_model_exists(self, sync_engine): + test_model_id = "test_sample_text_embedding_model" + error_message = f"Model {test_model_id} does not exist." + with pytest.raises(Exception, match=error_message): + AlloyDBEmbeddings.create_sync(engine=sync_engine, model_id=test_model_id) + + async def test_amodel_exists(self, engine): + test_model_id = "test_sample_text_embedding_model" + error_message = f"Model {test_model_id} does not exist." + with pytest.raises(Exception, match=error_message): + await AlloyDBEmbeddings.create(engine=engine, model_id=test_model_id) + + async def test_aembed_documents(self, embeddings): + with pytest.raises(NotImplementedError): + await embeddings.aembed_documents([Document(page_content="test document")]) + + async def test_embed_documents(self, embeddings): + with pytest.raises(NotImplementedError): + embeddings.embed_documents([Document(page_content="test document")]) + + async def test_embed_query(self, embeddings): + embedding = embeddings.embed_query("test document") + assert isinstance(embedding, list) + assert len(embedding) > 0 + for embedding_field in embedding: + assert isinstance(embedding_field, float) + assert -1 <= embedding_field <= 1 + + async def test_embed_query_inline(self, embeddings, model_id): + embedding_query = embeddings.embed_query_inline("test document") + assert embedding_query == f"embedding('{model_id}', 'test document')::vector" + + async def test_aembed_query(self, embeddings): + embedding = await embeddings.aembed_query("test document") + assert isinstance(embedding, list) + assert len(embedding) > 0 + for embedding_field in embedding: + assert isinstance(embedding_field, float) + assert -1 <= embedding_field <= 1 diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 00000000..3153e26d --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,103 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio + +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBModelManager + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +EMBEDDING_MODEL_NAME = "textembedding-gecko@003" + str(uuid.uuid4()).replace("-", "_") + + +@pytest.mark.asyncio +class TestAlloyDBModelManager: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest_asyncio.fixture(scope="module") + async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): + engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + cluster=db_cluster, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await engine.close() + + @pytest_asyncio.fixture(scope="module") + async def model_manager(self, engine): + model_manager = await AlloyDBModelManager.create(engine) + yield model_manager + + async def test_model_manager_constructor(self, engine): + with pytest.raises(Exception): + AlloyDBModelManager(engine=engine) + + async def test_acreate_model(self, model_manager): + await model_manager.acreate_model( + model_id=EMBEDDING_MODEL_NAME, + model_provider="google", + model_qualified_name="textembedding-gecko@003", + model_type="text_embedding", + ) + + @pytest.mark.depends(on=["test_acreate_model"]) + async def test_aget_model(self, model_manager): + model_info = await model_manager.aget_model(model_id=EMBEDDING_MODEL_NAME) + assert model_info.model_id == EMBEDDING_MODEL_NAME + + async def test_non_existent_model(self, model_manager): + model_info = await model_manager.aget_model(model_id="Non_existent_model") + assert model_info is None + + @pytest.mark.depends(on=["test_aget_model"]) + async def test_alist_models(self, model_manager): + models_list = await model_manager.alist_models() + assert len(models_list) >= 3 + model_ids = [model_info.model_id for model_info in models_list] + assert EMBEDDING_MODEL_NAME in model_ids + + @pytest.mark.depends(on=["test_alist_models"]) + async def test_adrop_model(self, model_manager): + await model_manager.adrop_model(model_id=EMBEDDING_MODEL_NAME) diff --git a/tests/test_vectorstore_embeddings.py b/tests/test_vectorstore_embeddings.py new file mode 100644 index 00000000..c215bc4e --- /dev/null +++ b/tests/test_vectorstore_embeddings.py @@ -0,0 +1,378 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio +from langchain_core.documents import Document +from sqlalchemy import text + +from langchain_google_alloydb_pg import ( + AlloyDBEmbeddings, + AlloyDBEngine, + AlloyDBVectorStore, + Column, +) +from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_EMBEDDING_MODEL = "textembedding-gecko@001" +VECTOR_SIZE = 768 + + +texts = ["foo", "bar", "baz", "boo"] +ids = [str(uuid.uuid4()) for i in range(len(texts))] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStoreEmbeddings: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def engine( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + cluster=db_cluster, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def embeddings_service(self, engine): + return await AlloyDBEmbeddings.create(engine, DEFAULT_EMBEDDING_MODEL) + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine, embeddings_service): + await engine.ainit_vectorstore_table( + DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False, overwrite_existing=True + ) + vs = await AlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + ids = [str(uuid.uuid4()) for i in range(len(texts))] + await vs.aadd_documents(docs, ids=ids) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def engine_sync( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + engine = AlloyDBEngine.from_instance( + project_id=db_project, + cluster=db_cluster, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs_custom(self, engine_sync, embeddings_service): + engine_sync.init_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + ], + store_metadata=False, + ) + + vs_custom = AlloyDBVectorStore.create_sync( + engine_sync, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + index_query_options=HNSWQueryOptions(ef_search=1), + ) + vs_custom.add_documents(docs, ids=ids) + yield vs_custom + + async def test_asimilarity_search(self, vs): + results = await vs.asimilarity_search("foo", k=1) + assert len(results) == 1 + assert results == [Document(page_content="foo")] + results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") + assert results == [Document(page_content="bar")] + + async def test_asimilarity_search_score(self, vs): + results = await vs.asimilarity_search_with_score("foo") + assert len(results) == 4 + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_asimilarity_search_by_vector(self, vs, embeddings_service): + search_embedding = embeddings_service.embed_query("foo") + results = await vs.asimilarity_search_by_vector(search_embedding) + assert len(results) == 4 + assert results[0] == Document(page_content="foo") + results = await vs.asimilarity_search_with_score_by_vector(search_embedding) + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs): + score_threshold = {"score_threshold": 0} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 4 + + score_threshold = {"score_threshold": 0.73} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 2 + + score_threshold = {"score_threshold": 0.8} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 1 + assert results[0][0] == Document(page_content="foo") + + async def test_similarity_search_with_relevance_scores_threshold_euclidean( + self, engine, embeddings_service + ): + vs = await AlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + distance_strategy=DistanceStrategy.EUCLIDEAN, + ) + + score_threshold = {"score_threshold": 0.9} + results = await vs.asimilarity_search_with_relevance_scores( + "foo", **score_threshold + ) + assert len(results) == 1 + assert results[0][0] == Document(page_content="foo") + + async def test_amax_marginal_relevance_search(self, vs): + results = await vs.amax_marginal_relevance_search("bar") + assert results[0] == Document(page_content="bar") + results = await vs.amax_marginal_relevance_search( + "bar", filter="content = 'boo'" + ) + assert results[0] == Document(page_content="boo") + + async def test_amax_marginal_relevance_search_vector(self, vs, embeddings_service): + embedding = embeddings_service.embed_query("bar") + results = await vs.amax_marginal_relevance_search_by_vector(embedding) + assert results[0] == Document(page_content="bar") + + async def test_amax_marginal_relevance_search_vector_score( + self, vs, embeddings_service + ): + embedding = embeddings_service.embed_query("bar") + results = await vs.amax_marginal_relevance_search_with_score_by_vector( + embedding + ) + assert results[0][0] == Document(page_content="bar") + + results = await vs.amax_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ) + assert results[0][0] == Document(page_content="bar") + + +class TestVectorStoreEmbeddingsSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def engine_sync( + self, + db_project, + db_region, + db_cluster, + db_instance, + db_name, + ): + engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + cluster=db_cluster, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + def embeddings_service(self, engine_sync): + return AlloyDBEmbeddings.create_sync(engine_sync, DEFAULT_EMBEDDING_MODEL) + + @pytest_asyncio.fixture(scope="class") + async def vs_custom(self, engine_sync, embeddings_service): + engine_sync.init_vectorstore_table( + DEFAULT_TABLE_SYNC, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + ], + store_metadata=False, + ) + + vs_custom = await AlloyDBVectorStore.create( + engine_sync, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE_SYNC, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + index_query_options=HNSWQueryOptions(ef_search=1), + ) + vs_custom.add_documents(docs, ids=ids) + yield vs_custom + + def test_similarity_search(self, vs_custom): + results = vs_custom.similarity_search("foo", k=1) + assert len(results) == 1 + assert results == [Document(page_content="foo")] + results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'") + assert results == [Document(page_content="bar")] + + def test_similarity_search_score(self, vs_custom): + results = vs_custom.similarity_search_with_score("foo") + assert len(results) == 4 + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + def test_similarity_search_by_vector(self, vs_custom, embeddings_service): + embedding = embeddings_service.embed_query("foo") + results = vs_custom.similarity_search_by_vector(embedding) + assert len(results) == 4 + assert results[0] == Document(page_content="foo") + results = vs_custom.similarity_search_with_score_by_vector(embedding) + assert results[0][0] == Document(page_content="foo") + assert results[0][1] == 0 + + def test_max_marginal_relevance_search(self, vs_custom): + results = vs_custom.max_marginal_relevance_search("bar") + assert results[0] == Document(page_content="bar") + results = vs_custom.max_marginal_relevance_search( + "bar", filter="mycontent = 'boo'" + ) + assert results[0] == Document(page_content="boo") + + def test_max_marginal_relevance_search_vector(self, vs_custom, embeddings_service): + embedding = embeddings_service.embed_query("bar") + results = vs_custom.max_marginal_relevance_search_by_vector(embedding) + assert results[0] == Document(page_content="bar") + + def test_max_marginal_relevance_search_vector_score( + self, vs_custom, embeddings_service + ): + embedding = embeddings_service.embed_query("bar") + results = vs_custom.max_marginal_relevance_search_with_score_by_vector( + embedding + ) + assert results[0][0] == Document(page_content="bar") + + results = vs_custom.max_marginal_relevance_search_with_score_by_vector( + embedding, lambda_mult=0.75, fetch_k=10 + ) + assert results[0][0] == Document(page_content="bar") diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 50838f2f..e17860dd 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -35,6 +35,10 @@ embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) +# Note: The following texts are chosen to produce diverse +# similarity scores when using the DeterministicFakeEmbedding service. This ensures +# that the test cases can effectively validate the filtering and scoring logic. +# The scoring might be different if using a different embedding service. texts = ["foo", "bar", "baz", "boo"] ids = [str(uuid.uuid4()) for i in range(len(texts))] metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))]