Skip to content

Commit

Permalink
feat(embeddings): remove batch size parameter
Browse files Browse the repository at this point in the history
App should figure out batch size automatically by
catching out of memory errors.
  • Loading branch information
PaulHax committed Jan 21, 2025
1 parent c2448c8 commit 1b246fa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
20 changes: 3 additions & 17 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def on_server_ready(self, *args, **kwargs):
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)

def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=feature_extraction_model
model_name=self.state.feature_extraction_model
)

def compute_points(self, fit_features, features):
Expand Down Expand Up @@ -160,10 +159,7 @@ def compute_source_points(self):
images = [
self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids
]
self.features = self.extractor.extract(
images,
batch_size=int(self.state.model_batch_size),
)
self.features = self.extractor.extract(images)

points = self.compute_points(self.features, self.features)

Expand Down Expand Up @@ -219,10 +215,7 @@ def update_transformed_images(self, id_to_image):
if image_id_to_dataset_id(id) not in self._stashed_points_transformations
}

transformation_features = self.extractor.extract(
list(new_to_plot.values()),
batch_size=int(self.state.model_batch_size),
)
transformation_features = self.extractor.extract(list(new_to_plot.values()))
points = self.compute_points(self.features, transformation_features)
image_id_to_point = zip(new_to_plot.keys(), points)

Expand Down Expand Up @@ -312,13 +305,6 @@ def settings_widget(self):
emit_value=True,
map_options=True,
)
quasar.QInput(
v_model=("model_batch_size", 32),
filled=True,
stack_label=True,
label="Batch Size",
type="number",
)

with html.Div(classes="col"):
with quasar.QTabs(
Expand Down
24 changes: 15 additions & 9 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import DataLoader, Dataset

IMAGE_MODEL_RESOLUTION = (224, 224)
STARTING_BATCH_SIZE = 32


# Create a dataset for images
Expand All @@ -26,6 +27,10 @@ class EmbeddingsExtractor:
def __init__(self, model_name="resnet50d", force_cpu=False):
self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
self.model = model_name
self.reset()

def reset(self):
self.batch_size = STARTING_BATCH_SIZE

@property
def device(self):
Expand Down Expand Up @@ -57,20 +62,21 @@ def transform_image(self, image: Image):
img = image.resize(IMAGE_MODEL_RESOLUTION).convert("RGB")
return self._model_transformer(img).unsqueeze(0)

def extract(self, images, batch_size=32):
def extract(self, images, batch_size=0):
"""Extract features from images"""
if len(images) == 0:
return []

if batch_size != 0:
self.batch_size = batch_size

features = list()
transformed_images = [self.transform_image(img) for img in images]

# Extract features from images
adjusted_batch_size = batch_size
while adjusted_batch_size > 0:
while self.batch_size > 0:
try:
for batch in DataLoader(
ImagesDataset(transformed_images), batch_size=adjusted_batch_size
ImagesDataset(transformed_images), batch_size=self.batch_size
):
# Copy image to device if using device
if self.device.type == "cuda":
Expand All @@ -80,11 +86,11 @@ def extract(self, images, batch_size=32):
return np.vstack(features)

except RuntimeError as e:
if "out of memory" in str(e) and adjusted_batch_size > 1:
previous_batch_size = adjusted_batch_size
adjusted_batch_size = adjusted_batch_size // 2
if "out of memory" in str(e) and self.batch_size > 1:
previous_batch_size = self.batch_size
self.batch_size //= 2
print(
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}"
f"Changing extract batch_size from {previous_batch_size} to {self.batch_size} because caught out of memory exception:\n{e}"
)
else:
raise
Expand Down

0 comments on commit 1b246fa

Please sign in to comment.