@@ -375,6 +375,49 @@ def label_ids(self):
375
375
"""
376
376
return None
377
377
378
+ def get_index_ids (self ):
379
+ """Returns the list of IDs in the full index.
380
+
381
+ All backends support this method. If the backend supports
382
+ :meth:`sample_ids` and :meth:`label_ids`, then the appropriate primary
383
+ keys are returned. For other backends, this operation can take some
384
+ time as we must query the backend sequentially to retrieve these.
385
+
386
+ Returns:
387
+ the list of sample IDs (or label IDs for patch indexes) in the full
388
+ index
389
+ """
390
+ if self .config .patches_field is not None :
391
+ index_ids = self .label_ids
392
+ else :
393
+ index_ids = self .sample_ids
394
+
395
+ if index_ids is not None :
396
+ return index_ids
397
+
398
+ # Unfortunately for this index, the only way to infer the available IDs
399
+ # is to download all embeddings
400
+
401
+ logger .info ("Retrieving IDs from index. This can take awhile..." )
402
+
403
+ sample_ids , label_ids = fbu .get_ids (
404
+ self ._samples , patches_field = self .config .patches_field
405
+ )
406
+
407
+ _ , sample_ids , label_ids = self .get_embeddings (
408
+ sample_ids = sample_ids ,
409
+ label_ids = label_ids ,
410
+ allow_missing = True ,
411
+ warn_missing = False ,
412
+ )
413
+
414
+ if self .config .patches_field is not None :
415
+ index_ids = label_ids
416
+ else :
417
+ index_ids = sample_ids
418
+
419
+ return index_ids
420
+
378
421
@property
379
422
def total_index_size (self ):
380
423
"""The total number of data points in the index.
@@ -948,22 +991,12 @@ def compute_embeddings(
948
991
model = self .get_model ()
949
992
950
993
if skip_existing :
951
- if self .config .patches_field is not None :
952
- index_ids = self .label_ids
953
- else :
954
- index_ids = self .sample_ids
955
-
956
- if index_ids is not None :
957
- samples = fbu .skip_ids (
958
- samples ,
959
- index_ids ,
960
- patches_field = self .config .patches_field ,
961
- warn_existing = warn_existing ,
962
- )
963
- else :
964
- logger .warning (
965
- "This index does not support skipping existing IDs"
966
- )
994
+ samples = fbu .skip_ids (
995
+ samples ,
996
+ self .get_index_ids (),
997
+ patches_field = self .config .patches_field ,
998
+ warn_existing = warn_existing ,
999
+ )
967
1000
968
1001
if self .config .roi_field is not None :
969
1002
patches_field = self .config .roi_field
@@ -988,6 +1021,92 @@ def compute_embeddings(
988
1021
progress = progress ,
989
1022
)
990
1023
1024
+ def update_index (
1025
+ self ,
1026
+ samples = None ,
1027
+ model = None ,
1028
+ overwrite = False ,
1029
+ batch_size = None ,
1030
+ num_workers = None ,
1031
+ skip_failures = True ,
1032
+ force_square = False ,
1033
+ alpha = None ,
1034
+ progress = None ,
1035
+ reload = True ,
1036
+ ):
1037
+ """Updates the index, if necessary, by adding embeddings for any
1038
+ samples that are not already present in the index.
1039
+
1040
+ Args:
1041
+ samples (None): a
1042
+ :class:`fiftyone.core.collections.SampleCollection` for which
1043
+ to update the index. By default, :meth:`samples` is used
1044
+ model (None): a :class:`fiftyone.core.models.Model` to use to
1045
+ generate embeddings. If not provided, these results must have
1046
+ been created with a stored model, which will be used by default
1047
+ overwrite (False): whether to regenerate embeddings for
1048
+ sample/label IDs that are already in the index
1049
+ batch_size (None): an optional batch size to use when computing
1050
+ embeddings. Only applicable when a ``model`` is provided
1051
+ num_workers (None): the number of workers to use when loading
1052
+ images. Only applicable when a Torch-based model is being used
1053
+ to compute embeddings
1054
+ skip_failures (True): whether to gracefully continue without
1055
+ raising an error if embeddings cannot be generated for a sample
1056
+ force_square (False): whether to minimally manipulate the patch
1057
+ bounding boxes into squares prior to extraction. Only
1058
+ applicable when a ``model`` and ``patches_field`` are specified
1059
+ alpha (None): an optional expansion/contraction to apply to the
1060
+ patches before extracting them, in ``[-1, inf)``. If provided,
1061
+ the length and width of the box are expanded (or contracted,
1062
+ when ``alpha < 0``) by ``(100 * alpha)%``. For example, set
1063
+ ``alpha = 0.1`` to expand the boxes by 10%, and set
1064
+ ``alpha = -0.1`` to contract the boxes by 10%. Only applicable
1065
+ when a ``model`` and ``patches_field`` are specified
1066
+ progress (None): whether to render a progress bar (True/False), use
1067
+ the default value ``fiftyone.config.show_progress_bars``
1068
+ (None), or a progress callback function to invoke instead
1069
+ reload (True): whether to call :meth:`reload` to refresh the
1070
+ current view after the update
1071
+ """
1072
+ if samples is None :
1073
+ samples = self ._samples
1074
+
1075
+ embeddings , sample_ids , label_ids = self .compute_embeddings (
1076
+ samples ,
1077
+ model = model ,
1078
+ batch_size = batch_size ,
1079
+ num_workers = num_workers ,
1080
+ skip_failures = skip_failures ,
1081
+ skip_existing = not overwrite ,
1082
+ warn_existing = False ,
1083
+ force_square = force_square ,
1084
+ alpha = alpha ,
1085
+ progress = progress ,
1086
+ )
1087
+
1088
+ num_added = len (embeddings )
1089
+ if num_added == 0 :
1090
+ logger .info ("Index is already up to date" )
1091
+ return
1092
+
1093
+ logger .info (f"Adding { num_added } embeddings to the index..." )
1094
+ self .add_to_index (
1095
+ embeddings ,
1096
+ sample_ids ,
1097
+ label_ids = label_ids ,
1098
+ overwrite = overwrite ,
1099
+ allow_existing = True ,
1100
+ warn_existing = False ,
1101
+ reload = reload ,
1102
+ )
1103
+
1104
+ if (
1105
+ self .config .method == "sklearn"
1106
+ and self .config .embeddings_field is None
1107
+ ):
1108
+ self .save ()
1109
+
991
1110
@classmethod
992
1111
def _from_dict (cls , d , samples , config , brain_key ):
993
1112
"""Builds a :class:`SimilarityIndex` from a JSON representation of it.
0 commit comments