Skip to content

Commit

Permalink
refactor: store chained transform in images store
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 30, 2025
1 parent 2696107 commit b35a2ea
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 33 deletions.
7 changes: 2 additions & 5 deletions src/nrtk_explorer/app/images/image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ async def original_image_endpoint(images: Images, state_id_to_dataset_id, reques
return make_response(image)


async def transform_image_endpoint(
images: Images, state_id_to_dataset_id, context, request: web.Request
):
async def transform_image_endpoint(images: Images, state_id_to_dataset_id, request: web.Request):
dataset_id = state_id_to_dataset_id[request.match_info["id"]]
image = images.get_transformed_image(context.chained_transform, dataset_id)
image = images.get_transformed_image(dataset_id)
return make_response(image)


Expand All @@ -69,7 +67,6 @@ def __init__(self, server, images: Images):
transform_image_endpoint,
images,
self._state_id_to_dataset_id,
self.server.context,
)

change_checker(self.server.state, "dataset_ids")(self.on_dataset_ids_change)
Expand Down
30 changes: 14 additions & 16 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import psutil
from PIL import Image
from trame.decorators import TrameApp, change, controller
from trame.decorators import TrameApp, change
from nrtk_explorer.app.images.image_ids import (
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.app.images.cache import LruCache
from nrtk_explorer.library.transforms import ImageTransform


IMAGE_CACHE_SIZE_DEFAULT = 50
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.3
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.2


@TrameApp()
Expand All @@ -20,6 +19,7 @@ def __init__(self, server):
self.original_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self.transformed_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self._should_ajust_cache_size = True
self._transform = None

def _ajust_cache_size(self, image_example: Image.Image):
img_size = len(image_example.tobytes())
Expand Down Expand Up @@ -58,39 +58,37 @@ def get_image_without_cache_eviction(self, dataset_id: str):
self.original_images.add_if_room(image_id, image)
return image

def _load_transformed_image(self, transform: ImageTransform, dataset_id: str):
def _load_transformed_image(self, dataset_id: str):
original = self.get_image_without_cache_eviction(dataset_id)
transformed = transform.execute(original)
transformed = self._transform.execute(original)
# So pixel-wise annotation similarity score works
if original.size != transformed.size:
return transformed.resize(original.size)
return transformed

def _get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
def _get_transformed_image(self, dataset_id: str, **kwargs):
image_id = dataset_id_to_transformed_image_id(dataset_id)
image = self.transformed_images.get_item(image_id) or self._load_transformed_image(
transform, dataset_id
dataset_id
)
return image_id, image

def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
image_id, image = self._get_transformed_image(transform, dataset_id, **kwargs)
def get_transformed_image(self, dataset_id: str, **kwargs):
image_id, image = self._get_transformed_image(dataset_id, **kwargs)
self.transformed_images.add_item(image_id, image, **kwargs)
return image

def get_transformed_image_without_cache_eviction(
self, transform: ImageTransform, dataset_id: str
):
image_id, image = self._get_transformed_image(transform, dataset_id)
def get_transformed_image_without_cache_eviction(self, dataset_id: str):
image_id, image = self._get_transformed_image(dataset_id)
self.transformed_images.add_if_room(image_id, image)
return image

@change("current_dataset")
def clear_all(self, **kwargs):
self.original_images.clear()
self.clear_transformed()
self.transformed_images.clear()
self._should_ajust_cache_size = True

@controller.add("apply_transform")
def clear_transformed(self, **kwargs):
def set_transform(self, transform):
self._transform = transform
self.transformed_images.clear()
17 changes: 5 additions & 12 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,18 @@ def on_apply_transform(self, **kwargs):
# Turn on switch if user clicked lower apply button
self.state.transform_enabled_switch = True
transforms = list(map(lambda t: t["instance"], self.context.transforms))
self.context.chained_transform = trans.ChainedImageTransform(transforms)
chained_transform = trans.ChainedImageTransform(transforms)
self.images.set_transform(chained_transform)
self._start_update_images()

async def update_transformed_images(
self, dataset_ids, predictions_original_images, visible=False
):
async def update_transformed_images(self, dataset_ids, predictions_original_images):
if not self.state.transform_enabled:
return

transform = self.context.chained_transform

id_to_image = LazyDict()
for id in dataset_ids:
id_to_image[dataset_id_to_transformed_image_id(id)] = (
lambda id=id: self.images.get_transformed_image_without_cache_eviction(
transform, id
)
lambda id=id: self.images.get_transformed_image_without_cache_eviction(id)
)

with self.state:
Expand Down Expand Up @@ -375,9 +370,7 @@ async def _update_images(self, dataset_ids, visible=False):
# sortable score value may have changed which may have changed images that are in view
self.server.controller.check_images_in_view()

await self.update_transformed_images(
dataset_ids, predictions_original_images, visible=visible
)
await self.update_transformed_images(dataset_ids, predictions_original_images)

async def _chunk_update_images(self, dataset_ids, visible=False):
ids = list(dataset_ids)
Expand Down

0 comments on commit b35a2ea

Please sign in to comment.