Skip to content

Commit

Permalink
fix(embeddings): only cache transformed image features
Browse files Browse the repository at this point in the history
Don't cache the full transformed image in embedding app.
Was running out of memory.
  • Loading branch information
PaulHax committed Feb 3, 2025
1 parent 54c0fed commit 430a8e4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
62 changes: 36 additions & 26 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import numpy as np
from trame.decorators import TrameApp, change
from PIL import Image
from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot
Expand All @@ -22,38 +23,46 @@
from trame.app import get_server, asynchronous


IdToImage = Dict[str, Image.Image]
IdToFeatures = Dict[str, Image.Image]


@TrameApp()
class TransformedImages:
def __init__(self, server):
self.server = server
self.transformed_images: IdToImage = {}
self.transformed_features: IdToFeatures = {}
self.extractor = None

def set_extractor(self, extractor):
self.extractor = extractor

def emit_update(self):
self.server.controller.update_transformed_images(self.transformed_images)
self.server.controller.update_transformed_images(self.transformed_features)

def add_images(self, dataset_id_to_image: IdToFeatures):
features = self.extractor.extract(dataset_id_to_image.values())

id_to_feature = {id: features for id, features in zip(dataset_id_to_image, features)}

def add_images(self, dataset_id_to_image: IdToImage):
self.transformed_images.update(dataset_id_to_image)
self.transformed_features.update(id_to_feature)
self.emit_update()

@change("dataset_ids")
def on_dataset_ids(self, **kwargs):
self.transformed_images = {
self.transformed_features = {
k: v
for k, v in self.transformed_images.items()
for k, v in self.transformed_features.items()
if image_id_to_dataset_id(k) in self.server.state.dataset_ids
}
self.emit_update()

@change("current_dataset")
def on_dataset(self, **kwargs):
self.transformed_images = {}
self.transformed_features = {}
self.emit_update()

def clear(self, **kwargs):
self.transformed_images = {}
self.transformed_features = {}
self.emit_update()


Expand Down Expand Up @@ -91,11 +100,11 @@ def __init__(
"is_transformed": True,
}

self.transformed_images = TransformedImages(server)
self.clear_points_transformations() # init vars
self.on_feature_extraction_model_change()

self.transformed_images = TransformedImages(server)
self.server.controller.update_transformed_images.add(self.update_transformed_images)
self.server.controller.update_transformed_images.add(self.update_transformed_points)

def on_server_ready(self, *args, **kwargs):
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
Expand All @@ -110,6 +119,7 @@ def on_feature_extraction_model_change(self, **kwargs):
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=self.state.feature_extraction_model
)
self.transformed_images.set_extractor(self.extractor)

def compute_points(self, fit_features, features):
if len(features) == 0:
Expand Down Expand Up @@ -180,7 +190,7 @@ async def _update_points(self):

with self.state:
self.compute_source_points()
self.update_transformed_images(self.transformed_images.transformed_images)
self.update_transformed_points(self.transformed_images.transformed_features)
self.state.is_loading = False

def update_points(self, **kwargs):
Expand All @@ -207,24 +217,24 @@ def on_run_clicked(self):
def on_run_transformations(self, id_to_image):
self.transformed_images.add_images(id_to_image)

def update_transformed_images(self, id_to_image):
def update_transformed_points(self, id_to_features):
ids_to_plot = [
id
for id in id_to_image.keys()
for id in id_to_features.keys()
if image_id_to_dataset_id(id) not in self._stashed_points_transformations
]
images_to_plot = (id_to_image[id] for id in ids_to_plot)

transformation_features = self.extractor.extract(images_to_plot)
points = self.compute_points(self.features, transformation_features)

updated_points = {
image_id_to_dataset_id(id): point for id, point in zip(ids_to_plot, points)
}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
}
features = [id_to_features[id] for id in ids_to_plot]
if len(features) > 0:
features_to_plot = np.vstack(features)
points = self.compute_points(self.features, features_to_plot)

updated_points = {
image_id_to_dataset_id(id): point for id, point in zip(ids_to_plot, points)
}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
}
self.update_points_transformations_state()

# called by category filter
Expand Down
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


IMAGE_CACHE_SIZE_DEFAULT = 50
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.2
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.4


@TrameApp()
Expand Down
7 changes: 1 addition & 6 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,8 @@ def __init__(self, model_name="facebook/detr-resnet-50", force_cpu=False):

self._start_process()

# Instead of a response thread, schedule an async task:
asyncio.ensure_future(self._poll_responses())

def handle_shutdown(signum, frame):
self.shutdown()

signal.signal(signal.SIGINT, handle_shutdown)
self.loop.add_signal_handler(signal.SIGINT, self.shutdown)

def _start_process(self):
with self._lock:
Expand Down

0 comments on commit 430a8e4

Please sign in to comment.