diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index d3c0da36..fe2292c7 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -111,15 +111,19 @@ def __init__(self, server): self.server.controller.add("on_server_ready")(self.on_server_ready) self._on_hover_fn = None - self.detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50") def on_server_ready(self, *args, **kwargs): # Bind instance methods to state change self.state.change("current_dataset")(self.on_current_dataset_change) self.state.change("current_num_elements")(self.on_current_num_elements_change) + self.state.change("object_detection_model")(self.on_object_detection_model_change) + self.on_object_detection_model_change(self.state.object_detection_model) self.on_current_dataset_change(self.state.current_dataset) + def on_object_detection_model_change(self, model_name, **kwargs): + self.detector = object_detector.ObjectDetector(model_name=model_name) + def set_on_transform(self, fn): self._on_transform_fn = fn @@ -191,7 +195,11 @@ def compute_annotations(self, ids): if len(ids) == 0: return - predictions = self.detector.eval(image_ids=ids, content=self.context.image_objects) + predictions = self.detector.eval( + image_ids=ids, + content=self.context.image_objects, + batch_size=int(self.state.object_detection_batch_size), + ) for id_, annotations in predictions.items(): image_annotations = [] diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py index 51999273..fd6bb3f1 100644 --- a/src/nrtk_explorer/app/ui/layout.py +++ b/src/nrtk_explorer/app/ui/layout.py @@ -96,6 +96,13 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf emit_value=True, map_options=True, ) + quasar.QInput( + v_model=("object_detection_batch_size", 32), + filled=True, + stack_label=True, + label="Batch Size", + type="number", + ) filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter") diff --git a/src/nrtk_explorer/library/embeddings_extractor.py b/src/nrtk_explorer/library/embeddings_extractor.py index 1cc3e700..2c3eedd9 100644 --- a/src/nrtk_explorer/library/embeddings_extractor.py +++ b/src/nrtk_explorer/library/embeddings_extractor.py @@ -1,7 +1,8 @@ -import torch +import gc import logging import numpy as np import timm +import torch from nrtk_explorer.library import images_manager from torch.utils.data import DataLoader, Dataset @@ -75,11 +76,33 @@ def extract(self, paths, content=None, batch_size=32): transformed_images.append(self.transform_image(img)) # Extract features from images - for batch in DataLoader(ImagesDataset(transformed_images), batch_size=batch_size): - # Copy image to device if using device - if self.device.type == "cuda": - batch = batch.cuda() - - features.append(self.model(batch).numpy(force=True)) - - return np.vstack(features) + adjusted_batch_size = batch_size + while adjusted_batch_size > 0: + try: + for batch in DataLoader( + ImagesDataset(transformed_images), batch_size=adjusted_batch_size + ): + # Copy image to device if using device + if self.device.type == "cuda": + batch = batch.cuda() + + features.append(self.model(batch).numpy(force=True)) + return np.vstack(features) + + except RuntimeError as e: + if "out of memory" in str(e) and adjusted_batch_size > 1: + previous_batch_size = adjusted_batch_size + adjusted_batch_size = adjusted_batch_size // 2 + print( + f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}" + ) + else: + raise + + finally: + # Pytorch needs to freed its allocations outside of the exception context + gc.collect() + torch.cuda.empty_cache() + + # We should never reach here + return None diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index b0171002..c64875da 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -1,3 +1,4 @@ +import gc import logging import torch import transformers @@ -6,7 +7,7 @@ from nrtk_explorer.library import images_manager -ImageIdToAnnotations = dict[str, Sequence[dict]] +ImageIdToAnnotations = Optional[dict[str, Sequence[dict]]] class ObjectDetector: @@ -77,17 +78,38 @@ def eval( batches[img.size][0].append(path) batches[img.size][1].append(img) - predictions_in_baches = [ - zip( - image_ids, - self.pipeline(images, batch_size=batch_size), - ) - for image_ids, images in batches.values() - ] - - predictions_by_image_id = { - image_id: predictions - for batch in predictions_in_baches - for image_id, predictions in batch - } - return predictions_by_image_id + adjusted_batch_size = batch_size + while adjusted_batch_size > 0: + try: + predictions_in_baches = [ + zip( + image_ids, + self.pipeline(images, batch_size=adjusted_batch_size), + ) + for image_ids, images in batches.values() + ] + + predictions_by_image_id = { + image_id: predictions + for batch in predictions_in_baches + for image_id, predictions in batch + } + return predictions_by_image_id + + except RuntimeError as e: + if "out of memory" in str(e) and adjusted_batch_size > 1: + previous_batch_size = adjusted_batch_size + adjusted_batch_size = adjusted_batch_size // 2 + print( + f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}" + ) + else: + raise + + finally: + # Pytorch needs to freed its allocations outside of the exception context + gc.collect() + torch.cuda.empty_cache() + + # We should never reach here + return None