From 1b246fa786a1c0908d423b5c9140eec93312ab2a Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Fri, 17 Jan 2025 13:02:09 -0500 Subject: [PATCH] feat(embeddings): remove batch size parameter App should figure out batch size automatically by catching out of memory errors. --- src/nrtk_explorer/app/embeddings.py | 20 +++------------- .../library/embeddings_extractor.py | 24 ++++++++++++------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/nrtk_explorer/app/embeddings.py b/src/nrtk_explorer/app/embeddings.py index 49603301..9abb60f8 100644 --- a/src/nrtk_explorer/app/embeddings.py +++ b/src/nrtk_explorer/app/embeddings.py @@ -108,9 +108,8 @@ def on_server_ready(self, *args, **kwargs): self.state.change("transform_enabled_switch")(self.update_points_transformations_state) def on_feature_extraction_model_change(self, **kwargs): - feature_extraction_model = self.state.feature_extraction_model self.extractor = embeddings_extractor.EmbeddingsExtractor( - model_name=feature_extraction_model + model_name=self.state.feature_extraction_model ) def compute_points(self, fit_features, features): @@ -160,10 +159,7 @@ def compute_source_points(self): images = [ self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids ] - self.features = self.extractor.extract( - images, - batch_size=int(self.state.model_batch_size), - ) + self.features = self.extractor.extract(images) points = self.compute_points(self.features, self.features) @@ -219,10 +215,7 @@ def update_transformed_images(self, id_to_image): if image_id_to_dataset_id(id) not in self._stashed_points_transformations } - transformation_features = self.extractor.extract( - list(new_to_plot.values()), - batch_size=int(self.state.model_batch_size), - ) + transformation_features = self.extractor.extract(list(new_to_plot.values())) points = self.compute_points(self.features, transformation_features) image_id_to_point = zip(new_to_plot.keys(), points) @@ -312,13 +305,6 @@ def settings_widget(self): emit_value=True, map_options=True, ) - quasar.QInput( - v_model=("model_batch_size", 32), - filled=True, - stack_label=True, - label="Batch Size", - type="number", - ) with html.Div(classes="col"): with quasar.QTabs( diff --git a/src/nrtk_explorer/library/embeddings_extractor.py b/src/nrtk_explorer/library/embeddings_extractor.py index 361ffe24..edd28f0a 100644 --- a/src/nrtk_explorer/library/embeddings_extractor.py +++ b/src/nrtk_explorer/library/embeddings_extractor.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader, Dataset IMAGE_MODEL_RESOLUTION = (224, 224) +STARTING_BATCH_SIZE = 32 # Create a dataset for images @@ -26,6 +27,10 @@ class EmbeddingsExtractor: def __init__(self, model_name="resnet50d", force_cpu=False): self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" self.model = model_name + self.reset() + + def reset(self): + self.batch_size = STARTING_BATCH_SIZE @property def device(self): @@ -57,20 +62,21 @@ def transform_image(self, image: Image): img = image.resize(IMAGE_MODEL_RESOLUTION).convert("RGB") return self._model_transformer(img).unsqueeze(0) - def extract(self, images, batch_size=32): + def extract(self, images, batch_size=0): """Extract features from images""" if len(images) == 0: return [] + if batch_size != 0: + self.batch_size = batch_size + features = list() transformed_images = [self.transform_image(img) for img in images] - # Extract features from images - adjusted_batch_size = batch_size - while adjusted_batch_size > 0: + while self.batch_size > 0: try: for batch in DataLoader( - ImagesDataset(transformed_images), batch_size=adjusted_batch_size + ImagesDataset(transformed_images), batch_size=self.batch_size ): # Copy image to device if using device if self.device.type == "cuda": @@ -80,11 +86,11 @@ def extract(self, images, batch_size=32): return np.vstack(features) except RuntimeError as e: - if "out of memory" in str(e) and adjusted_batch_size > 1: - previous_batch_size = adjusted_batch_size - adjusted_batch_size = adjusted_batch_size // 2 + if "out of memory" in str(e) and self.batch_size > 1: + previous_batch_size = self.batch_size + self.batch_size //= 2 print( - f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}" + f"Changing extract batch_size from {previous_batch_size} to {self.batch_size} because caught out of memory exception:\n{e}" ) else: raise