diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 98f690e6d9..72f6211dd3 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -38,6 +38,7 @@ Iterator, Iterable, NoReturn, + Sequence, Tuple, Union, TYPE_CHECKING, @@ -549,7 +550,7 @@ def avg(self, field_ref: str | FieldPath, alias=None): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> VectorQuery: @@ -559,7 +560,7 @@ def find_nearest( Args: vector_field(str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c8c2f3ceb2..15525d9901 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -46,6 +46,7 @@ Iterable, NoReturn, Optional, + Sequence, Tuple, Type, TypeVar, @@ -978,7 +979,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - queryVector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> BaseVectorQuery: diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index e41717d2b5..7e5283b707 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -19,7 +19,7 @@ from abc import ABC from enum import Enum -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -107,13 +107,16 @@ def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ): """Finds the closest vector embeddings to the given query vector.""" + if not isinstance(query_vector, Vector): + self._query_vector = Vector(query_vector) + else: + self._query_vector = query_vector self._vector_field = vector_field - self._query_vector = query_vector self._limit = limit self._distance_measure = distance_measure return self diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index c46a06918a..8e71d976f6 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -41,7 +41,17 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING +from typing import ( + Any, + Callable, + Generator, + List, + Optional, + Sequence, + Type, + TYPE_CHECKING, + Union, +) if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.field_path import FieldPath @@ -245,7 +255,7 @@ def _retry_query_after_exception(self, exc, retry, transaction): def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, ) -> Type["firestore_v1.vector_query.VectorQuery"]: @@ -255,7 +265,7 @@ def find_nearest( Args: vector_field(str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. diff --git a/tests/unit/v1/test_vector.py b/tests/unit/v1/test_vector.py index 6ca1ce4134..fb8ca62af0 100644 --- a/tests/unit/v1/test_vector.py +++ b/tests/unit/v1/test_vector.py @@ -24,7 +24,7 @@ from unittest import mock -def _make_commit_repsonse(): +def _make_commit_response(): response = mock.create_autospec(firestore.CommitResponse) response.write_results = [mock.sentinel.write_result] response.commit_time = mock.sentinel.commit_time @@ -34,7 +34,7 @@ def _make_commit_repsonse(): def _make_firestore_api(): firestore_api = mock.Mock() firestore_api.commit.mock_add_spec(spec=["commit"]) - firestore_api.commit.return_value = _make_commit_repsonse() + firestore_api.commit.return_value = _make_commit_response() return firestore_api diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index 92dca45c4d..2a5c8c5a5a 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -325,6 +325,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance): ) +def test_vector_query_list_as_query_vector(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + vector_query = query.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=[1.0, 2.0, 3.0], + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + def test_query_stream_multiple_empty_response_in_stream(): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"])