8
8
from abc import abstractmethod
9
9
from collections .abc import Iterable
10
10
from contextlib import contextmanager
11
- from typing import TYPE_CHECKING , Callable , Generic , Iterator , List , Union
11
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterator , List , Union
12
12
13
13
import cv2
14
14
import numpy as np
@@ -92,6 +92,7 @@ def __init__(
92
92
self .image_color_channel = image_color_channel
93
93
self .stack_images = stack_images
94
94
self .to_tv_image = to_tv_image
95
+
95
96
if self .dm_subset .categories ():
96
97
self .label_info = LabelInfo .from_dm_label_groups (self .dm_subset .categories ()[AnnotationType .label ])
97
98
else :
@@ -141,11 +142,31 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:
141
142
msg = f"Reach the maximum refetch number ({ self .max_refetch } )"
142
143
raise RuntimeError (msg )
143
144
144
- def _get_img_data_and_shape (self , img : Image ) -> tuple [np .ndarray , tuple [int , int ]]:
145
+ def _get_img_data_and_shape (
146
+ self ,
147
+ img : Image ,
148
+ roi : dict [str , Any ] | None = None ,
149
+ ) -> tuple [np .ndarray , tuple [int , int ], dict [str , Any ] | None ]:
150
+ """Get image data and shape.
151
+
152
+ This method is used to get image data and shape from Datumaro image object.
153
+ If ROI is provided, the image data is extracted from the ROI.
154
+
155
+ Args:
156
+ img (Image): Image object from Datumaro.
157
+ roi (dict[str, Any] | None, Optional): Region of interest.
158
+ Represented by dict with coordinates and some meta information.
159
+
160
+ Returns:
161
+ The image data, shape, and ROI meta information
162
+ """
145
163
key = img .path if isinstance (img , ImageFromFile ) else id (img )
164
+ roi_meta = None
146
165
147
- if (img_data := self .mem_cache_handler .get (key = key )[0 ]) is not None :
148
- return img_data , img_data .shape [:2 ]
166
+ # check if the image is already in the cache
167
+ img_data , roi_meta = self .mem_cache_handler .get (key = key )
168
+ if img_data is not None :
169
+ return img_data , img_data .shape [:2 ], roi_meta
149
170
150
171
with image_decode_context ():
151
172
img_data = (
@@ -158,11 +179,28 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in
158
179
msg = "Cannot get image data"
159
180
raise RuntimeError (msg )
160
181
161
- img_data = self ._cache_img (key = key , img_data = img_data .astype (np .uint8 ))
182
+ if roi :
183
+ # extract ROI from image
184
+ shape = roi ["shape" ]
185
+ h , w = img_data .shape [:2 ]
186
+ x1 , y1 , x2 , y2 = (
187
+ int (np .clip (np .trunc (shape ["x1" ] * w ), 0 , w )),
188
+ int (np .clip (np .trunc (shape ["y1" ] * h ), 0 , h )),
189
+ int (np .clip (np .ceil (shape ["x2" ] * w ), 0 , w )),
190
+ int (np .clip (np .ceil (shape ["y2" ] * h ), 0 , h )),
191
+ )
192
+ if (x2 - x1 ) * (y2 - y1 ) <= 0 :
193
+ msg = f"ROI has zero or negative area. ROI coordinates: { x1 } , { y1 } , { x2 } , { y2 } "
194
+ raise ValueError (msg )
195
+
196
+ img_data = img_data [y1 :y2 , x1 :x2 ]
197
+ roi_meta = {"x1" : x1 , "y1" : y1 , "x2" : x2 , "y2" : y2 , "orig_image_shape" : (h , w )}
198
+
199
+ img_data = self ._cache_img (key = key , img_data = img_data .astype (np .uint8 ), meta = roi_meta )
162
200
163
- return img_data , img_data .shape [:2 ]
201
+ return img_data , img_data .shape [:2 ], roi_meta
164
202
165
- def _cache_img (self , key : str | int , img_data : np .ndarray ) -> np .ndarray :
203
+ def _cache_img (self , key : str | int , img_data : np .ndarray , meta : dict [ str , Any ] | None = None ) -> np .ndarray :
166
204
"""Cache an image after resizing.
167
205
168
206
If there is available space in the memory pool, the input image is cached.
@@ -182,14 +220,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
182
220
return img_data
183
221
184
222
if self .mem_cache_img_max_size is None :
185
- self .mem_cache_handler .put (key = key , data = img_data , meta = None )
223
+ self .mem_cache_handler .put (key = key , data = img_data , meta = meta )
186
224
return img_data
187
225
188
226
height , width = img_data .shape [:2 ]
189
227
max_height , max_width = self .mem_cache_img_max_size
190
228
191
229
if height <= max_height and width <= max_width :
192
- self .mem_cache_handler .put (key = key , data = img_data , meta = None )
230
+ self .mem_cache_handler .put (key = key , data = img_data , meta = meta )
193
231
return img_data
194
232
195
233
# Preserve the image size ratio and fit to max_height or max_width
@@ -206,7 +244,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
206
244
self .mem_cache_handler .put (
207
245
key = key ,
208
246
data = resized_img ,
209
- meta = None ,
247
+ meta = meta ,
210
248
)
211
249
return resized_img
212
250
0 commit comments