diff --git a/src/nrtk_explorer/app/embeddings.py b/src/nrtk_explorer/app/embeddings.py index 58be546..a4093ed 100644 --- a/src/nrtk_explorer/app/embeddings.py +++ b/src/nrtk_explorer/app/embeddings.py @@ -1,4 +1,5 @@ from typing import Dict +import numpy as np from trame.decorators import TrameApp, change from PIL import Image from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot @@ -22,38 +23,46 @@ from trame.app import get_server, asynchronous -IdToImage = Dict[str, Image.Image] +IdToFeatures = Dict[str, Image.Image] @TrameApp() class TransformedImages: def __init__(self, server): self.server = server - self.transformed_images: IdToImage = {} + self.transformed_features: IdToFeatures = {} + self.extractor = None + + def set_extractor(self, extractor): + self.extractor = extractor def emit_update(self): - self.server.controller.update_transformed_images(self.transformed_images) + self.server.controller.update_transformed_images(self.transformed_features) + + def add_images(self, dataset_id_to_image: IdToFeatures): + features = self.extractor.extract(dataset_id_to_image.values()) + + id_to_feature = {id: features for id, features in zip(dataset_id_to_image, features)} - def add_images(self, dataset_id_to_image: IdToImage): - self.transformed_images.update(dataset_id_to_image) + self.transformed_features.update(id_to_feature) self.emit_update() @change("dataset_ids") def on_dataset_ids(self, **kwargs): - self.transformed_images = { + self.transformed_features = { k: v - for k, v in self.transformed_images.items() + for k, v in self.transformed_features.items() if image_id_to_dataset_id(k) in self.server.state.dataset_ids } self.emit_update() @change("current_dataset") def on_dataset(self, **kwargs): - self.transformed_images = {} + self.transformed_features = {} self.emit_update() def clear(self, **kwargs): - self.transformed_images = {} + self.transformed_features = {} self.emit_update() @@ -91,11 +100,11 @@ def __init__( "is_transformed": True, } + self.transformed_images = TransformedImages(server) self.clear_points_transformations() # init vars self.on_feature_extraction_model_change() - self.transformed_images = TransformedImages(server) - self.server.controller.update_transformed_images.add(self.update_transformed_images) + self.server.controller.update_transformed_images.add(self.update_transformed_points) def on_server_ready(self, *args, **kwargs): self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) @@ -110,6 +119,7 @@ def on_feature_extraction_model_change(self, **kwargs): self.extractor = embeddings_extractor.EmbeddingsExtractor( model_name=self.state.feature_extraction_model ) + self.transformed_images.set_extractor(self.extractor) def compute_points(self, fit_features, features): if len(features) == 0: @@ -180,7 +190,7 @@ async def _update_points(self): with self.state: self.compute_source_points() - self.update_transformed_images(self.transformed_images.transformed_images) + self.update_transformed_points(self.transformed_images.transformed_features) self.state.is_loading = False def update_points(self, **kwargs): @@ -207,24 +217,24 @@ def on_run_clicked(self): def on_run_transformations(self, id_to_image): self.transformed_images.add_images(id_to_image) - def update_transformed_images(self, id_to_image): + def update_transformed_points(self, id_to_features): ids_to_plot = [ id - for id in id_to_image.keys() + for id in id_to_features.keys() if image_id_to_dataset_id(id) not in self._stashed_points_transformations ] - images_to_plot = (id_to_image[id] for id in ids_to_plot) - - transformation_features = self.extractor.extract(images_to_plot) - points = self.compute_points(self.features, transformation_features) - - updated_points = { - image_id_to_dataset_id(id): point for id, point in zip(ids_to_plot, points) - } - self._stashed_points_transformations = { - **self._stashed_points_transformations, - **updated_points, - } + features = [id_to_features[id] for id in ids_to_plot] + if len(features) > 0: + features_to_plot = np.vstack(features) + points = self.compute_points(self.features, features_to_plot) + + updated_points = { + image_id_to_dataset_id(id): point for id, point in zip(ids_to_plot, points) + } + self._stashed_points_transformations = { + **self._stashed_points_transformations, + **updated_points, + } self.update_points_transformations_state() # called by category filter diff --git a/src/nrtk_explorer/app/images/images.py b/src/nrtk_explorer/app/images/images.py index 1ca4824..1f11d66 100644 --- a/src/nrtk_explorer/app/images/images.py +++ b/src/nrtk_explorer/app/images/images.py @@ -9,7 +9,7 @@ IMAGE_CACHE_SIZE_DEFAULT = 50 -AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.2 +AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.4 @TrameApp() diff --git a/src/nrtk_explorer/library/multiprocess_predictor.py b/src/nrtk_explorer/library/multiprocess_predictor.py index ff14ecd..b94d54b 100644 --- a/src/nrtk_explorer/library/multiprocess_predictor.py +++ b/src/nrtk_explorer/library/multiprocess_predictor.py @@ -80,13 +80,8 @@ def __init__(self, model_name="facebook/detr-resnet-50", force_cpu=False): self._start_process() - # Instead of a response thread, schedule an async task: asyncio.ensure_future(self._poll_responses()) - - def handle_shutdown(signum, frame): - self.shutdown() - - signal.signal(signal.SIGINT, handle_shutdown) + self.loop.add_signal_handler(signal.SIGINT, self.shutdown) def _start_process(self): with self._lock: