Skip to content

Commit

Permalink
fix: avoid out of memory errors with larger images
Browse files Browse the repository at this point in the history
- Use generator when computing embeddings.
- Ajust image cache size based on image size and available memory.
  • Loading branch information
PaulHax committed Jan 27, 2025
1 parent 1c99cfc commit 37672e9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def update_points_transformations_state(self, **kwargs):
self.state.points_transformations = {}

def compute_source_points(self):
images = [
images = (
self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids
]
)
self.features = self.extractor.extract(images)

points = self.compute_points(self.features, self.features)
Expand Down
28 changes: 21 additions & 7 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import psutil
from io import BytesIO
from PIL import Image
from trame.decorators import TrameApp, change, controller
Expand All @@ -18,23 +19,35 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


IMAGE_CACHE_SIZE = 500
IMAGE_CACHE_SIZE_DEFAULT = 50
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.5


@TrameApp()
class Images:
def __init__(self, server):
self.server = server
self.original_images = LruCache(
IMAGE_CACHE_SIZE,
)
self.transformed_images = LruCache(
IMAGE_CACHE_SIZE,
)
self.original_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self.transformed_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self._should_reset_cache = True

def _ajust_cache_size(self, image_example: Image.Image):
img_size = len(image_example.tobytes())
system_memory = psutil.virtual_memory().available
mem_for_cache = round(system_memory * AVALIBLE_MEMORY_TO_TAKE_FACTOR)
images_that_fit = max(min(mem_for_cache // img_size, 500), 50)
cache_size = images_that_fit // 2
self.original_images = LruCache(cache_size)
self.transformed_images = LruCache(cache_size)

def _load_image(self, dataset_id: str):
img = self.server.context.dataset.get_image(int(dataset_id))
img.load() # Avoid OSError(24, 'Too many open files')

if self._should_reset_cache:
self._should_reset_cache = False
self._ajust_cache_size(img) # assuming images in dataset are similar size

# transforms and base64 encoding require RGB mode
return img.convert("RGB") if img.mode != "RGB" else img

Expand Down Expand Up @@ -105,6 +118,7 @@ def get_transformed_image_without_cache_eviction(
def clear_all(self, **kwargs):
self.original_images.clear()
self.clear_transformed()
self._should_reset_cache = True

@controller.add("apply_transform")
def clear_transformed(self, **kwargs):
Expand Down
5 changes: 2 additions & 3 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ def transform_image(self, image: Image):

def extract(self, images, batch_size=0):
"""Extract features from images"""
if len(images) == 0:
return []

if batch_size != 0:
self.batch_size = batch_size

features = list()
transformed_images = [self.transform_image(img) for img in images]
if len(transformed_images) == 0:
return []

while self.batch_size > 0:
try:
Expand Down

0 comments on commit 37672e9

Please sign in to comment.