diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7de577db..6f0a4c41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,9 +46,6 @@ jobs: run: npm run eslint unit_tests: - needs: - - linters_python - - linters_vue runs-on: ubuntu-latest name: ubuntu-latest-tests-python${{ matrix.python-version}} strategy: @@ -81,7 +78,7 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip setuptools wheel "build<0.10.0" python-semantic-release + python -m pip install --upgrade pip setuptools wheel "build<1.3.0" python-semantic-release - name: Python Semantic Release id: release diff --git a/.github/workflows/create_release.yaml b/.github/workflows/create_release.yaml index 795a3b66..d1c8d85a 100644 --- a/.github/workflows/create_release.yaml +++ b/.github/workflows/create_release.yaml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip setuptools wheel "build<0.10.0" python-semantic-release + python -m pip install --upgrade pip setuptools wheel "build<1.3.0" python-semantic-release - name: Python Semantic Release id: release diff --git a/README.md b/README.md index 184656b0..e007998a 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,8 @@ insights of a image dataset in [COCO][3] format and it evaluate image transformation and perturbation resilience of object recognition DL models. It is built using [trame][1] by the [kitware][2] team. -![nrtk explorer](https://raw.githubusercontent.com/Kitware/nrtk-explorer/main/screenshot.png) +![nrtk explorer screenshot](https://github.com/user-attachments/assets/85c95836-3490-40ec-813d-e6841c540d51) + Features -------- diff --git a/captain-definition b/captain-definition new file mode 100644 index 00000000..d593db3e --- /dev/null +++ b/captain-definition @@ -0,0 +1,4 @@ +{ + "schemaVersion": 2, + "dockerfilePath": "./docker/Dockerfile" +} \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..12e00118 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,3 @@ +FROM kitware/trame:py3.10-glvnd-2024-09 +COPY --chown=trame-user:trame-user ./docker /deploy +RUN /opt/trame/entrypoint.sh build diff --git a/docker/nrtk_explorer-0.3.1-py2.py3-none-any.whl b/docker/nrtk_explorer-0.3.1-py2.py3-none-any.whl new file mode 100644 index 00000000..256fdb65 Binary files /dev/null and b/docker/nrtk_explorer-0.3.1-py2.py3-none-any.whl differ diff --git a/docker/setup/apps.yml b/docker/setup/apps.yml new file mode 100644 index 00000000..261d220b --- /dev/null +++ b/docker/setup/apps.yml @@ -0,0 +1,17 @@ +trame: + www_modules: + - nrtk_explorer.module + cmd: + - python + - -m + - nrtk_explorer.app.main + - --host + - ${host} + - --port + - ${port} + - --authKey + - ${secret} + - --server + - --banner + - --dataset + - /data/OIRDS_v1_0/oirds.json diff --git a/docker/setup/initialize.sh b/docker/setup/initialize.sh new file mode 100644 index 00000000..a8e16e28 --- /dev/null +++ b/docker/setup/initialize.sh @@ -0,0 +1 @@ +pip install /deploy/nrtk_explorer-0.3.1-py2.py3-none-any.whl diff --git a/docker/setup/requirements.txt b/docker/setup/requirements.txt new file mode 100644 index 00000000..831c5c2d --- /dev/null +++ b/docker/setup/requirements.txt @@ -0,0 +1,2 @@ +nrtk[headless] +# nrtk-explorer diff --git a/pyproject.toml b/pyproject.toml index 2888d97e..31a16da8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,16 +32,15 @@ dependencies = [ "numpy", "Pillow", "pybsm>=0.6", - "scikit-learn==1.5.1", + "scikit-learn==1.5.2", "smqtk_image_io", "tabulate", "timm>=1.0.3", "torch", "torchvision", - "trame", - "trame-client>=2.15.0", + "trame>=3.6", "trame-quasar", - "trame-server>=2.15.0", + "trame-server>=3.2", "transformers", "umap-learn", ] @@ -60,7 +59,7 @@ dev = [ ] package = [ - "build<0.10.0", + "build<1.3.0", "python-semantic-release", "setuptools", "wheel", @@ -78,10 +77,10 @@ build-backend = "hatchling.build" [tool.hatch.build.hooks.custom] [project.scripts] -nrtk-explorer = "nrtk_explorer.app:main" -nrtk-explorer-embeddings = "nrtk_explorer.app.embeddings:embeddings" -nrtk-explorer-transforms = "nrtk_explorer.app.transforms:transforms" -nrtk-explorer-filtering = "nrtk_explorer.app.filtering:filtering" +nrtk-explorer = "nrtk_explorer.app.main:main" +nrtk-explorer-embeddings = "nrtk_explorer.app.embeddings:main" +nrtk-explorer-transforms = "nrtk_explorer.app.transforms:main" +nrtk-explorer-filtering = "nrtk_explorer.app.filtering:main" [tool.black] line-length = 99 @@ -126,7 +125,7 @@ build_command = """ python -m venv .venv source .venv/bin/activate pip install -U pip - python -m pip install "build<0.10.0" python-semantic-release setuptools wheel + python -m pip install "build<1.3.0" python-semantic-release setuptools wheel python -m build . """ diff --git a/src/nrtk_explorer/__main__.py b/src/nrtk_explorer/__main__.py new file mode 100644 index 00000000..56864a5e --- /dev/null +++ b/src/nrtk_explorer/__main__.py @@ -0,0 +1,4 @@ +from .app.main import main + +if __name__ == "__main__": + main() diff --git a/src/nrtk_explorer/app/__init__.py b/src/nrtk_explorer/app/__init__.py index 2779e152..e69de29b 100644 --- a/src/nrtk_explorer/app/__init__.py +++ b/src/nrtk_explorer/app/__init__.py @@ -1,5 +0,0 @@ -from .main import main - -__all__ = [ - "main", -] diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 156622df..32769ba3 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -1,13 +1,13 @@ import logging from typing import Iterable -from pathlib import Path from trame.widgets import html from trame_server.utils.namespace import Translator -from nrtk_explorer.library import images_manager from nrtk_explorer.library.filtering import FilterProtocol from nrtk_explorer.library.dataset import get_dataset, get_image_fpath +from nrtk_explorer.library.debounce import debounce +from nrtk_explorer.app.images.images import Images from nrtk_explorer.app.embeddings import EmbeddingsApp from nrtk_explorer.app.transforms import TransformsApp from nrtk_explorer.app.filtering import FilteringApp @@ -25,13 +25,13 @@ html.Template.slot_names.add("before") html.Template.slot_names.add("after") -HORIZONTAL_SPLIT_DEFAULT_VALUE = 17 -VERTICAL_SPLIT_DEFAULT_VALUE = 40 DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__) DEFAULT_DATASETS = [ f"{DIR_NAME}/coco-od-2017/test_val2017.json", ] +NUM_IMAGES_DEFAULT = 500 +NUM_IMAGES_DEBOUNCE_TIME = 0.3 # seconds # --------------------------------------------------------- @@ -52,41 +52,18 @@ def __init__(self, server=None): known_args, _ = self.server.cli.parse_known_args() self.input_paths = known_args.dataset - self.state.current_dataset = str(Path(self.input_paths[0]).resolve()) + self.state.current_dataset = self.input_paths[0] self.ctrl.get_image_fpath = lambda i: get_image_fpath(i, self.state.current_dataset) - - self.context["image_objects"] = {} - self.context["images_manager"] = images_manager.ImagesManager() - - self.state.collapse_dataset = False - self.state.collapse_embeddings = False - self.state.collapse_filter = False - self.state.collapse_transforms = False - self.state.client_only( - "collapse_dataset", "collapse_embeddings", "collapse_filter", "collapse_transforms" - ) - - self.state.horizontal_split = HORIZONTAL_SPLIT_DEFAULT_VALUE - self.state.vertical_split = VERTICAL_SPLIT_DEFAULT_VALUE - self.state.client_only("horizontal_split", "vertical_split") - - transforms_translator = Translator() - transforms_translator.add_translation( - "feature_extraction_model", "current_transforms_model" - ) + images = Images(server=self.server) self._transforms_app = TransformsApp( - server=self.server.create_child_server(translator=transforms_translator) - ) - - embeddings_translator = Translator() - embeddings_translator.add_translation( - "feature_extraction_model", "current_embeddings_model" + server=self.server.create_child_server(), images=images ) self._embeddings_app = EmbeddingsApp( - server=self.server.create_child_server(translator=embeddings_translator), + server=self.server.create_child_server(), + images=images, ) filtering_translator = Translator() @@ -95,15 +72,11 @@ def __init__(self, server=None): server=self.server.create_child_server(translator=filtering_translator), ) - self._embeddings_app.set_on_select(self._transforms_app.set_selected_dataset_ids) self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations) self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered) self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered) self._filtering_app.set_on_apply_filter(self.on_filter_apply) - # Set state variable - self.state.trame__title = "nrtk_explorer" - # Bind instance methods to controller self.ctrl.on_server_reload = self._build_ui self.ctrl.add("on_server_ready")(self.on_server_ready) @@ -112,76 +85,66 @@ def __init__(self, server=None): self.state.num_images_disabled = True self.state.random_sampling = False self.state.random_sampling_disabled = True - self.state.images_id = [] + self.state.dataset_ids = [] + self.state.hovered_id = None + + def clear_hovered(**kwargs): + self.state.hovered_id = None + + self.state.change("dataset_ids")(clear_hovered) self._build_ui() def on_server_ready(self, *args, **kwargs): # Bind instance methods to state change self.state.change("current_dataset")(self.on_dataset_change) - self.state.change("num_images")(self.on_num_images_change) - self.state.change("random_sampling")(self.on_random_sampling_change) + self.state.change("num_images")( + debounce(NUM_IMAGES_DEBOUNCE_TIME, self.state)(self.resample_images) + ) + self.state.change("random_sampling")(self.resample_images) self.on_dataset_change() def on_dataset_change(self, **kwargs): - # Reset cache - self.context.images_manager = images_manager.ImagesManager() + self.state.dataset_ids = [] # sampled images self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True) self.state.num_images_max = len(self.context.dataset.imgs) + self.state.num_images = min(self.state.num_images_max, NUM_IMAGES_DEFAULT) + self.state.dirty("num_images") # Trigger resample_images() self.state.random_sampling_disabled = False self.state.num_images_disabled = False - self.reload_images() + self.state.annotation_categories = { + category["id"]: category for category in self.context.dataset.cats.values() + } def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs): - selected_indices = [] - for index, image_id in enumerate(self.state.images_ids): + selected_ids = [] + for dataset_id in self.state.dataset_ids: image_annotations_categories = [ annotation["category_id"] for annotation in self.context.dataset.anns.values() - if annotation["image_id"] == image_id + if annotation["image_id"] == int(dataset_id) ] include = filter.evaluate(image_annotations_categories) if include: - selected_indices.append(index) + selected_ids.append(dataset_id) - self._embeddings_app.on_select(selected_indices) - - def on_num_images_change(self, **kwargs): - self.reload_images() - - def on_random_sampling_change(self, **kwargs): - self.reload_images() - - def reload_images(self): - categories = {} - for category in self.context.dataset.cats.values(): - categories[category["id"]] = category + self._embeddings_app.on_select(selected_ids) + def resample_images(self, **kwargs): images = list(self.context.dataset.imgs.values()) selected_images = [] if self.state.num_images: if self.state.random_sampling: - selected_images = random.sample(images, self.state.num_images) + selected_images = random.sample(images, min(len(images), self.state.num_images)) else: selected_images = images[: self.state.num_images] else: selected_images = images - paths = list() - for image in selected_images: - paths.append( - os.path.join( - os.path.dirname(self.state.current_dataset), - image["file_name"], - ) - ) - - self.context.paths = paths - self.state.annotation_categories = categories - self.state.images_ids = [img["id"] for img in selected_images] + self.state.dataset_ids = [str(img["id"]) for img in selected_images] def _build_ui(self): extra_args = {} @@ -189,7 +152,7 @@ def _build_ui(self): ui.reload(ui) extra_args["reload"] = self._build_ui - self.ui = ui.build_layout( + self.ui = ui.NrtkExplorerLayout( server=self.server, dataset_paths=self.input_paths, embeddings_app=self._embeddings_app, diff --git a/src/nrtk_explorer/app/embeddings.py b/src/nrtk_explorer/app/embeddings.py index 8a6bdd08..40d59c35 100644 --- a/src/nrtk_explorer/app/embeddings.py +++ b/src/nrtk_explorer/app/embeddings.py @@ -1,42 +1,46 @@ from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot from nrtk_explorer.library import embeddings_extractor from nrtk_explorer.library import dimension_reducers -from nrtk_explorer.library import images_manager from nrtk_explorer.library.dataset import get_dataset from nrtk_explorer.app.applet import Applet -import nrtk_explorer.test_data -import asyncio -import os +from nrtk_explorer.app.images.image_ids import ( + image_id_to_dataset_id, + dataset_id_to_transformed_image_id, + dataset_id_to_image_id, + is_transformed, +) +from nrtk_explorer.app.images.images import Images + +from pathlib import Path from trame.widgets import quasar, html from trame.ui.quasar import QLayout from trame.app import get_server, asynchronous -os.environ["TRAME_DISABLE_V3_WARNING"] = "1" - -DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__) -DATASET_DIRS = [ - f"{DIR_NAME}/OIRDS_v1_0/oirds.json", - f"{DIR_NAME}/OIRDS_v1_0/oirds_test.json", - f"{DIR_NAME}/OIRDS_v1_0/oirds_train.json", -] - - class EmbeddingsApp(Applet): - def __init__(self, server): + def __init__( + self, + server, + datasets=None, + images=None, + ): super().__init__(server) + self._dataset_paths = datasets + self.images = images or Images(server) + self._on_hover_fn = None self._ui = None - self._on_select_fn = None self.reducer = dimension_reducers.DimReducerManager() - self.is_standalone_app = self.server.state.parent is None - if self.is_standalone_app: - self.context.images_manager = images_manager.ImagesManager() - if self.state.current_dataset is None: - self.state.current_dataset = DATASET_DIRS[0] + # Local initialization if standalone + self.is_standalone_app = self.server.root_server == self.server + if self.is_standalone_app and datasets: + self.state.dataset_ids = [] + self.state.current_dataset = datasets[0] + self.on_current_dataset_change() + self.features = None self.state.client_only("camera_position") @@ -44,18 +48,31 @@ def __init__(self, server): self.server.controller.add("on_server_ready")(self.on_server_ready) self.transformed_images_cache = {} + self.state.highlighted_image = { + "id": "", + "is_transformed": True, + } def on_server_ready(self, *args, **kwargs): # Bind instance methods to state change self.on_current_dataset_change() - self.on_feature_extraction_model_change() self.state.change("current_dataset")(self.on_current_dataset_change) + + self.on_feature_extraction_model_change() self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) + self.update_points() + self.state.change("dataset_ids")(self.update_points) + + self.server.controller.apply_transform.add(self.clear_points_transformations) + self.state.change("transform_enabled_switch")( + self.update_points_transformations_visibility + ) + def on_feature_extraction_model_change(self, **kwargs): feature_extraction_model = self.state.feature_extraction_model self.extractor = embeddings_extractor.EmbeddingsExtractor( - model_name=feature_extraction_model, manager=self.context.images_manager + model_name=feature_extraction_model ) def on_current_dataset_change(self, **kwargs): @@ -63,125 +80,102 @@ def on_current_dataset_change(self, **kwargs): if self.context.dataset is None: self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True) - self.images = list(self.context.dataset.imgs.values()) - self.state.num_elements_max = len(self.images) + self.state.num_elements_max = len(list(self.context.dataset.imgs)) self.state.num_elements_disabled = False - if self.is_standalone_app: - self.context.images_manager = images_manager.ImagesManager() - - def on_run_clicked(self): - self.state.is_loading = True - asynchronous.create_task(self.compute(self.compute_source_points)) - - async def compute(self, method): - # We need to yield twice for the is_loading=True to commit to the trame state - # before this routine ends - await asyncio.sleep(0) - await asyncio.sleep(0) - await method() - with self.state: - self.state.is_loading = False - - async def compute_source_points(self): - self.features = self.extractor.extract( - paths=self.context.paths, - batch_size=int(self.state.model_batch_size), - ) + def compute_points(self, fit_features, features): + if len(features) == 0: + # reduce will fail if no features + return [] if self.state.tab == "PCA": - self.state.points_sources = self.reducer.reduce( + return self.reducer.reduce( name="PCA", - fit_features=self.features, - features=self.features, + fit_features=fit_features, + features=features, dims=self.state.dimensionality, whiten=self.state.pca_whiten, solver=self.state.pca_solver, ) - elif self.state.tab == "UMAP": - args = {} - if self.state.umap_random_seed: - args["random_state"] = int(self.state.umap_random_seed_value) + # must be UMAP + args = {} + if self.state.umap_random_seed: + args["random_state"] = int(self.state.umap_random_seed_value) - if self.state.umap_n_neighbors: - args["n_neighbors"] = int(self.state.umap_n_neighbors_number) + if self.state.umap_n_neighbors: + args["n_neighbors"] = int(self.state.umap_n_neighbors_number) - self.state.points_sources = self.reducer.reduce( - name="UMAP", - dims=self.state.dimensionality, - fit_features=self.features, - features=self.features, - **args, - ) + return self.reducer.reduce( + name="UMAP", + fit_features=fit_features, + features=features, + dims=self.state.dimensionality, + **args, + ) - # Unselect current selection of images - if self._on_select_fn: - self._on_select_fn([]) + def clear_points_transformations(self, **kwargs): + self.state.points_transformations = {} # ID to point + self._stashed_points_transformations = {} - self.state.points_transformations = [] - self.state.user_selected_points_indices = [] - self.state.camera_position = [] + def update_points_transformations_visibility(self, **kwargs): + if self.state.transform_enabled_switch: + self.state.points_transformations = self._stashed_points_transformations + else: + self._stashed_points_transformations = self.state.points_transformations + self.state.points_transformations = {} - def on_run_transformations(self, transformed_image_ids): - # Fillup the cache with the transformed images - for img_id in transformed_image_ids: - img = self.context.image_objects[img_id] - img = self.context.images_manager.prepare_for_model(img) - self.transformed_images_cache[img_id] = img + async def compute_source_points(self): + with self.state: + self.state.is_loading = True - transformation_features = self.extractor.extract( - paths=transformed_image_ids, - content=self.transformed_images_cache, + # Don't lock server before enabling the spinner on client + await self.server.network_completion + + images = [ + self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids + ] + self.features = self.extractor.extract( + images, batch_size=int(self.state.model_batch_size), ) - if self.state.tab == "PCA": - self.state.points_transformations = self.reducer.reduce( - name="PCA", - fit_features=self.features, - features=transformation_features, - dims=self.state.dimensionality, - whiten=self.state.pca_whiten, - solver=self.state.pca_solver, - ) + points = self.compute_points(self.features, self.features) - elif self.state.tab == "UMAP": - args = {} - if self.state.umap_random_seed: - args["random_state"] = int(self.state.umap_random_seed_value) + self.state.points_sources = { + id: point for id, point in zip(self.state.dataset_ids, points) + } - if self.state.umap_n_neighbors: - args["n_neighbors"] = int(self.state.umap_n_neighbors_number) + self.clear_points_transformations() - self.state.points_transformations = self.reducer.reduce( - name="UMAP", - dims=self.state.dimensionality, - fit_features=self.features, - features=transformation_features, - **args, - ) + self.state.user_selected_ids = [] + self.state.camera_position = [] + + with self.state: + self.state.is_loading = False - def set_on_select(self, fn): - self._on_select_fn = fn - - def on_select(self, indices): - # remap transformed indices to original indices - original_indices = set() - for point_index in indices: - original_image_point_index = point_index - if point_index >= len(self.state.points_sources): - original_image_point_index = self.state.user_selected_points_indices[ - point_index - len(self.state.points_sources) - ] - original_indices.add(original_image_point_index) - original_indices = list(original_indices) - - self.state.user_selected_points_indices = original_indices - self.state.points_transformations = [] - ids = [self.state.images_ids[i] for i in original_indices] - if self._on_select_fn: - self._on_select_fn(ids) + def update_points(self, **kwargs): + if hasattr(self, "_update_task"): + self._update_task.cancel() + self._update_task = asynchronous.create_task(self.compute_source_points()) + + def on_run_clicked(self): + self.update_points() + + def on_run_transformations(self, id_to_image): + transformation_features = self.extractor.extract( + id_to_image.values(), + batch_size=int(self.state.model_batch_size), + ) + + points = self.compute_points(self.features, transformation_features) + + ids = id_to_image.keys() + updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)} + self.state.points_transformations = {**self.state.points_transformations, **updated_points} + + def on_select(self, image_ids): + self.state.user_selected_ids = image_ids def on_move(self, camera_position): self.state.camera_position = camera_position @@ -189,53 +183,37 @@ def on_move(self, camera_position): def set_on_hover(self, fn): self._on_hover_fn = fn - def on_point_hover(self, point_index): - self.state.highlighted_point = point_index - image_id = "" - if point_index is not None: - original_image_point_index = point_index - if point_index >= len(self.state.points_sources): - image_kind = "transformed_img_" - original_image_point_index = self.state.user_selected_points_indices[ - point_index - len(self.state.points_sources) - ] - else: - image_kind = "img_" - dataset_id = self.state.images_ids[original_image_point_index] - image_id = f"{image_kind}{dataset_id}" - - if self._on_hover_fn: - self._on_hover_fn(image_id) - - def on_image_hovered(self, id_): - # If the point is in the list of selected points, we set it as the highlighted point - is_transformation = id_.startswith("transformed_img_") - try: - dataset_id = int(id_.split("_")[-1]) # img_123 or transformed_img_123 -> 123 - except ValueError: - # id_ probably is an empty string - dataset_id = id_ - if dataset_id in self.state.images_ids: - index = self.state.images_ids.index(dataset_id) - if is_transformation: - index_selected = self.state.user_selected_points_indices.index(index) - self.state.highlighted_point = len(self.state.points_sources) + index_selected - else: - self.state.highlighted_point = index + def get_dataset_id_index(self, point_index): + if point_index < len(self.state.dataset_ids): + return point_index + return point_index - len(self.state.dataset_ids) + + def on_point_hover(self, event): + self.state.highlighted_image = event + if not self._on_hover_fn: + return + if event["is_transformed"]: + image_id = dataset_id_to_transformed_image_id(event["id"]) else: - # If the point is not in the list of selected points, we set it to a negative point - self.state.highlighted_point = -1 + image_id = dataset_id_to_image_id(event["id"]) + self._on_hover_fn(image_id) + + def on_image_hovered(self, image_id): + self.state.highlighted_image = { + "id": image_id_to_dataset_id(image_id), + "is_transformed": is_transformed(image_id), + } def visualization_widget(self): ScatterPlot( cameraMove="camera_position=$event", cameraPosition=("camera_position",), - highlightedPoint=("highlighted_point", -1), + highlightedPoint=("highlighted_image",), hover=(self.on_point_hover, "[$event]"), - points=("points_sources", []), - transformedPoints=("points_transformations", []), + points=("points_sources", {}), + transformedPoints=("points_transformations", {}), select=(self.on_select, "[$event]"), - selectedPoints=("user_selected_points_indices", []), + selectedPoints=("user_selected_ids", []), ) def settings_widget(self): @@ -371,9 +349,10 @@ def ui(self): label="Dataset", v_model=("current_dataset",), options=( + "dataset_options", [ - {"label": "oirds_test", "value": DATASET_DIRS[0]}, - {"label": "oirds_train", "value": DATASET_DIRS[1]}, + {"label": Path(p).name, "value": p} + for p in self._dataset_paths ], ), filled=True, @@ -394,15 +373,22 @@ def ui(self): return self._ui -def embeddings(server=None, *args, **kwargs): - server = get_server() - server.client_type = "vue3" +def main(server=None, *args, **kwargs): + server = get_server(client_type="vue3") + server.cli.add_argument( + "--dataset", + nargs="+", + required=True, + help="Path of the json file describing the image dataset", + ) + + known_args, _ = server.cli.parse_known_args() - embeddings_app = EmbeddingsApp(server) + embeddings_app = EmbeddingsApp(server, known_args.dataset) embeddings_app.ui server.start(**kwargs) if __name__ == "__main__": - embeddings() + main() diff --git a/src/nrtk_explorer/app/filtering.py b/src/nrtk_explorer/app/filtering.py index e70a8340..fdad02aa 100644 --- a/src/nrtk_explorer/app/filtering.py +++ b/src/nrtk_explorer/app/filtering.py @@ -50,7 +50,7 @@ def on_server_ready(self, *args, **kwargs): self.state.change("filter_categories")(self.on_filter_categories_change) self.state.change("filter_operator")(self.on_filter_categories_change) - def on_apply_click(self): + def on_select_click(self): if self.state.filter_not: self._on_apply_filter(self._not_filter) else: @@ -86,8 +86,8 @@ def filter_operator_ui(self): def filter_apply_ui(self): with html.Div(trame_server=self.server): quasar.QBtn( - "Apply", - click=(self.on_apply_click,), + "Select Images", + click=(self.on_select_click,), flat=True, ) @@ -117,9 +117,8 @@ def ui(self): return self._ui -def filtering(server=None, *args, **kwargs): - server = get_server() - server.client_type = "vue3" +def main(server=None, *args, **kwargs): + server = get_server(client_type="vue3") app = FilteringApp(server) @@ -150,4 +149,4 @@ def on_apply_filter(filter: FilterProtocol): if __name__ == "__main__": - filtering() + main() diff --git a/src/nrtk_explorer/app/image_ids.py b/src/nrtk_explorer/app/image_ids.py deleted file mode 100644 index 49af67e9..00000000 --- a/src/nrtk_explorer/app/image_ids.py +++ /dev/null @@ -1,14 +0,0 @@ -def image_id_to_dataset_id(image_id: str): - return image_id.split("_")[-1] - - -def dataset_id_to_image_id(dataset_id: str): - return f"img_{dataset_id}" - - -def dataset_id_to_transformed_image_id(dataset_id: str): - return f"transformed_img_{dataset_id}" - - -def image_id_to_result_id(image_id: str): - return f"result_{image_id}" diff --git a/src/nrtk_explorer/app/image_server.py b/src/nrtk_explorer/app/image_server.py deleted file mode 100644 index 871216d4..00000000 --- a/src/nrtk_explorer/app/image_server.py +++ /dev/null @@ -1,54 +0,0 @@ -from aiohttp import web -from PIL import Image -import io -from trame.app import get_server - - -ORIGINAL_IMAGE_ENDPOINT = "original-image" - -server = get_server() - - -def is_browser_compatible_image(file_path): - # Check if the image format is compatible with web browsers - compatible_formats = {"jpg", "jpeg", "png", "gif", "webp"} - return file_path.split(".")[-1].lower() in compatible_formats - - -def make_response(image, format): - bytes_io = io.BytesIO() - image.save(bytes_io, format=format) - bytes_io.seek(0) - return web.Response(body=bytes_io.read(), content_type=f"image/{format.lower()}") - - -async def original_image_endpoint(request: web.Request): - id = request.match_info["id"] - image_path = server.controller.get_image_fpath(int(id)) - - if image_path in server.context.images_manager.images: - image = server.context.images_manager.images[image_path] - send_format = "PNG" - if is_browser_compatible_image(image.format): - send_format = image.format.upper() - return make_response(image, send_format) - - if is_browser_compatible_image(image_path): - return web.FileResponse(image_path) - else: - image = Image.open(image_path) - return make_response(image, "PNG") - - -image_routes = [ - web.get(f"/{ORIGINAL_IMAGE_ENDPOINT}/{{id}}", original_image_endpoint), -] - - -def app_available(wslink_server): - """Add our custom REST endpoints to the trame server.""" - wslink_server.app.add_routes(image_routes) - - -# --hot-reload does not work if this is configured as decorator on the function -server.controller.add("on_server_bind")(app_available) diff --git a/src/nrtk_explorer/app/images/annotations.py b/src/nrtk_explorer/app/images/annotations.py new file mode 100644 index 00000000..866a8ae3 --- /dev/null +++ b/src/nrtk_explorer/app/images/annotations.py @@ -0,0 +1,88 @@ +from typing import Dict, Sequence +from functools import lru_cache, partial +from PIL import Image +from nrtk_explorer.app.images.cache import LruCache +from nrtk_explorer.library.object_detector import ObjectDetector +from nrtk_explorer.library.coco_utils import partition + + +ANNOTATION_CACHE_SIZE = 1000 + + +class DeleteCallbackRef: + def __init__(self, del_callback, value): + self.del_callback = del_callback + self.value = value + + def __del__(self): + self.del_callback() + + +def get_annotations_from_dataset( + context, add_to_cache_callback, delete_from_cache_callback, dataset_id: str +): + dataset = context.dataset + annotations = [ + annotation + for annotation in dataset.anns.values() + if str(annotation["image_id"]) == dataset_id + ] + add_to_cache_callback(dataset_id, annotations) + with_id = partial(delete_from_cache_callback, dataset_id) + return DeleteCallbackRef(with_id, annotations) + + +class GroundTruthAnnotations: + def __init__( + self, + context, # for dataset + add_to_cache_callback, + delete_from_cache_callback, + ): + with_callbacks = partial( + get_annotations_from_dataset, + context, + add_to_cache_callback, + delete_from_cache_callback, + ) + self.get_annotations_for_image = lru_cache(maxsize=ANNOTATION_CACHE_SIZE)(with_callbacks) + + def get_annotations(self, dataset_ids: Sequence[str]): + return { + dataset_id: self.get_annotations_for_image(dataset_id).value + for dataset_id in dataset_ids + } + + def cache_clear(self): + self.get_annotations_for_image.cache_clear() + + +class DetectionAnnotations: + def __init__( + self, + add_to_cache_callback, + delete_from_cache_callback, + ): + self.cache = LruCache(ANNOTATION_CACHE_SIZE) + self.add_to_cache_callback = add_to_cache_callback + self.delete_from_cache_callback = delete_from_cache_callback + + def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]): + hits, misses = partition(self.cache.get_item, id_to_image.keys()) + cached_predictions = {id: self.cache.get_item(id) for id in hits} + + to_detect = {id: id_to_image[id] for id in misses} + predictions = detector.eval( + to_detect, + ) + for id, annotations in predictions.items(): + self.cache.add_item( + id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback + ) + + predictions.update(**cached_predictions) + # match input order because of scoring code assumptions + return {id: predictions[id] for id in id_to_image.keys()} + + def cache_clear(self): + self.cache.clear() diff --git a/src/nrtk_explorer/app/images/cache.py b/src/nrtk_explorer/app/images/cache.py new file mode 100644 index 00000000..d8c97aa2 --- /dev/null +++ b/src/nrtk_explorer/app/images/cache.py @@ -0,0 +1,93 @@ +from typing import Any, Callable, List, NamedTuple +from collections import OrderedDict + +Item = Any + + +class CacheItem(NamedTuple): + item: Item + on_add_item_callbacks: List[Callable[[str, Item], None]] + on_clear_item_callbacks: List[Callable[[str], None]] + + +def noop(*args, **kwargs): + pass + + +class LruCache: + """ + Least recently accessed item is removed when the cache is full. + Per item callbacks are called when an item is added or cleared. + Useful for side effects like updating the trame state. + """ + + def __init__(self, max_size: int): + self.cache: OrderedDict[str, CacheItem] = OrderedDict() + self.max_size = max_size + + def _cache_full(self): + return len(self.cache) >= self.max_size + + def add_item( + self, + key: str, + item: Item, + on_add_item: Callable[[str, Any], None] = noop, + on_clear_item: Callable[[str], None] = noop, + ): + """ + Add an item to the cache. + Runs on_add_item callback if callback does not exist in current item callbacks list or item is new + """ + cache_item = self.cache.get(key) + if cache_item and cache_item.item != item: + # stale cached item, clear it + self._clear_item(key) + cache_item = None + + if self._cache_full(): + oldest = next(iter(self.cache)) + self._clear_item(oldest) + + if cache_item: + # Update callbacks list only if they are not already present + if on_add_item not in cache_item.on_add_item_callbacks: + cache_item.on_add_item_callbacks.append(on_add_item) + on_add_item(key, item) + if on_clear_item not in cache_item.on_clear_item_callbacks: + cache_item.on_clear_item_callbacks.append(on_clear_item) + else: + # Create a new CacheItem and add it to the cache + cache_item = CacheItem( + item=item, + on_add_item_callbacks=[on_add_item], + on_clear_item_callbacks=[on_clear_item], + ) + self.cache[key] = cache_item + on_add_item(key, item) + + self.cache.move_to_end(key) + + def add_if_room(self, key: str, item: Item, **kwargs): + """Does not remove items from cache, only adds.""" + if not self._cache_full(): + self.add_item(key, item, **kwargs) + + def get_item(self, key: str): + """Retrieve an item from the cache.""" + if key in self.cache: + self.cache.move_to_end(key) + return self.cache[key].item + return None + + def _clear_item(self, key: str): + """Remove a specific item from the cache.""" + if key in self.cache: + for callback in self.cache[key].on_clear_item_callbacks: + callback(key) + del self.cache[key] + + def clear(self): + """Clear the cache.""" + for key in list(self.cache.keys()): + self._clear_item(key) diff --git a/src/nrtk_explorer/app/images/image_ids.py b/src/nrtk_explorer/app/images/image_ids.py new file mode 100644 index 00000000..2c8844dc --- /dev/null +++ b/src/nrtk_explorer/app/images/image_ids.py @@ -0,0 +1,34 @@ +from nrtk_explorer.app.images.image_meta import dataset_id_to_meta + + +def image_id_to_dataset_id(image_id: str): + return image_id.split("_")[-1] + + +def dataset_id_to_image_id(dataset_id: str): + return f"img_{dataset_id}" + + +def dataset_id_to_transformed_image_id(dataset_id: str): + return f"transformed_img_{dataset_id}" + + +def image_id_to_result_id(image_id: str): + return f"result_{image_id}" + + +def is_transformed(image_id: str): + return image_id.startswith("transformed_img_") + + +def get_image_state_keys(dataset_id: str): + return { + "original_image": dataset_id_to_image_id(dataset_id), + "ground_truth": image_id_to_result_id(dataset_id), + "original_image_detection": image_id_to_result_id(dataset_id_to_image_id(dataset_id)), + "transformed_image": dataset_id_to_transformed_image_id(dataset_id), + "transformed_image_detection": image_id_to_result_id( + dataset_id_to_transformed_image_id(dataset_id) + ), + "meta_id": dataset_id_to_meta(dataset_id), + } diff --git a/src/nrtk_explorer/app/image_meta.py b/src/nrtk_explorer/app/images/image_meta.py similarity index 87% rename from src/nrtk_explorer/app/image_meta.py rename to src/nrtk_explorer/app/images/image_meta.py index 96bb3100..bd2a4673 100644 --- a/src/nrtk_explorer/app/image_meta.py +++ b/src/nrtk_explorer/app/images/image_meta.py @@ -4,7 +4,7 @@ ImageMetaId = str -def image_id_to_meta(image_id: str) -> ImageMetaId: +def dataset_id_to_meta(image_id: str) -> ImageMetaId: return f"meta_{image_id}" @@ -26,7 +26,7 @@ class DatasetImageMeta(TypedDict): def update_image_meta(state, dataset_id: str, meta_patch: PartialDatasetImageMeta): - meta_key = image_id_to_meta(dataset_id) + meta_key = dataset_id_to_meta(dataset_id) current_meta = {} if state.has(meta_key) and state[meta_key] is not None: current_meta = state[meta_key] @@ -34,5 +34,5 @@ def update_image_meta(state, dataset_id: str, meta_patch: PartialDatasetImageMet def delete_image_meta(state, dataset_id: str): - meta_key = image_id_to_meta(dataset_id) + meta_key = dataset_id_to_meta(dataset_id) delete_state(state, meta_key) diff --git a/src/nrtk_explorer/app/images/images.py b/src/nrtk_explorer/app/images/images.py new file mode 100644 index 00000000..0f9102a1 --- /dev/null +++ b/src/nrtk_explorer/app/images/images.py @@ -0,0 +1,98 @@ +import base64 +import io +from PIL import Image +from trame.decorators import TrameApp, change, controller +from nrtk_explorer.app.images.image_ids import ( + dataset_id_to_image_id, + dataset_id_to_transformed_image_id, +) +from nrtk_explorer.app.trame_utils import delete_state +from nrtk_explorer.app.images.cache import LruCache +from nrtk_explorer.library.transforms import ImageTransform + + +def convert_to_base64(img: Image.Image) -> str: + """Convert image to base64 string""" + buf = io.BytesIO() + img.save(buf, format="png") + return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() + + +IMAGE_CACHE_SIZE = 200 + + +@TrameApp() +class Images: + def __init__(self, server): + self.server = server + self.original_images = LruCache( + IMAGE_CACHE_SIZE, + ) + self.transformed_images = LruCache( + IMAGE_CACHE_SIZE, + ) + + def _load_image(self, dataset_id: str): + image_path = self.server.controller.get_image_fpath(int(dataset_id)) + return Image.open(image_path) + + def get_image(self, dataset_id: str, **kwargs): + """For cache side effects pass on_add_item and on_clear_item callbacks as kwargs""" + image_id = dataset_id_to_image_id(dataset_id) + image = self.original_images.get_item(image_id) or self._load_image(dataset_id) + self.original_images.add_item(image_id, image, **kwargs) + return image + + def get_stateful_image(self, dataset_id: str): + return self.get_image( + dataset_id, on_add_item=self._add_image_to_state, on_clear_item=self._delete_from_state + ) + + def _add_image_to_state(self, image_id: str, image: Image.Image): + self.server.state[image_id] = convert_to_base64(image) + + def _delete_from_state(self, state_key: str): + delete_state(self.server.state, state_key) + + def get_image_without_cache_eviction(self, dataset_id: str): + """ + Does not remove items from cache, only adds. + For computing metrics on all images. + """ + image_id = dataset_id_to_image_id(dataset_id) + image = self.original_images.get_item(image_id) or self._load_image(dataset_id) + self.original_images.add_if_room(image_id, image) + return image + + def _load_transformed_image(self, transform: ImageTransform, dataset_id: str): + original = self.get_image_without_cache_eviction(dataset_id) + transformed = transform.execute(original) + # So pixel-wise annotation similarity score works + if original.size != transformed.size: + return transformed.resize(original.size) + return transformed + + def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs): + image_id = dataset_id_to_transformed_image_id(dataset_id) + image = self.transformed_images.get_item(image_id) or self._load_transformed_image( + transform, dataset_id + ) + self.transformed_images.add_item(image_id, image, **kwargs) + return image + + def get_stateful_transformed_image(self, transform: ImageTransform, dataset_id: str): + return self.get_transformed_image( + transform, + dataset_id, + on_add_item=self._add_image_to_state, + on_clear_item=self._delete_from_state, + ) + + @change("current_dataset") + def clear_all(self, **kwargs): + self.original_images.clear() + self.clear_transformed() + + @controller.add("apply_transform") + def clear_transformed(self, **kwargs): + self.transformed_images.clear() diff --git a/src/nrtk_explorer/app/images/stateful_annotations.py b/src/nrtk_explorer/app/images/stateful_annotations.py new file mode 100644 index 00000000..55ea3080 --- /dev/null +++ b/src/nrtk_explorer/app/images/stateful_annotations.py @@ -0,0 +1,87 @@ +from functools import partial +from typing import Any, Union, Callable +from .annotations import GroundTruthAnnotations, DetectionAnnotations +from trame.decorators import TrameApp, change +from nrtk_explorer.app.images.image_ids import ( + image_id_to_result_id, +) +from nrtk_explorer.app.trame_utils import delete_state + + +def add_annotation_to_state(state: Any, image_id: str, annotations: Any): + state[image_id_to_result_id(image_id)] = annotations + + +def delete_annotation_from_state(state: Any, image_id: str): + delete_state(state, image_id_to_result_id(image_id)) + + +def prediction_to_annotations(state, predictions): + annotations = [] + for prediction in predictions: + # if no matching category in dataset JSON, category_id will be None + category_id = None + for cat_id, cat in state.annotation_categories.items(): + if cat["name"] == prediction["label"]: + category_id = cat_id + + bbox = prediction["box"] + annotations.append( + { + "category_id": category_id, + "label": prediction["label"], + "bbox": [ + bbox["xmin"], + bbox["ymin"], + bbox["xmax"] - bbox["xmin"], + bbox["ymax"] - bbox["ymin"], + ], + } + ) + return annotations + + +def add_prediction_to_state(state: Any, image_id: str, prediction: Any): + state[image_id_to_result_id(image_id)] = prediction_to_annotations(state, prediction) + + +AnnotationsFactoryConstructorType = Union[ + Callable[[Callable, Callable], GroundTruthAnnotations], + Callable[[Callable, Callable], DetectionAnnotations], +] + + +@TrameApp() +class StatefulAnnotations: + def __init__( + self, + annotations_factory_constructor: AnnotationsFactoryConstructorType, + server, + add_to_cache_callback=None, + ): + self.server = server + state = self.server.state + add_to_cache_callback = add_to_cache_callback or partial(add_annotation_to_state, state) + delete_from_cache_callback = partial(delete_annotation_from_state, state) + self.annotations_factory = annotations_factory_constructor( + add_to_cache_callback, delete_from_cache_callback + ) + + @change("current_dataset", "object_detection_model") + def _cache_clear(self, **kwargs): + self.annotations_factory.cache_clear() + + +def make_stateful_annotations(server): + return StatefulAnnotations( + partial(GroundTruthAnnotations, server.context), + server, + ) + + +def make_stateful_predictor(server): + return StatefulAnnotations( + DetectionAnnotations, + server, + add_to_cache_callback=partial(add_prediction_to_state, server.state), + ) diff --git a/src/nrtk_explorer/app/parameters.py b/src/nrtk_explorer/app/parameters.py index afda379b..52a178bd 100644 --- a/src/nrtk_explorer/app/parameters.py +++ b/src/nrtk_explorer/app/parameters.py @@ -32,8 +32,6 @@ def __init__(self, server): self.state.transforms = [k for k in self._transforms.keys()] self.state.current_transform = self.state.transforms[0] - self.on_apply_transform = lambda: None - self.server.controller.add("on_server_ready")(self.on_server_ready) self._ui = None @@ -76,7 +74,7 @@ def transform_apply_ui(self): with html.Div(trame_server=self.server): quasar.QBtn( "Apply", - click=(self.on_apply_transform,), + click=(self.server.controller.apply_transform), classes="full-width", flat=True, ) diff --git a/src/nrtk_explorer/app/trame_utils.py b/src/nrtk_explorer/app/trame_utils.py index a57f8e68..b5769323 100644 --- a/src/nrtk_explorer/app/trame_utils.py +++ b/src/nrtk_explorer/app/trame_utils.py @@ -1,4 +1,3 @@ -import asyncio from typing import Hashable, Callable from trame_server.state import State @@ -8,44 +7,27 @@ def delete_state(state: State, key: Hashable): state[key] = None -class SetStateAsync: +def change_checker(state: State, key: str, trigger_check=lambda a, b: a != b): """ Usage:: - async with SetStateAsync(state): - state["key"] = value + @change_checker(self.state, "visible_columns", transformed_became_visible) + def on_apply_transform(old_value, new_value): """ - def __init__(self, state: State): - self.state = state + def decorator(callback: Callable): + old_value = state[key] - async def __aenter__(self): - return self.state + def on_change(): + nonlocal old_value + new_value = state[key] + if trigger_check(old_value, new_value): + callback(old_value, new_value) + old_value = new_value - async def __aexit__(self, exc_type, exc, tb): - self.state.flush() - await asyncio.sleep(0) - await asyncio.sleep(0) - await asyncio.sleep(0) - await asyncio.sleep(0) + def on_state(**kwargs): + on_change() + state.change(key)(on_state) + return callback -def change_checker(state: State, key: str, callback: Callable, trigger_check=lambda a, b: a != b): - """ - Usage:: - change_checker( - self.state, "visible_columns", self.on_apply_transform, tranformed_became_visible - ) - """ - old_value = state[key] - - def on_change(): - nonlocal old_value - new_value = state[key] - if trigger_check(old_value, new_value): - callback() - old_value = new_value - - def on_state(**kwargs): - on_change() - - state.change(key)(on_state) + return decorator diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index 7e6e3bae..a8dcfcbf 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -3,25 +3,21 @@ """ import logging -from typing import Dict, Sequence -from functools import partial -import os +from typing import Dict, Callable from trame.ui.quasar import QLayout from trame.widgets import quasar from trame.widgets import html from trame.app import get_server, asynchronous +from trame_server import Server import nrtk_explorer.library.transforms as trans import nrtk_explorer.library.nrtk_transforms as nrtk_trans -from nrtk_explorer.library import images_manager, object_detector -from nrtk_explorer.app.ui import ImageList, init_state as init_state_image_list +from nrtk_explorer.library import object_detector +from nrtk_explorer.app.ui import ImageList from nrtk_explorer.app.applet import Applet from nrtk_explorer.app.parameters import ParametersApp -from nrtk_explorer.app.image_meta import ( - update_image_meta, - delete_image_meta, -) +from nrtk_explorer.app.images.image_meta import update_image_meta, dataset_id_to_meta from nrtk_explorer.library.coco_utils import ( convert_from_ground_truth_to_first_arg, convert_from_ground_truth_to_second_arg, @@ -29,64 +25,128 @@ convert_from_predictions_to_first_arg, compute_score, ) -import nrtk_explorer.test_data -from nrtk_explorer.app.trame_utils import delete_state, SetStateAsync, change_checker -from nrtk_explorer.app.image_ids import ( - image_id_to_dataset_id, - image_id_to_result_id, +from nrtk_explorer.app.trame_utils import change_checker, delete_state + +from nrtk_explorer.app.images.image_ids import ( dataset_id_to_image_id, dataset_id_to_transformed_image_id, ) from nrtk_explorer.library.dataset import get_dataset -import nrtk_explorer.app.image_server +from nrtk_explorer.app.images.images import Images +from nrtk_explorer.app.images.stateful_annotations import ( + make_stateful_annotations, + make_stateful_predictor, +) +from nrtk_explorer.app.ui.image_list import TRANSFORM_COLUMNS, ORIGINAL_COLUMNS logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__) -DATASET_DIRS = [ - f"{DIR_NAME}/OIRDS_v1_0/oirds.json", - f"{DIR_NAME}/OIRDS_v1_0/oirds_test.json", - f"{DIR_NAME}/OIRDS_v1_0/oirds_train.json", -] +class ProcessingStep: + def __init__( + self, + server: Server, + feature_enabled_state_key: str, + gui_switch_key: str, + column_name: str, + enabled_callback: Callable, + ): + self.state = server.state + self.feature_enabled_state_key = feature_enabled_state_key + self.gui_switch_key = gui_switch_key + self.enabled_callback = enabled_callback + self.column_name = column_name + self.state.change(self.gui_switch_key)(self.on_gui_switch) + self.update_feature_enabled_state() + self.state.change("visible_columns", self.gui_switch_key)( + self.update_feature_enabled_state + ) + self.state.change(self.feature_enabled_state_key)(self.on_change_feature_enabled) + + def on_gui_switch(self, **kwargs): + if self.state[self.gui_switch_key]: + self.state.visible_columns = list(set([*self.state.visible_columns, self.column_name])) + else: + self.state.visible_columns = [ + col for col in self.state.visible_columns if col != self.column_name + ] + + def update_feature_enabled_state(self, **kwargs): + self.state[self.feature_enabled_state_key] = ( + self.column_name in self.state.visible_columns and self.state[self.gui_switch_key] + ) + + def on_change_feature_enabled(self, **kwargs): + if self.state[self.feature_enabled_state_key]: + self.enabled_callback() class TransformsApp(Applet): - def __init__(self, server): + def __init__( + self, + server, + images=None, + ground_truth_annotations=None, + original_detection_annotations=None, + transformed_detection_annotations=None, + ): super().__init__(server) - self.update_image_meta = partial(update_image_meta, self.server.state) + self.images = images or Images(server) - self._parameters_app = ParametersApp( - server=server, + ground_truth_annotations = ground_truth_annotations or make_stateful_annotations(server) + self.ground_truth_annotations = ground_truth_annotations.annotations_factory + + original_detection_annotations = original_detection_annotations or make_stateful_predictor( + server ) + self.original_detection_annotations = original_detection_annotations.annotations_factory - self._parameters_app.on_apply_transform = self.on_apply_transform + transformed_detection_annotations = ( + transformed_detection_annotations or make_stateful_predictor(server) + ) + self.transformed_detection_annotations = ( + transformed_detection_annotations.annotations_factory + ) - self._ui = None + def clear_transformed(**kwargs): + self.transformed_detection_annotations.cache_clear() + for id in self.state.dataset_ids: + update_image_meta( + self.state, + id, + { + "original_detection_to_transformed_detection_score": 0, + "ground_truth_to_transformed_detection_score": 0, + }, + ) - self.is_standalone_app = self.server.state.parent is None - if self.is_standalone_app: - self.context.images_manager = images_manager.ImagesManager() + server.controller.apply_transform.add(clear_transformed) + + # delete score from state of old ids that are not in new + def delete_meta_state(old_ids, new_ids): + if old_ids is not None: + to_clean = set(old_ids) - set(new_ids) + for id in to_clean: + delete_state(self.state, dataset_id_to_meta(id)) + + change_checker(self.state, "dataset_ids")(delete_meta_state) + + self._parameters_app = ParametersApp( + server=server, + ) - if self.context["image_objects"] is None: - self.context["image_objects"] = {} + self._ui = None self._on_transform_fn = None - self.state.models = [ - "ClassificationResNet50", - "ClassificationAlexNet", - "ClassificationVgg16", - ] - self.state.feature_extraction_model = self.state.models[0] self._transforms: Dict[str, trans.ImageTransform] = { - "identity": trans.IdentityTransform(), "blur": trans.GaussianBlurTransform(), "invert": trans.InvertTransform(), "downsample": trans.DownSampleTransform(), + "identity": trans.IdentityTransform(), } if nrtk_trans.nrtk_transforms_available(): @@ -95,47 +155,51 @@ def __init__(self, server): self._parameters_app._transforms = self._transforms - self.state.annotation_categories = {} - - self.context.selected_dataset_ids = [] - self.state.source_image_ids = [] - self.state.transformed_image_ids = [] - self.state.transforms = [k for k in self._transforms.keys()] self.state.current_transform = self.state.transforms[0] - if self.state.current_dataset is None: - self.state.current_dataset = DATASET_DIRS[0] - - self.state.current_num_elements = 15 - - init_state_image_list(self.state) - - def tranformed_became_visible(old, new): - return "transformed" not in old and "transformed" in new + # On annotations enabled, run whole pipeline to possibly compute transforms. Why? Transforms compute scores are based on original images + self.annotations_enable_control = ProcessingStep( + server, + feature_enabled_state_key="predictions_original_images_enabled", + gui_switch_key="annotations_enabled_switch", + column_name=ORIGINAL_COLUMNS[0], + enabled_callback=self._start_update_images, + ) - change_checker( - self.state, "visible_columns", self.on_apply_transform, tranformed_became_visible + self.transform_enable_control = ProcessingStep( + server, + feature_enabled_state_key="transform_enabled", + gui_switch_key="transform_enabled_switch", + column_name=TRANSFORM_COLUMNS[0], + enabled_callback=self._start_transformed_images, ) - self.server.controller.add("on_server_ready")(self.on_server_ready) + self.server.controller.on_server_ready.add(self.on_server_ready) + self.server.controller.apply_transform.add(self.on_apply_transform) self._on_hover_fn = None + self.visible_dataset_ids = [] # set by ImageList via self.on_scroll callback + @property def get_image_fpath(self): return self.server.controller.get_image_fpath def on_server_ready(self, *args, **kwargs): - # Bind instance methods to state change - self.state.change("current_dataset")(self.on_current_dataset_change) - self.state.change("current_num_elements")(self.on_current_num_elements_change) self.state.change("object_detection_model")(self.on_object_detection_model_change) + self.on_object_detection_model_change() + self.state.change("current_dataset")(self.reset_detector) + + def on_object_detection_model_change(self, **kwargs): + self.original_detection_annotations.cache_clear() + self.transformed_detection_annotations.cache_clear() + self.detector = object_detector.ObjectDetector( + model_name=self.state.object_detection_model + ) + self._start_update_images() - self.on_object_detection_model_change(self.state.object_detection_model) - self.on_current_dataset_change(self.state.current_dataset) - - def on_object_detection_model_change(self, model_name, **kwargs): - self.detector = object_detector.ObjectDetector(model_name=model_name) + def reset_detector(self, **kwargs): + self.detector.reset() def set_on_transform(self, fn): self._on_transform_fn = fn @@ -144,146 +208,118 @@ def on_transform(self, *args, **kwargs): if self._on_transform_fn: self._on_transform_fn(*args, **kwargs) - def on_apply_transform(self, *args, **kwargs): - logger.debug("on_apply_transform") + def on_apply_transform(self, **kwargs): + # Turn on switch if user clicked lower apply button + self.state.transform_enabled_switch = True + self._start_transformed_images() + + def _start_transformed_images(self, *args, **kwargs): + logger.debug("_start_transformed_images") if self._updating_images(): - return # update_images will call update_transformed_images() at the end - self.update_transformed_images() + if self._updating_transformed_images: + # computing stale transformed images, restart task + self._update_task.cancel() + else: + return # update_images will call update_transformed_images() at the end + self._update_task = asynchronous.create_task( + self.update_transformed_images(self.visible_dataset_ids) + ) - def update_transformed_images(self): - if not ("transformed" in self.state.visible_columns): - return + async def update_transformed_images(self, dataset_ids): + self._updating_transformed_images = True + try: + await self._update_transformed_images(dataset_ids) + finally: + self._updating_transformed_images = False - transform = self._transforms[self.state.current_transform] - transformed_image_ids = [] - for image_id in self.state.source_image_ids: - image = self.context["image_objects"][image_id] - transformed_image_id = f"transformed_{image_id}" - transformed_img = transform.execute(image) - if image.size != transformed_img.size: - # Resize so pixel-wise annotation similarity score works - transformed_img = transformed_img.resize(image.size) - self.context["image_objects"][transformed_image_id] = transformed_img - transformed_image_ids.append(transformed_image_id) - self.state[transformed_image_id] = images_manager.convert_to_base64(transformed_img) - - self.state.transformed_image_ids = transformed_image_ids - if len(self.state.source_image_ids) > 0: - self.state.hovered_id = "" - - if len(transformed_image_ids) == 0: + async def _update_transformed_images(self, dataset_ids): + if not self.state.transform_enabled: return - result_ids = [image_id_to_result_id(id) for id in transformed_image_ids] - for id in result_ids: - delete_state(self.state, id) + transform = self._transforms[self.state.current_transform] - annotations = self.compute_annotations(transformed_image_ids) + id_to_matching_size_img = {} + for id in dataset_ids: + with self.state: + transformed = self.images.get_stateful_transformed_image(transform, id) + id_to_matching_size_img[dataset_id_to_transformed_image_id(id)] = transformed + await self.server.network_completion - dataset_ids = [image_id_to_dataset_id(id) for id in transformed_image_ids] - predictions = convert_from_predictions_to_second_arg(annotations) - scores = compute_score( - dataset_ids, - self.predictions_source_images, - predictions, - ) - for dataset_id, score in scores: - update_image_meta( - self.state, - dataset_id, - {"original_detection_to_transformed_detection_score": score}, + with self.state: + annotations = self.transformed_detection_annotations.get_annotations( + self.detector, id_to_matching_size_img ) - - ground_truth_annotations = [self.state[image_id_to_result_id(id)] for id in dataset_ids] - ground_truth_predictions = convert_from_ground_truth_to_first_arg(ground_truth_annotations) - scores = compute_score( - dataset_ids, - ground_truth_predictions, - predictions, - ) - for dataset_id, score in scores: - update_image_meta( - self.state, dataset_id, {"ground_truth_to_transformed_detection_score": score} + await self.server.network_completion + + # depends on original images predictions + if self.state.predictions_original_images_enabled: + predictions = convert_from_predictions_to_second_arg(annotations) + scores = compute_score( + dataset_ids, + self.predictions_original_images, + predictions, ) - - # Only invoke callbacks when we transform images - self.on_transform(transformed_image_ids) - - def compute_annotations(self, ids): - """Compute annotations for the given image ids using the object detector model.""" - if len(ids) == 0: - return - - predictions = self.detector.eval( - image_ids=ids, - content=self.context.image_objects, - batch_size=int(self.state.object_detection_batch_size), - ) - - for id_, annotations in predictions.items(): - image_annotations = [] - for prediction in annotations: - category_id = None - # if no matching category in dataset JSON, category_id will be None - for cat_id, cat in self.state.annotation_categories.items(): - if cat["name"] == prediction["label"]: - category_id = cat_id - - bbox = prediction["box"] - image_annotations.append( - { - "category_id": category_id, - "label": prediction["label"], - "bbox": [ - bbox["xmin"], - bbox["ymin"], - bbox["xmax"] - bbox["xmin"], - bbox["ymax"] - bbox["ymin"], - ], - } + for id, score in scores: + update_image_meta( + self.state, + id, + {"original_detection_to_transformed_detection_score": score}, ) - self.state[image_id_to_result_id(id_)] = image_annotations - - return predictions - def on_current_num_elements_change(self, current_num_elements, **kwargs): - ids = [img["id"] for img in self.context.dataset.imgs.values()] - return self.set_source_images(ids[:current_num_elements]) + ground_truth_annotations = self.ground_truth_annotations.get_annotations( + dataset_ids + ).values() + ground_truth_predictions = convert_from_ground_truth_to_first_arg( + ground_truth_annotations + ) + scores = compute_score( + dataset_ids, + ground_truth_predictions, + predictions, + ) + for id, score in scores: + update_image_meta( + self.state, id, {"ground_truth_to_transformed_detection_score": score} + ) - def load_ground_truth_annotations(self, dataset_ids): - # collect annotations for each dataset_id - annotations = { - image_id_to_result_id(dataset_id): [ - annotation - for annotation in self.context.dataset.anns.values() - if str(annotation["image_id"]) == dataset_id - ] - for dataset_id in dataset_ids + id_to_image = { + dataset_id_to_transformed_image_id(id): self.images.get_transformed_image( + transform, id + ) + for id in dataset_ids } - self.state.update(annotations) - def compute_predictions_source_images(self, ids): - """Compute the predictions for the source images.""" + self.on_transform(id_to_image) + + self.state.flush() # needed cuz in async func and modifying state or else UI does not update - if len(ids) == 0: + def compute_predictions_original_images(self, dataset_ids): + if not self.state.predictions_original_images_enabled: return - annotations = self.compute_annotations(ids) + image_id_to_image = { + dataset_id_to_image_id(id): self.images.get_image_without_cache_eviction(id) + for id in dataset_ids + } + annotations = self.original_detection_annotations.get_annotations( + self.detector, image_id_to_image + ) dataset = get_dataset(self.state.current_dataset) - self.predictions_source_images = convert_from_predictions_to_first_arg( + self.predictions_original_images = convert_from_predictions_to_first_arg( annotations, dataset, - ids, + dataset_ids, ) - dataset_ids = [image_id_to_dataset_id(id) for id in ids] - ground_truth_annotations = [self.state[image_id_to_result_id(id)] for id in dataset_ids] + ground_truth_annotations = self.ground_truth_annotations.get_annotations( + dataset_ids + ).values() ground_truth_predictions = convert_from_ground_truth_to_second_arg( ground_truth_annotations, self.context.dataset ) scores = compute_score( dataset_ids, - self.predictions_source_images, + self.predictions_original_images, ground_truth_predictions, ) for dataset_id, score in scores: @@ -291,90 +327,36 @@ def compute_predictions_source_images(self, ids): self.state, dataset_id, {"original_ground_to_original_detection_score": score} ) - async def _update_images(self): - selected_ids = self.context.selected_dataset_ids - loading = len(selected_ids) > 0 - async with SetStateAsync(self.state): - self.state.loading_images = loading - self.state.hovered_id = "" - - for selected_id in selected_ids: - filename = self.get_image_fpath(int(selected_id)) - img = self.context.images_manager.load_image(filename) - image_id = dataset_id_to_image_id(selected_id) - self.context.image_objects[image_id] = img - - async with SetStateAsync(self.state): - # create reactive annotation variables so ImageDetection component has live Ref - for id in selected_ids: - self.state[image_id_to_result_id(id)] = None - self.state[image_id_to_result_id(dataset_id_to_image_id(id))] = None - self.state[image_id_to_result_id(dataset_id_to_transformed_image_id(id))] = None - self.state.source_image_ids = [dataset_id_to_image_id(id) for id in selected_ids] - self.state.loading_images = False # remove big spinner and show table - - async with SetStateAsync(self.state): - self.load_ground_truth_annotations(selected_ids) - - async with SetStateAsync(self.state): - self.compute_predictions_source_images(self.state.source_image_ids) - - async with SetStateAsync(self.state): - self.update_transformed_images() + async def _update_images(self, dataset_ids): + # load images on state for ImageList + for id in dataset_ids: + with self.state: + self.images.get_stateful_image(id) + await self.server.network_completion - def _start_update_images(self): - if hasattr(self, "_update_images_task"): - self._update_images_task.cancel() - self._update_images_task = asynchronous.create_task(self._update_images()) - - def _updating_images(self): - return hasattr(self, "_update_images_task") and not self._update_images_task.done() - - def set_selected_dataset_ids(self, selected_dataset_ids: Sequence[int]): - self.delete_computed_image_data() - self.context.selected_dataset_ids = [str(id) for id in selected_dataset_ids] - self._start_update_images() + with self.state: + self.ground_truth_annotations.get_annotations(dataset_ids) + await self.server.network_completion - def delete_computed_image_data(self): - source_and_transformed = self.state.source_image_ids + self.state.transformed_image_ids - for image_id in source_and_transformed: - delete_state(self.state, image_id) - if image_id in self.context["image_objects"]: - del self.context["image_objects"][image_id] + with self.state: + self.compute_predictions_original_images(dataset_ids) + await self.server.network_completion - for dataset_id in self.context.selected_dataset_ids: - delete_image_meta(self.server.state, dataset_id) + with self.state: + await self.update_transformed_images(dataset_ids) + await self.server.network_completion - ids_with_annotations = ( - self.context.selected_dataset_ids - + self.state.source_image_ids - + self.state.transformed_image_ids - ) - for id in ids_with_annotations: - delete_state(self.state, image_id_to_result_id(id)) - - self.state.source_image_ids = [] - self.state.transformed_image_ids = [] - - def reset_data(self): - self.delete_computed_image_data() - self.state.annotation_categories = {} - - def on_current_dataset_change(self, current_dataset, **kwargs): - logger.debug(f"on_current_dataset_change change {self.state}") - self.reset_data() - - categories = {} - if self.context.dataset is None: - self.context.dataset = get_dataset(current_dataset, force_reload=True) - - for category in self.context.dataset.cats.values(): - categories[category["id"]] = category + def _start_update_images(self): + if hasattr(self, "_update_task"): + self._update_task.cancel() + self._update_task = asynchronous.create_task(self._update_images(self.visible_dataset_ids)) - self.state.annotation_categories = categories + def _updating_images(self): + return hasattr(self, "_update_task") and not self._update_task.done() - if self.is_standalone_app: - self.context.images_manager = images_manager.ImagesManager() + def on_scroll(self, visible_ids): + self.visible_dataset_ids = visible_ids + self._start_update_images() def on_image_hovered(self, id): self.state.hovered_id = id @@ -389,22 +371,17 @@ def on_hover(self, hover_event): self._on_hover_fn(id_) def settings_widget(self): - with html.Div(trame_server=self.server): - with html.Div(classes="col"): - self._parameters_app.transform_select_ui() - - with html.Div( - classes="q-pa-md q-ma-md", - style="border-style: solid; border-width: thin; border-radius: 0.5rem; border-color: lightgray;", - ): - self._parameters_app.transform_params_ui() + with html.Div(classes="col"): + self._parameters_app.transform_select_ui() + with html.Div(classes="q-pa-md q-ma-md"): + self._parameters_app.transform_params_ui() def apply_ui(self): - with html.Div(trame_server=self.server): + with html.Div(): self._parameters_app.transform_apply_ui() def dataset_widget(self): - ImageList(self.on_hover) + ImageList(self.on_scroll, self.on_hover) # This is only used within when this module (file) is executed as an Standalone app. @property @@ -437,7 +414,7 @@ def ui(self): quasar.QSelect( label="Dataset", v_model=("current_dataset",), - options=(DATASET_DIRS,), + options=([],), filled=True, emit_value=True, map_options=True, @@ -461,9 +438,8 @@ def ui(self): return self._ui -def transforms(server=None, *args, **kwargs): - server = get_server() - server.client_type = "vue3" +def main(server=None, *args, **kwargs): + server = get_server(client_type="vue3") transforms_app = TransformsApp(server) transforms_app.ui @@ -472,4 +448,4 @@ def transforms(server=None, *args, **kwargs): if __name__ == "__main__": - transforms() + main() diff --git a/src/nrtk_explorer/app/ui/__init__.py b/src/nrtk_explorer/app/ui/__init__.py index 1b5e0431..7aa1a7c6 100644 --- a/src/nrtk_explorer/app/ui/__init__.py +++ b/src/nrtk_explorer/app/ui/__init__.py @@ -1,6 +1,6 @@ -from .layout import build_layout -from .image_list import ImageList, init_state -from .collapsible_card import card +from .layout import NrtkExplorerLayout +from .image_list import ImageList +from .collapsible_card import CollapsibleCard def reload(m=None): @@ -14,8 +14,7 @@ def reload(m=None): __all__ = [ - "build_layout", + "NrtkExplorerLayout", "ImageList", - "init_state", - "card", + "CollapsibleCard", ] diff --git a/src/nrtk_explorer/app/ui/collapsible_card.py b/src/nrtk_explorer/app/ui/collapsible_card.py index 6203a8cc..6a2bc53f 100644 --- a/src/nrtk_explorer/app/ui/collapsible_card.py +++ b/src/nrtk_explorer/app/ui/collapsible_card.py @@ -2,23 +2,30 @@ from trame.widgets import html -def card(collapse_key): - with quasar.QCard(): - with quasar.QCardSection(): - with html.Div(classes="row items-center no-wrap"): - title_slot = html.Div(classes="col") +class CollapsibleCard(quasar.QCard): + id_count = 0 - with html.Div(classes="col-auto"): - quasar.QBtn( - round=True, - flat=True, - dense=True, - click=f"{collapse_key} = !{collapse_key}", - icon=(f"{collapse_key} ? 'keyboard_arrow_down' : 'keyboard_arrow_up'",), - ) - with quasar.QSlideTransition(): - with html.Div(v_show=f"!{collapse_key}"): - content_slot = quasar.QCardSection() - actions_slot = quasar.QCardActions(align="right") + def __init__(self, name=None, collapsed=False, **kwargs): + super().__init__(**kwargs) - return title_slot, content_slot, actions_slot + if name is None: + CollapsibleCard.id_count += 1 + name = f"is_card_open_{CollapsibleCard.id_count}" + self.state.client_only(name) # keep it local if not provided + + with self: + with quasar.QCardSection(): + with html.Div(classes="row items-center no-wrap"): + self.slot_title = html.Div(classes="col") + with html.Div(classes="col-auto"): + quasar.QBtn( + round=True, + flat=True, + dense=True, + click=f"{name} = !{name}", + icon=(f"{name} ? 'keyboard_arrow_up' : 'keyboard_arrow_down'",), + ) + with quasar.QSlideTransition(): + with html.Div(v_show=(name, not collapsed)): + self.slot_content = quasar.QCardSection() + self.slot_actions = quasar.QCardActions(align="right") diff --git a/src/nrtk_explorer/app/ui/image_list.css b/src/nrtk_explorer/app/ui/image_list.css new file mode 100644 index 00000000..4d660a76 --- /dev/null +++ b/src/nrtk_explorer/app/ui/image_list.css @@ -0,0 +1,11 @@ +.sticky-header { + thead tr th { + position: sticky; + z-index: 1; + background-color: white; + } + + thead tr:first-child th { + top: 0; + } +} diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py index 8355c780..5393a81b 100644 --- a/src/nrtk_explorer/app/ui/image_list.py +++ b/src/nrtk_explorer/app/ui/image_list.py @@ -1,6 +1,11 @@ -from trame.widgets import html, quasar - +from pathlib import Path +from trame.widgets import html, quasar, client +from trame.app import get_server +from nrtk_explorer.app.trame_utils import change_checker from nrtk_explorer.widgets.nrtk_explorer import ImageDetection +from nrtk_explorer.app.images.image_ids import get_image_state_keys + +CSS_FILE = Path(__file__).with_name("image_list.css") COLUMNS = [ {"name": "id", "label": "Dataset ID", "field": "id", "sortable": True}, @@ -32,67 +37,211 @@ ] -def init_state(state): - state.client_only("columns") - state.columns = COLUMNS - state.visible_columns = [col["name"] for col in COLUMNS] +server = get_server() +state, context, ctrl = server.state, server.context, server.controller +state.client_only("columns") +state.columns = COLUMNS +state.visible_columns = [col["name"] for col in COLUMNS] -class ImageList(html.Div): - def __init__(self, hover_fn=None): - super().__init__(classes="col full-height") + +def make_dependent_columns_handler(state, columns): + toggle_column = columns[0] + dependent_columns = columns[1:] + + def column_toggler(old_columns, new_columns): + dependant_columns_visible = any(col in state.visible_columns for col in dependent_columns) + if toggle_column not in state.visible_columns and dependant_columns_visible: + state.visible_columns = [ + col for col in state.visible_columns if col not in dependent_columns + ] + return + + toggle_column_turned_on = toggle_column in new_columns and toggle_column not in old_columns + if toggle_column_turned_on: + state.visible_columns = list(set([*state.visible_columns, *dependent_columns])) + + change_checker(state, "visible_columns")(column_toggler) + + +ORIGINAL_COLUMNS = [ + "original", + "original_ground_to_original_detection_score", +] + + +make_dependent_columns_handler(state, ORIGINAL_COLUMNS) + + +TRANSFORM_COLUMNS = [ + "transformed", + "ground_truth_to_transformed_detection_score", + "original_detection_to_transformed_detection_score", +] + +make_dependent_columns_handler(state, TRANSFORM_COLUMNS) + + +state.client_only("image_size_image_list") + + +def set_image_list_ids(dataset_ids): + # create reactive variables so ImageDetection components have live Refs + for id in dataset_ids: + keys = get_image_state_keys(id) + for key in keys.values(): + if not state.has(key): + state[key] = None + state.image_list_ids = dataset_ids + + +@state.change("dataset_ids", "user_selected_ids") +def update_image_list_ids(**kwargs): + if len(state.user_selected_ids) > 0: + set_image_list_ids(state.user_selected_ids) + else: + set_image_list_ids(state.dataset_ids) + + +state.pagination = {} + + +@state.change("image_list_ids") +def reset_virtual_scroll(**kwargs): + ImageList.reset_view_range() + if state.image_list_view_mode == "grid": + ctrl.get_visible_ids() + + +@state.change("image_list_view_mode") +def update_pagination(**kwargs): + if state.image_list_view_mode == "grid": + state.pagination = {**state.pagination, "rowsPerPage": 12} + ctrl.get_visible_ids() + else: + state.pagination = {**state.pagination, "rowsPerPage": 0} # show all rows + + +class ImageWithSpinner(html.Div): + def __init__( + self, + identifier=None, + src=None, + annotations=None, + categories=None, + selected=None, + hover=None, + containerSelector=None, + **kwargs, + ): + super().__init__( + classes="relative-position", + **kwargs, + ) with self: - ImageTable( - v_if="source_image_ids.length > 0", hover_fn=hover_fn, classes="full-height" - ) - html.Div( - "No images selected", - v_if="source_image_ids.length === 0 && !loading_images", - classes="text-h5 row flex-center q-my-md", + ImageDetection( + identifier=identifier, + src=src, + annotations=(f"show_annotations_on_images ? {annotations[0]} : []",), + categories=categories, + selected=selected, + hover=hover, + containerSelector=containerSelector, ) quasar.QInnerLoading( - showing=("loading_images", False), - label="Loading, transforming, and annotating images...", + showing=(f"!{src[0]} || (show_annotations_on_images && !{annotations[0]}.value)",) ) -class ImageTable(html.Div): - def __init__(self, hover_fn=None, **kwargs): - super().__init__(**kwargs) +class ImageList(html.Div): + instances = [] + + @staticmethod + def reset_view_range(): + for instance in ImageList.instances: + instance.visible_ids = set() + server.js_call(ref="image-list", method="resetVirtualScroll") + + def set_in_view_ids(self, ids): + visible = set(ids) + if self.visible_ids != visible: + self.visible_ids = visible + self.scroll_callback(self.visible_ids) + + def __init__(self, on_scroll, on_hover, **kwargs): + super().__init__(classes="full-height", **kwargs) + ImageList.instances.append(self) + self.visible_ids = set() + self.scroll_callback = on_scroll with self: + client.Style(CSS_FILE.read_text()) + get_visible_ids = client.JSEval( + exec=f''' + ;const list = trame.refs['image-list'] + if (!list) return + // wait a tick so pagination prop is applied to computedRows + window.setTimeout(() => {{ + const ids = list.computedRows.map(i => i.id) + trigger('{ ctrl.trigger_name(self.set_in_view_ids) }', [ids]) + }}, 0) + "''', + ) + ctrl.get_visible_ids = get_visible_ids.exec with quasar.QTable( - classes="full-height", + ref=("image-list"), + classes="full-height sticky-header", flat=True, - hide_bottom=True, - title="Selected Images", + hide_bottom=("image_list_view_mode !== 'grid'", True), + title="Sampled Images", grid=("image_list_view_mode === 'grid'", False), filter=("image_list_search", ""), id="image-list", # set id so that the ImageDetection component can select the container for tooltip positioning visible_columns=("visible_columns",), columns=("columns",), rows=( - r"""source_image_ids.map((id) => + r"""image_list_ids.map((id) => { - const datasetId = id.split('_').at(-1) - const meta = get(`meta_${datasetId}`)?.value ?? {original_ground_to_original_detection_score: 0, ground_truth_to_transformed_detection_score: 0, original_detection_to_transformed_detection_score: 0} + const meta = get(`meta_${id}`)?.value ?? {original_ground_to_original_detection_score: 0, ground_truth_to_transformed_detection_score: 0, original_detection_to_transformed_detection_score: 0} + const original_id = `img_${id}` + const transformed_id = `transformed_img_${id}` return { ...meta, original_ground_to_original_detection_score: meta.original_ground_to_original_detection_score.toFixed(2), ground_truth_to_transformed_detection_score: meta.ground_truth_to_transformed_detection_score.toFixed(2), original_detection_to_transformed_detection_score: meta.original_detection_to_transformed_detection_score.toFixed(2), - id: datasetId, - original: id, - original_src: `original-image/${datasetId}`, - transformed: `transformed_${id}`, - groundTruthAnnotations: get(`result_${datasetId}`), - originalAnnotations: get(`result_${id}`), - transformedAnnotations: get(`result_transformed_${id}`), + id, + original: original_id, + original_src: get(original_id).value, + transformed: transformed_id, + transformed_src: get(transformed_id).value, + groundTruthAnnotations: get(`result_${id}`), + originalAnnotations: get(`result_img_${id}`), + transformedAnnotations: get(`result_transformed_img_${id}`), } }) """, ), row_key="id", - rows_per_page_options=("[0]",), # [0] means show all rows + rows_per_page_options=( + "image_list_view_mode === 'table' ? [0] : [6, 12, 24]", + "[0]", + ), + raw_attrs=[ + "virtual-scroll", + "virtual-scroll-slice-size='2'", + "virtual-scroll-item-size='200'", + # e.ref._.props.items is sorted+filtered rows like the QTable.computedRows computed prop + f'''@virtual-scroll="(e) => {{ + const ids = e.ref._.props.items.map(i => i.id).slice(e.from, e.to + 1) + trigger('{ self.server.controller.trigger_name(self.set_in_view_ids) }', [ids]) + }}"''', + "virtual-scroll-sticky-size-start='48'", + r"v-model:pagination='pagination'", + f'''@update:pagination="() => {{ + if(get('image_list_view_mode').value !== 'grid') return; + trigger('{ self.server.controller.trigger_name(ctrl.get_visible_ids) }') + }}"''', + ], ): # ImageDetection component for image columns with html.Template( @@ -100,14 +249,14 @@ def __init__(self, hover_fn=None, **kwargs): __properties=[("v_slot_body_cell_truth", "v-slot:body-cell-truth='props'")], ): with quasar.QTd(): - ImageDetection( - style="max-width: 10rem; float: inline-end;", + ImageWithSpinner( + style=("`width: ${image_size_image_list}rem; float: inline-end;`",), identifier=("props.row.original",), src=("props.row.original_src",), annotations=("props.row.groundTruthAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), containerSelector="#image-list .q-table__middle", ) with html.Template( @@ -117,14 +266,14 @@ def __init__(self, hover_fn=None, **kwargs): ], ): with quasar.QTd(): - ImageDetection( - style="max-width: 10rem; float: inline-end;", + ImageWithSpinner( + style=("`width: ${image_size_image_list}rem; float: inline-end;`",), identifier=("props.row.original",), src=("props.row.original_src",), annotations=("props.row.originalAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), containerSelector="#image-list .q-table__middle", ) with html.Template( @@ -137,14 +286,14 @@ def __init__(self, hover_fn=None, **kwargs): ], ): with quasar.QTd(): - ImageDetection( - style="max-width: 10rem; float: inline-end;", + ImageWithSpinner( + style=("`width: ${image_size_image_list}rem; float: inline-end;`",), identifier=("props.row.transformed",), - src=("get(props.row.transformed)",), + src=("props.row.transformed_src",), annotations=("props.row.transformedAnnotations",), categories=("annotation_categories",), selected=("(props.row.transformed == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), containerSelector="#image-list .q-table__middle", ) # Grid Mode template for each row/grid-item @@ -163,13 +312,13 @@ def __init__(self, hover_fn=None, **kwargs): "Original: Ground Truth Annotations", classes="text-center", ) - ImageDetection( + ImageWithSpinner( identifier=("props.row.original",), src=("props.row.original_src",), annotations=("props.row.groundTruthAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), ) with html.Div( classes="col-4 q-pa-sm", @@ -182,13 +331,13 @@ def __init__(self, hover_fn=None, **kwargs): "Original: Detection Annotations", classes="text-center", ) - ImageDetection( + ImageWithSpinner( identifier=("props.row.original",), src=("props.row.original_src",), annotations=("props.row.originalAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), ) with html.Div( classes="col-4 q-pa-sm", @@ -201,13 +350,13 @@ def __init__(self, hover_fn=None, **kwargs): "Transformed: Detection Annotations", classes="text-center", ) - ImageDetection( + ImageWithSpinner( identifier=("props.row.transformed",), - src=("get(props.row.transformed)",), + src=("props.row.transformed_src",), annotations=("props.row.transformedAnnotations",), categories=("annotation_categories",), selected=("(props.row.transformed == hovered_id)",), - hover=(hover_fn, "[$event]"), + hover=(on_hover, "[$event]"), ) with quasar.QList( dense=True, @@ -231,8 +380,29 @@ def __init__(self, hover_fn=None, **kwargs): v_slot_top=True, __properties=[("v_slot_top", "v-slot:top='props'")], ): - html.Span("Selected Images", classes="col q-table__title") + html.Span("Sampled Images", classes="col q-table__title") + # Image size + quasar.QIcon(name="zoom_in", size="1.2rem", classes="q-px-sm") + html.Span("Image Size") + quasar.QSlider( + classes="q-pl-sm q-pr-lg", + v_model=("image_size_image_list", 12), + raw_attrs=[ + ":min='5'", + ":max='40'", + ], + style="width: 12rem;", + ) + # Annotations visible switch + quasar.QIcon( + name="picture_in_picture", size="1.2rem", classes="q-pl-lg q-pr-sm" + ) + html.Span("Show Annotations") + quasar.QToggle( + v_model=("show_annotations_on_images", True), + ) quasar.QSelect( + classes="q-pl-xl q-pr-lg", v_model=("visible_columns"), multiple=True, dense=True, @@ -247,11 +417,11 @@ def __init__(self, hover_fn=None, **kwargs): ], ) quasar.QBtn( + classes="q-pl-lg q-pr-xl", icon="fullscreen", dense=True, flat=True, click="props.toggleFullscreen", - classes="q-mx-md", ) quasar.QBtnToggle( v_model=("image_list_view_mode", "table"), @@ -260,8 +430,9 @@ def __init__(self, hover_fn=None, **kwargs): ], ) quasar.QInput( + classes="q-pl-xl", v_model=("image_list_search", ""), + debounce="300", label="Search", dense=True, - classes="col-3 q-pl-md", ) diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py index fd6bb3f1..cea6d57e 100644 --- a/src/nrtk_explorer/app/ui/layout.py +++ b/src/nrtk_explorer/app/ui/layout.py @@ -4,187 +4,185 @@ from trame.widgets import html from nrtk_explorer.app import ui - -def toolbar(reload=None): - with quasar.QHeader(): - with quasar.QToolbar(classes="shadow-4"): - quasar.QToolbarTitle("NRTK_EXPLORER") - if reload: - quasar.QBtn( - "Reload", - click=(reload,), - flat=True, - ) +HORIZONTAL_SPLIT_DEFAULT_VALUE = 17 +VERTICAL_SPLIT_DEFAULT_VALUE = 40 def parse_dataset_dirs(datasets): return [{"label": Path(ds).name, "value": ds} for ds in datasets] -def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transforms_app=None): - with html.Div(classes="q-pa-md q-gutter-md"): - ( - dataset_title_slot, - dataset_content_slot, - _, - ) = ui.card("collapse_dataset") - - with dataset_title_slot: - html.Span("Dataset Selection", classes="text-h6") - - with dataset_content_slot: - quasar.QSelect( - label="Dataset", - v_model=("current_dataset",), - options=(parse_dataset_dirs(dataset_paths),), - filled=True, - emit_value=True, - map_options=True, - dense=True, - ) - quasar.QSlider( - v_model=("num_images", 15), - min=(0,), - max=("num_images_max", 25), - disable=("num_images_disabled", True), - step=(1,), - ) - html.P( - "{{num_images}}/{{num_images_max}} images", - classes="text-caption text-center", - ) - - quasar.QToggle( - v_model=("random_sampling", False), - dense=False, - label="Random selection", - ) - - ( - embeddings_title_slot, - embeddings_content_slot, - embeddings_actions_slot, - ) = ui.card("collapse_embeddings") - - with embeddings_title_slot: - html.Span("Embeddings", classes="text-h6") - - with embeddings_content_slot: - embeddings_app.settings_widget() - - with embeddings_actions_slot: - embeddings_app.compute_ui() - - (annotations_title_slot, annotations_content_slot, _) = ui.card("collapse_annotations") - - with annotations_title_slot: - html.Span("Annotations settings", classes="text-h6") - - with annotations_content_slot: - quasar.QSelect( - label="Object detection Model", - v_model=("object_detection_model", "facebook/detr-resnet-50"), - options=( - [ - { - "label": "facebook/detr-resnet-50", - "value": "facebook/detr-resnet-50", - }, - ], - ), - filled=True, - emit_value=True, - map_options=True, - ) - quasar.QInput( - v_model=("object_detection_batch_size", 32), - filled=True, - stack_label=True, - label="Batch Size", - type="number", - ) - - filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter") - - with filter_title_slot: - html.Span("Category Filter", classes="text-h6") - - with filter_content_slot: - filtering_app.filter_operator_ui() - filtering_app.filter_options_ui() - - with filter_actions_slot: - filtering_app.filter_apply_ui() - - ( - transforms_title_slot, - transforms_content_slot, - transforms_actions_slot, - ) = ui.card("collapse_transforms") - - with transforms_title_slot: - html.Span("Transform Settings", classes="text-h6") - - with transforms_content_slot: - transforms_app.settings_widget() - - with transforms_actions_slot: - transforms_app.apply_ui() - - -def dataset_view( - embeddings_app=None, - transforms_app=None, -): - with quasar.QSplitter( - v_model=("vertical_split",), - limits=("[0,100]",), - horizontal=True, - classes="inherit-height zero-height", +class NrtkDrawer(html.Div): + def __init__( + self, dataset_paths=[], embeddings_app=None, filtering_app=None, transforms_app=None ): - with html.Template(v_slot_before=True): - embeddings_app.visualization_widget() - - with html.Template(v_slot_after=True): - transforms_app.dataset_widget() - - -def explorer( - dataset_paths=[], - embeddings_app=None, - filtering_app=None, - transforms_app=None, -): - with quasar.QSplitter( - model_value=("horizontal_split",), - classes="inherit-height", - before_class="inherit-height zero-height scroll", - after_class="inherit-height zero-height", + super().__init__(classes="q-pa-md q-gutter-md") + + with self: + # DataSet card + with ui.CollapsibleCard() as card: + with card.slot_title: + html.Span("Dataset", classes="text-h6") + with card.slot_content: + quasar.QSelect( + label="Dataset", + v_model=("current_dataset",), + options=(parse_dataset_dirs(dataset_paths),), + filled=True, + emit_value=True, + map_options=True, + dense=True, + ) + quasar.QSlider( + v_model=("num_images", 15), + min=(0,), + max=("num_images_max", 25), + disable=("num_images_disabled", True), + step=(1,), + ) + html.P( + "{{num_images}}/{{num_images_max}} images", + classes="text-caption text-center", + ) + quasar.QToggle( + v_model=("random_sampling", False), + dense=False, + label="Random sampling", + ) + + # Embeddings + with ui.CollapsibleCard() as card: + with card.slot_title: + html.Span("Embeddings", classes="text-h6") + with card.slot_content: + embeddings_app.settings_widget() + with card.slot_actions: + embeddings_app.compute_ui() + + # Annotations + with ui.CollapsibleCard() as card: + with card.slot_title: + quasar.QToggle(v_model=("annotations_enabled_switch", False)) + html.Span("Model Inference", classes="text-h6") + with card.slot_content: + quasar.QSelect( + label="Object detection Model", + v_model=("object_detection_model", "facebook/detr-resnet-50"), + options=( + [ + { + "label": "facebook/detr-resnet-50", + "value": "facebook/detr-resnet-50", + }, + { + "label": "facebook/detr-resnet-50-dc5", + "value": "facebook/detr-resnet-50-dc5", + }, + { + "label": "hustvl/yolos-tiny", + "value": "hustvl/yolos-tiny", + }, + { + "label": "valentinafeve/yolos-fashionpedia", + "value": "valentinafeve/yolos-fashionpedia", + }, + ], + ), + filled=True, + emit_value=True, + map_options=True, + ) + + # Transforms + with ui.CollapsibleCard() as card: + with card.slot_title: + quasar.QToggle(v_model=("transform_enabled_switch", False)) + html.Span("Transform", classes="text-h6") + with card.slot_content: + transforms_app.settings_widget() + with card.slot_actions: + transforms_app.apply_ui() + + # Filters + with ui.CollapsibleCard() as card: + with card.slot_title: + html.Span("Category Filter", classes="text-h6") + with card.slot_content: + filtering_app.filter_operator_ui() + filtering_app.filter_options_ui() + with card.slot_actions: + filtering_app.filter_apply_ui() + + +class Splitter(quasar.QSplitter): + def __init__(self, **kwargs): + super().__init__(**kwargs) + with self: + self.slot_before = html.Template(raw_attrs=["v-slot:before"]) + self.slot_after = html.Template(raw_attrs=["v-slot:after"]) + + +class NrtkToolbar(quasar.QHeader): + def __init__(self, reload=None): + super().__init__() + with self: + with quasar.QToolbar(classes="shadow-4"): + quasar.QToolbarTitle("NRTK Explorer") + if reload: + quasar.QBtn( + "Reload", + click=(reload,), + flat=True, + ) + quasar.QSpinnerBox( + v_show="trame__busy", + size="2rem", + ) + + +class NrtkExplorerLayout(QLayout): + def __init__( + self, + server, + reload=None, + dataset_paths=None, + embeddings_app=None, + filtering_app=None, + transforms_app=None, + **kwargs, ): - with html.Template(v_slot_before=True): - parameters( - dataset_paths=dataset_paths, - embeddings_app=embeddings_app, - filtering_app=filtering_app, - transforms_app=transforms_app, - ) - - with html.Template(v_slot_after=True): - dataset_view(embeddings_app=embeddings_app, transforms_app=transforms_app) - - -def build_layout( - server=None, - reload=None, - **kwargs, -): - with QLayout( - server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2" - ) as layout: - toolbar(reload=reload) - - with quasar.QPageContainer(): - with quasar.QPage(): - explorer(**kwargs) - - return layout + super().__init__(server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2") + + # Make local variables on state + self.state.client_only("horizontal_split", "vertical_split") + self.state.trame__title = "NRTK Explorer" + + with self: + NrtkToolbar(reload=reload) + with quasar.QPageContainer(): + with quasar.QPage(): + with Splitter( + model_value=("horizontal_split", HORIZONTAL_SPLIT_DEFAULT_VALUE), + classes="inherit-height", + before_class="inherit-height zero-height scroll", + after_class="inherit-height zero-height", + ) as split_drawer_main: + with split_drawer_main.slot_before: + NrtkDrawer( + dataset_paths=dataset_paths, + embeddings_app=embeddings_app, + filtering_app=filtering_app, + transforms_app=transforms_app, + ) + with split_drawer_main.slot_after: + with Splitter( + v_model=("vertical_split", VERTICAL_SPLIT_DEFAULT_VALUE), + limits=("[0,100]",), + horizontal=True, + classes="inherit-height zero-height", + ) as split_scatter_table: + with split_scatter_table.slot_before: + embeddings_app.visualization_widget() + + with split_scatter_table.slot_after: + transforms_app.dataset_widget() diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index f53a338f..ecb62e76 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -51,7 +51,8 @@ def get_dataset(path: str, force_reload: bool = False): """ if force_reload: __load_dataset.cache_clear() - return __load_dataset(path) + absolute_path = str(Path(path).resolve()) + return __load_dataset(absolute_path) def get_image_fpath(selected_id: int, path: str): diff --git a/src/nrtk_explorer/library/debounce.py b/src/nrtk_explorer/library/debounce.py new file mode 100644 index 00000000..19411415 --- /dev/null +++ b/src/nrtk_explorer/library/debounce.py @@ -0,0 +1,32 @@ +import asyncio +from functools import wraps + + +def debounce(wait, state=None): + """Pass Trame state as arg if function modifies state""" + + def decorator(func): + task = None + + @wraps(func) + async def wrapper(*args, **kwargs): + nonlocal task + if task: + task.cancel() + + async def debounced(): + try: + await asyncio.sleep(wait) + if state: + with state: + await func(*args, **kwargs) + else: + await func(*args, **kwargs) + except asyncio.CancelledError: + pass + + task = asyncio.create_task(debounced()) + + return wrapper + + return decorator diff --git a/src/nrtk_explorer/library/dimension_reducers.py b/src/nrtk_explorer/library/dimension_reducers.py index bb59867d..cbb20c14 100644 --- a/src/nrtk_explorer/library/dimension_reducers.py +++ b/src/nrtk_explorer/library/dimension_reducers.py @@ -33,7 +33,7 @@ def reduce(self, name, fit_features, features=None, cache=True, **kwargs): self.cached_reducers[reducer_id] = reducer if features is None or len(features) == 0: - return None + return [] else: if cache is False or reduction_id not in self.cached_reductions: # Perform reduction without modifying the model diff --git a/src/nrtk_explorer/library/embeddings_extractor.py b/src/nrtk_explorer/library/embeddings_extractor.py index 2c3eedd9..361ffe24 100644 --- a/src/nrtk_explorer/library/embeddings_extractor.py +++ b/src/nrtk_explorer/library/embeddings_extractor.py @@ -3,10 +3,12 @@ import numpy as np import timm import torch +from PIL.Image import Image -from nrtk_explorer.library import images_manager from torch.utils.data import DataLoader, Dataset +IMAGE_MODEL_RESOLUTION = (224, 224) + # Create a dataset for images class ImagesDataset(Dataset): @@ -21,10 +23,7 @@ def __getitem__(self, i): class EmbeddingsExtractor: - def __init__( - self, model_name="resnet50d", manager=images_manager.ImagesManager(), force_cpu=False - ): - self.manager = manager + def __init__(self, model_name="resnet50d", force_cpu=False): self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" self.model = model_name @@ -53,27 +52,18 @@ def model(self, model_name): **timm.data.resolve_model_data_config(self._model.pretrained_cfg) ) - def transform_image(self, img): + def transform_image(self, image: Image): """Transform image to fit model input size and format""" + img = image.resize(IMAGE_MODEL_RESOLUTION).convert("RGB") return self._model_transformer(img).unsqueeze(0) - def extract(self, paths, content=None, batch_size=32): - """Extract features from images in paths""" - if len(paths) == 0: - return None + def extract(self, images, batch_size=32): + """Extract features from images""" + if len(images) == 0: + return [] features = list() - transformed_images = list() - - # Load images and transform them - for path in paths: - img = None - if content and path in content: - img = content[path] - else: - img = self.manager.load_image_for_model(path) - - transformed_images.append(self.transform_image(img)) + transformed_images = [self.transform_image(img) for img in images] # Extract features from images adjusted_batch_size = batch_size @@ -105,4 +95,4 @@ def extract(self, paths, content=None, batch_size=32): torch.cuda.empty_cache() # We should never reach here - return None + return [] diff --git a/src/nrtk_explorer/library/images_manager.py b/src/nrtk_explorer/library/images_manager.py deleted file mode 100644 index fbe7440f..00000000 --- a/src/nrtk_explorer/library/images_manager.py +++ /dev/null @@ -1,55 +0,0 @@ -from PIL.Image import Image -from PIL import Image as ImageModule - -import base64 -import copy -import io - -# Resolution for images to be used in model -IMAGE_MODEL_RESOLUTION = (224, 224) -THUMBNAIL_RESOLUTION = (250, 250) - - -def convert_to_base64(img: Image) -> str: - """Convert image to base64 string""" - buf = io.BytesIO() - img.save(buf, format="png") - return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode() - - -class ImagesManager: - """Class to manage images and thumbnails""" - - def __init__(self): - self.images = {} - self.thumbnails = {} - self.images_for_model = {} - - def prepare_for_model(self, img): - """Prepare image for model input""" - img = img.resize(IMAGE_MODEL_RESOLUTION) - return img.convert("RGB") - - def load_image(self, path): - """Load image from path and store it in cache if not already loaded""" - if path not in self.images: - self.images[path] = ImageModule.open(path) - - return self.images[path] - - def load_image_for_model(self, path): - """Load image for model from path and store it in cache if not already loaded""" - if path not in self.images_for_model: - img = copy.copy(self.load_image(path)) - self.images_for_model[path] = self.prepare_for_model(img) - - return self.images_for_model[path] - - def load_thumbnail(self, path): - """Load thumbnail from path and store it in cache if not already loaded""" - if path not in self.thumbnails: - img = copy.copy(self.load_image(path)) - img.thumbnail(THUMBNAIL_RESOLUTION) - self.thumbnails[path] = img - - return self.thumbnails[path] diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index c64875da..281ffa0d 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -2,12 +2,19 @@ import logging import torch import transformers +from typing import Optional, Sequence, Dict, NamedTuple +from PIL.Image import Image -from typing import Optional, Sequence -from nrtk_explorer.library import images_manager +ImageIdToAnnotations = dict[str, Sequence[dict]] -ImageIdToAnnotations = Optional[dict[str, Sequence[dict]]] + +class ImageWithId(NamedTuple): + id: str + image: Image + + +STARTING_BATCH_SIZE = 32 class ObjectDetector: @@ -17,16 +24,12 @@ def __init__( self, model_name: str = "facebook/detr-resnet-50", task: Optional[str] = None, - manager: Optional[images_manager.ImagesManager] = None, force_cpu: bool = False, ): - if manager is None: - manager = images_manager.ImagesManager() - self.task = task - self.manager = manager self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" self.pipeline = model_name + self.reset() @property def device(self) -> str: @@ -56,37 +59,39 @@ def pipeline(self, model_name: str): # Do not display warnings transformers.utils.logging.set_verbosity_error() + def reset(self): + self.batch_size = STARTING_BATCH_SIZE + def eval( self, - image_ids: list[str], - content: Optional[dict] = None, - batch_size: int = 32, + images: Dict[str, Image], + batch_size: int = 0, # 0 means use last successful batch size ) -> ImageIdToAnnotations: """Compute object recognition. Returns Annotations grouped by input image paths.""" + images_with_ids = [ImageWithId(id, img) for id, img in images.items()] + # Some models require all the images in a batch to be the same size, # otherwise crash or UB. batches: dict = {} - for path in image_ids: - img = None - if content and path in content: - img = content[path] - else: - img = self.manager.load_image(path) - - batches.setdefault(img.size, [[], []]) - batches[img.size][0].append(path) - batches[img.size][1].append(img) - - adjusted_batch_size = batch_size - while adjusted_batch_size > 0: + for image in images_with_ids: + size = image.image.size + batches.setdefault(size, []) + batches[size].append(image) + + if batch_size != 0: + self.batch_size = self.batch_size + while self.batch_size > 0: try: predictions_in_baches = [ zip( - image_ids, - self.pipeline(images, batch_size=adjusted_batch_size), + [image.id for image in imagesInBatch], + self.pipeline( + [image.image for image in imagesInBatch], + batch_size=self.batch_size, + ), ) - for image_ids, images in batches.values() + for imagesInBatch in batches.values() ] predictions_by_image_id = { @@ -97,11 +102,12 @@ def eval( return predictions_by_image_id except RuntimeError as e: - if "out of memory" in str(e) and adjusted_batch_size > 1: - previous_batch_size = adjusted_batch_size - adjusted_batch_size = adjusted_batch_size // 2 + if "out of memory" in str(e) and self.batch_size > 1: + previous_batch_size = self.batch_size + self.batch_size = self.batch_size // 2 + self.batch_size = self.batch_size print( - f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}" + f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={self.batch_size}" ) else: raise @@ -112,4 +118,4 @@ def eval( torch.cuda.empty_cache() # We should never reach here - return None + return {} diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..2be7ad7c --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,72 @@ +import unittest +from unittest.mock import Mock + +from nrtk_explorer.app.images.cache import LruCache + + +class TestLruCache(unittest.TestCase): + + def test_add_item(self): + cache = LruCache(max_size=2) + cache.add_item("key1", "value1") + self.assertEqual(cache.get_item("key1"), "value1") + + def test_get_item(self): + cache = LruCache(max_size=2) + cache.add_item("key1", "value1") + self.assertEqual(cache.get_item("key1"), "value1") + self.assertIsNone(cache.get_item("key2")) + + def test_cache_max_size(self): + cache = LruCache(max_size=2) + cache.add_item("key1", "value1") + cache.add_item("key2", "value2") + cache.add_item("key3", "value3") + self.assertIsNone(cache.get_item("key1")) + self.assertEqual(cache.get_item("key2"), "value2") + self.assertEqual(cache.get_item("key3"), "value3") + + def test_callbacks(self): + cache = LruCache(max_size=2) + on_add = Mock() + on_clear = Mock() + cache.add_item("key1", "value1", on_add_item=on_add, on_clear_item=on_clear) + on_add.assert_called_once_with("key1", "value1") + cache.clear() + on_clear.assert_called_once_with("key1") + + def test_callback_called_once(self): + cache = LruCache(max_size=2) + on_add = Mock() + on_clear = Mock() + + cache.add_item("key1", "value1", on_add_item=on_add, on_clear_item=on_clear) + cache.add_item("key1", "value1", on_add_item=on_add, on_clear_item=on_clear) + + on_add.assert_called_once_with("key1", "value1") + + cache.clear() + + on_clear.assert_called_once_with("key1") + + def test_multiple_callbacks(self): + cache = LruCache(max_size=2) + on_add_1 = Mock() + on_add_2 = Mock() + on_clear_1 = Mock() + on_clear_2 = Mock() + + cache.add_item("key1", "value1", on_add_item=on_add_1, on_clear_item=on_clear_1) + cache.add_item("key1", "value1", on_add_item=on_add_2, on_clear_item=on_clear_2) + + on_add_1.assert_called_once_with("key1", "value1") + on_add_2.assert_called_once_with("key1", "value1") + + cache.clear() + + on_clear_1.assert_called_once_with("key1") + on_clear_2.assert_called_once_with("key1") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 85f766f9..dd43440b 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,9 +1,9 @@ from nrtk_explorer.library import embeddings_extractor from nrtk_explorer.library import dimension_reducers -from nrtk_explorer.library import images_manager from nrtk_explorer.library.dataset import get_dataset import nrtk_explorer.test_data +from PIL import Image from tabulate import tabulate from itertools import product from pathlib import Path @@ -16,7 +16,7 @@ inc_ds_path = Path(f"{CURRENT_DIR_NAME}/coco-od-2017/test_val2017.json") -def image_paths_impl(file_name): +def images_impl(file_name): dataset = get_dataset(file_name) images = dataset.imgs.values() @@ -24,44 +24,44 @@ def image_paths_impl(file_name): for image_metadata in images: paths.append(os.path.join(os.path.dirname(file_name), image_metadata["file_name"])) - return paths + return [Image.open(path) for path in paths] @pytest.fixture -def image_paths(request): - return image_paths_impl(inc_ds_path) +def images(request): + return images_impl(inc_ds_path) @pytest.fixture -def image_paths_external(request): - return image_paths_impl(request.config.getoption("--benchmark-dataset-file")) +def images_external(request): + return images_impl(request.config.getoption("--benchmark-dataset-file")) -def test_features_small(image_paths): +def test_features_small(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) assert len(features) == 10 print(features) -def test_features_zero(image_paths): +def test_features_zero(images): extractor = embeddings_extractor.EmbeddingsExtractor() features = extractor.extract([]) - assert features is None + assert len(features) == 0 print(features) @pytest.mark.benchmark -def test_features_all(image_paths_external): +def test_features_all(images_external): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths_external) - assert len(features) == len(image_paths_external) + features = extractor.extract(images_external) + assert len(features) == len(images_external) print(f"Number of features: {len(features)}") -def test_pca_2d(image_paths): +def test_pca_2d(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) model = dimension_reducers.PCAReducer(2) points = model.fit(features) points = model.reduce(features) @@ -70,9 +70,9 @@ def test_pca_2d(image_paths): print(points) -def test_pca_3d(image_paths): +def test_pca_3d(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) model = dimension_reducers.PCAReducer(3) points = model.fit(features) points = model.reduce(features) @@ -81,9 +81,9 @@ def test_pca_3d(image_paths): print(points) -def test_umap_2d(image_paths): +def test_umap_2d(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) model = dimension_reducers.UMAPReducer(2, n_neighbors=8) points = model.fit(features) points = model.reduce(features) @@ -92,9 +92,9 @@ def test_umap_2d(image_paths): print(points) -def test_umap_3d(image_paths): +def test_umap_3d(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) model = dimension_reducers.UMAPReducer(3, n_neighbors=8) points = model.fit(features) points = model.reduce(features) @@ -103,9 +103,9 @@ def test_umap_3d(image_paths): print(points) -def test_reducer_manager(image_paths): +def test_reducer_manager(images): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths[:10]) + features = extractor.extract(images[:10]) mgr = dimension_reducers.DimReducerManager() old_points = mgr.reduce(fit_features=features, features=features, name="PCA", dims=3) assert len(old_points) > 0 @@ -119,7 +119,7 @@ def test_reducer_manager(image_paths): @pytest.mark.benchmark -def test_features_extractor_benchmark(image_paths_external): +def test_features_extractor_benchmark(images_external): repetitions = 3 sampling = [10, 100] batch_size = [1, 8, 16, 32] @@ -128,15 +128,10 @@ def test_features_extractor_benchmark(image_paths_external): setups.append([500, 64]) table = list() - # Pre-load images - manager = images_manager.ImagesManager() - for path in image_paths_external[: max(sampling)]: - manager.load_image_for_model(path) - for n, batch_size in setups: - extractor = embeddings_extractor.EmbeddingsExtractor(manager=manager) + extractor = embeddings_extractor.EmbeddingsExtractor() output = timeit.repeat( - stmt=lambda: extractor.extract(image_paths_external[:n], batch_size=batch_size), + stmt=lambda: extractor.extract(images_external[:n], batch_size=batch_size), number=repetitions, repeat=5, ) @@ -146,7 +141,7 @@ def test_features_extractor_benchmark(image_paths_external): @pytest.mark.benchmark -def test_reducer_manager_benchmark(image_paths_external): +def test_reducer_manager_benchmark(images_external): setups = [ ("PCA", 10, True, 100), ("PCA", 10, False, 100), @@ -160,7 +155,7 @@ def test_reducer_manager_benchmark(image_paths_external): mgr = dimension_reducers.DimReducerManager() extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths_external) + features = extractor.extract(images_external) # Short benchmarks cached for name, n, cache, iterations in setups: @@ -177,9 +172,9 @@ def test_reducer_manager_benchmark(image_paths_external): @pytest.mark.benchmark -def test_pca_3d_large(image_paths_external): +def test_pca_3d_large(images_external): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths_external) + features = extractor.extract(images_external) model = dimension_reducers.PCAReducer(3) points = model.reduce(features) assert len(points) > 0 @@ -189,9 +184,9 @@ def test_pca_3d_large(image_paths_external): @pytest.mark.benchmark -def test_umap_3d_large(image_paths_external): +def test_umap_3d_large(images_external): extractor = embeddings_extractor.EmbeddingsExtractor() - features = extractor.extract(image_paths_external) + features = extractor.extract(images_external) model = dimension_reducers.UMAPReducer(3) points = model.reduce(features) assert len(points) > 0 diff --git a/tests/test_object_detector.py b/tests/test_object_detector.py index ab2e30f7..c10dc9f3 100644 --- a/tests/test_object_detector.py +++ b/tests/test_object_detector.py @@ -7,6 +7,7 @@ import json import os +from PIL import Image import nrtk_explorer.test_data DIR_PATH = os.path.dirname(nrtk_explorer.test_data.__file__) @@ -16,17 +17,20 @@ def test_detector_small(): ds = json.load(open(DATASET)) - sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]][:15] + sample = { + id: Image.open(f"{DATASET_PATH}/{img['file_name']}") + for id, img in enumerate(ds["images"][:15]) + } detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny") - img = detector.eval(image_ids=sample) + img = detector.eval(sample) assert len(img) == 15 def test_nrkt_scorer(): ds = json.load(open(DATASET)) - sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]] + sample = {img["id"]: Image.open(f"{DATASET_PATH}/{img['file_name']}") for img in ds["images"]} detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50") - predictions = detector.eval(image_ids=sample) + predictions = detector.eval(sample) dataset_annotations = dict() for annotation in ds["annotations"]: diff --git a/vue-components/package-lock.json b/vue-components/package-lock.json index b98c8772..8c75a90d 100644 --- a/vue-components/package-lock.json +++ b/vue-components/package-lock.json @@ -26,7 +26,7 @@ "scatter-gl": "^0.0.13", "semantic-release": "^24.0.0", "typescript": "~5.1.6", - "vite": "^4.5.3", + "vite": "^4.5.5", "vue": "^3.0.0", "vue-tsc": "^1.8.6" }, @@ -8410,9 +8410,9 @@ } }, "node_modules/rollup": { - "version": "3.29.4", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-3.29.4.tgz", - "integrity": "sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==", + "version": "3.29.5", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-3.29.5.tgz", + "integrity": "sha512-GVsDdsbJzzy4S/v3dqWPJ7EfvZJfCHiDqe80IyrF59LYuP+e6U1LJoUqeuqRbwAWoMNoXivMNeNAOf5E22VA1w==", "dev": true, "bin": { "rollup": "dist/bin/rollup" @@ -9590,9 +9590,9 @@ } }, "node_modules/vite": { - "version": "4.5.3", - "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.3.tgz", - "integrity": "sha512-kQL23kMeX92v3ph7IauVkXkikdDRsYMGTVl5KY2E9OY4ONLvkHf04MDTbnfo6NKxZiDLWzVpP5oTa8hQD8U3dg==", + "version": "4.5.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.5.tgz", + "integrity": "sha512-ifW3Lb2sMdX+WU91s3R0FyQlAyLxOzCSCP37ujw0+r5POeHPwe6udWVIElKQq8gk3t7b8rkmvqC6IHBpCff4GQ==", "dev": true, "dependencies": { "esbuild": "^0.18.10", diff --git a/vue-components/package.json b/vue-components/package.json index f514416e..8a4256bc 100644 --- a/vue-components/package.json +++ b/vue-components/package.json @@ -47,7 +47,7 @@ "scatter-gl": "^0.0.13", "semantic-release": "^24.0.0", "typescript": "~5.1.6", - "vite": "^4.5.3", + "vite": "^4.5.5", "vue": "^3.0.0", "vue-tsc": "^1.8.6" }, diff --git a/vue-components/src/components/ImageDetection.vue b/vue-components/src/components/ImageDetection.vue index 109de26e..d784f86b 100644 --- a/vue-components/src/components/ImageDetection.vue +++ b/vue-components/src/components/ImageDetection.vue @@ -15,7 +15,7 @@ const CATEGORY_COLORS: Vector3[] = [ ] const TOOLTIP_OFFSET = [8, 8] -const TOOLTIP_HEIGHT_PADDING = 12 // fudge to keep bottom border from clipping. In pixels +const TOOLTIP_PADDING = 12 // fudge to keep tooltip from clipping/overflowing. In pixels let annotationsTree: Quadtree> | undefined = undefined @@ -220,8 +220,8 @@ function mouseMove(e: MouseEvent) { const toolTipInContainer = { left: parentRect.left + posX - containerRect.left, top: parentRect.top + posY - containerRect.top, - width: tooltipRect.width, - height: tooltipRect.height + TOOLTIP_HEIGHT_PADDING + width: tooltipRect.width + TOOLTIP_PADDING, + height: tooltipRect.height + TOOLTIP_PADDING } // if text goes off the edge, move up and/or left @@ -239,8 +239,6 @@ function mouseMove(e: MouseEvent) { const borderSize = computed(() => (props.selected ? '4' : '0')) const src = computed(() => unref(props.src)) - -const showSpinner = computed(() => !src.value || unref(props.annotations) == undefined) diff --git a/vue-components/src/components/ScatterPlot.vue b/vue-components/src/components/ScatterPlot.vue index b2ccccc1..d7746417 100644 --- a/vue-components/src/components/ScatterPlot.vue +++ b/vue-components/src/components/ScatterPlot.vue @@ -1,6 +1,6 @@