Skip to content

feat: Support Sequence[float] as query_vector in FindNearest #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Iterator,
Iterable,
NoReturn,
Sequence,
Tuple,
Union,
TYPE_CHECKING,
Expand Down Expand Up @@ -549,7 +550,7 @@ def avg(self, field_ref: str | FieldPath, alias=None):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
) -> VectorQuery:
Expand All @@ -559,7 +560,7 @@ def find_nearest(
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Iterable,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -978,7 +979,7 @@ def _to_protobuf(self) -> StructuredQuery:
def find_nearest(
self,
vector_field: str,
queryVector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
) -> BaseVectorQuery:
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC
from enum import Enum
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.cloud.firestore_v1.base_document import DocumentSnapshot
Expand Down Expand Up @@ -107,13 +107,16 @@ def get(
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
):
"""Finds the closest vector embeddings to the given query vector."""
if not isinstance(query_vector, Vector):
self._query_vector = Vector(query_vector)
else:
self._query_vector = query_vector
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
return self
16 changes: 13 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,17 @@

from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING
from typing import (
Any,
Callable,
Generator,
List,
Optional,
Sequence,
Type,
TYPE_CHECKING,
Union,
)

if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.field_path import FieldPath
Expand Down Expand Up @@ -245,7 +255,7 @@ def _retry_query_after_exception(self, exc, retry, transaction):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
) -> Type["firestore_v1.vector_query.VectorQuery"]:
Expand All @@ -255,7 +265,7 @@ def find_nearest(
Args:
vector_field(str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/v1/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from unittest import mock


def _make_commit_repsonse():
def _make_commit_response():
response = mock.create_autospec(firestore.CommitResponse)
response.write_results = [mock.sentinel.write_result]
response.commit_time = mock.sentinel.commit_time
Expand All @@ -34,7 +34,7 @@ def _make_commit_repsonse():
def _make_firestore_api():
firestore_api = mock.Mock()
firestore_api.commit.mock_add_spec(spec=["commit"])
firestore_api.commit.return_value = _make_commit_repsonse()
firestore_api.commit.return_value = _make_commit_response()
return firestore_api


Expand Down
62 changes: 62 additions & 0 deletions tests/unit/v1/test_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance):
)


def test_vector_query_list_as_query_vector():
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query"])
client = make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
query = make_query(parent)
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb1 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)
response_pb2 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = iter([response_pb1, response_pb2])

vector_query = query.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=[1.0, 2.0, 3.0],
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=5,
)

returned = vector_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == data

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
limit=5,
)
expected_pb.where = StructuredQuery.Filter(
field_filter=StructuredQuery.FieldFilter(
field=StructuredQuery.FieldReference(field_path="snooze"),
op=StructuredQuery.FieldFilter.Operator.EQUAL,
value=encode_value(10),
)
)

firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
metadata=client._rpc_metadata,
**kwargs,
)


def test_query_stream_multiple_empty_response_in_stream():
# Create a minimal fake GAPIC with a dummy response.
firestore_api = mock.Mock(spec=["run_query"])
Expand Down
Loading