Skip to content

Commit 638e4ea

Browse files
committed
adding update_index() method for similarity indexes
1 parent 2e673cf commit 638e4ea

File tree

1 file changed

+135
-16
lines changed

1 file changed

+135
-16
lines changed

fiftyone/brain/similarity.py

Lines changed: 135 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,49 @@ def label_ids(self):
375375
"""
376376
return None
377377

378+
def get_index_ids(self):
379+
"""Returns the list of IDs in the full index.
380+
381+
All backends support this method. If the backend supports
382+
:meth:`sample_ids` and :meth:`label_ids`, then the appropriate primary
383+
keys are returned. For other backends, this operation can take some
384+
time as we must query the backend sequentially to retrieve these.
385+
386+
Returns:
387+
the list of sample IDs (or label IDs for patch indexes) in the full
388+
index
389+
"""
390+
if self.config.patches_field is not None:
391+
index_ids = self.label_ids
392+
else:
393+
index_ids = self.sample_ids
394+
395+
if index_ids is not None:
396+
return index_ids
397+
398+
# Unfortunately for this index, the only way to infer the available IDs
399+
# is to download all embeddings
400+
401+
logger.info("Retrieving IDs from index. This can take awhile...")
402+
403+
sample_ids, label_ids = fbu.get_ids(
404+
self._samples, patches_field=self.config.patches_field
405+
)
406+
407+
_, sample_ids, label_ids = self.get_embeddings(
408+
sample_ids=sample_ids,
409+
label_ids=label_ids,
410+
allow_missing=True,
411+
warn_missing=False,
412+
)
413+
414+
if self.config.patches_field is not None:
415+
index_ids = label_ids
416+
else:
417+
index_ids = sample_ids
418+
419+
return index_ids
420+
378421
@property
379422
def total_index_size(self):
380423
"""The total number of data points in the index.
@@ -948,22 +991,12 @@ def compute_embeddings(
948991
model = self.get_model()
949992

950993
if skip_existing:
951-
if self.config.patches_field is not None:
952-
index_ids = self.label_ids
953-
else:
954-
index_ids = self.sample_ids
955-
956-
if index_ids is not None:
957-
samples = fbu.skip_ids(
958-
samples,
959-
index_ids,
960-
patches_field=self.config.patches_field,
961-
warn_existing=warn_existing,
962-
)
963-
else:
964-
logger.warning(
965-
"This index does not support skipping existing IDs"
966-
)
994+
samples = fbu.skip_ids(
995+
samples,
996+
self.get_index_ids(),
997+
patches_field=self.config.patches_field,
998+
warn_existing=warn_existing,
999+
)
9671000

9681001
if self.config.roi_field is not None:
9691002
patches_field = self.config.roi_field
@@ -988,6 +1021,92 @@ def compute_embeddings(
9881021
progress=progress,
9891022
)
9901023

1024+
def update_index(
1025+
self,
1026+
samples=None,
1027+
model=None,
1028+
overwrite=False,
1029+
batch_size=None,
1030+
num_workers=None,
1031+
skip_failures=True,
1032+
force_square=False,
1033+
alpha=None,
1034+
progress=None,
1035+
reload=True,
1036+
):
1037+
"""Updates the index, if necessary, by adding embeddings for any
1038+
samples that are not already present in the index.
1039+
1040+
Args:
1041+
samples (None): a
1042+
:class:`fiftyone.core.collections.SampleCollection` for which
1043+
to update the index. By default, :meth:`samples` is used
1044+
model (None): a :class:`fiftyone.core.models.Model` to use to
1045+
generate embeddings. If not provided, these results must have
1046+
been created with a stored model, which will be used by default
1047+
overwrite (False): whether to regenerate embeddings for
1048+
sample/label IDs that are already in the index
1049+
batch_size (None): an optional batch size to use when computing
1050+
embeddings. Only applicable when a ``model`` is provided
1051+
num_workers (None): the number of workers to use when loading
1052+
images. Only applicable when a Torch-based model is being used
1053+
to compute embeddings
1054+
skip_failures (True): whether to gracefully continue without
1055+
raising an error if embeddings cannot be generated for a sample
1056+
force_square (False): whether to minimally manipulate the patch
1057+
bounding boxes into squares prior to extraction. Only
1058+
applicable when a ``model`` and ``patches_field`` are specified
1059+
alpha (None): an optional expansion/contraction to apply to the
1060+
patches before extracting them, in ``[-1, inf)``. If provided,
1061+
the length and width of the box are expanded (or contracted,
1062+
when ``alpha < 0``) by ``(100 * alpha)%``. For example, set
1063+
``alpha = 0.1`` to expand the boxes by 10%, and set
1064+
``alpha = -0.1`` to contract the boxes by 10%. Only applicable
1065+
when a ``model`` and ``patches_field`` are specified
1066+
progress (None): whether to render a progress bar (True/False), use
1067+
the default value ``fiftyone.config.show_progress_bars``
1068+
(None), or a progress callback function to invoke instead
1069+
reload (True): whether to call :meth:`reload` to refresh the
1070+
current view after the update
1071+
"""
1072+
if samples is None:
1073+
samples = self._samples
1074+
1075+
embeddings, sample_ids, label_ids = self.compute_embeddings(
1076+
samples,
1077+
model=model,
1078+
batch_size=batch_size,
1079+
num_workers=num_workers,
1080+
skip_failures=skip_failures,
1081+
skip_existing=not overwrite,
1082+
warn_existing=False,
1083+
force_square=force_square,
1084+
alpha=alpha,
1085+
progress=progress,
1086+
)
1087+
1088+
num_added = len(embeddings)
1089+
if num_added == 0:
1090+
logger.info("Index is already up to date")
1091+
return
1092+
1093+
logger.info(f"Adding {num_added} embeddings to the index...")
1094+
self.add_to_index(
1095+
embeddings,
1096+
sample_ids,
1097+
label_ids=label_ids,
1098+
overwrite=overwrite,
1099+
allow_existing=True,
1100+
warn_existing=False,
1101+
reload=reload,
1102+
)
1103+
1104+
if (
1105+
self.config.method == "sklearn"
1106+
and self.config.embeddings_field is None
1107+
):
1108+
self.save()
1109+
9911110
@classmethod
9921111
def _from_dict(cls, d, samples, config, brain_key):
9931112
"""Builds a :class:`SimilarityIndex` from a JSON representation of it.

0 commit comments

Comments
 (0)