Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union, Any

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
Expand Down Expand Up @@ -208,7 +208,8 @@ class MatchNeighbor:
For example, values [1,2,3] with dimensions [4,5,6] means value 1 is
of the 4th dimension, value 2 is of the 4th dimension, and value 3 is
of the 6th dimension.

embedding_metadata (Dict[str,Any]):
Optional. The embedding metadata of the matching datapoint.
"""

id: str
Expand All @@ -220,6 +221,7 @@ class MatchNeighbor:
numeric_restricts: Optional[List[NumericNamespace]] = None
sparse_embedding_values: Optional[List[float]] = None
sparse_embedding_dimensions: Optional[List[int]] = None
embedding_metadata: Optional[Dict[str,Any]] = None

def from_index_datapoint(
self, index_datapoint: gca_index_v1beta1.IndexDatapoint
Expand Down Expand Up @@ -276,6 +278,8 @@ def from_index_datapoint(
self.sparse_embedding_dimensions = (
index_datapoint.sparse_embedding.dimensions
)
if index_datapoint.embedding_metadata is not None:
self.embedding_metadata = dict(index_datapoint.embedding_metadata)
return self

def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":
Expand Down