Skip to content

Commit 844fc2e

Browse files
authored
Fix task chain for Det -> Cls / Seg (#4105)
* fix linter * return recipe back * added roi extraction for multi cllass classification datasett * fix linter * add same logic to semantic seg * added test for OTXDataset * add clip and raise an error when coordinates are invalid. * rewrite value error
1 parent 88ab4b8 commit 844fc2e

File tree

11 files changed

+182
-35
lines changed

11 files changed

+182
-35
lines changed

src/otx/algo/utils/xai_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _get_image_data_name(
225225
subset = datamodule.subsets[subset_name]
226226
item = subset.dm_subset[img_id]
227227
img = item.media_as(Image)
228-
img_data, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
228+
img_data, _, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
229229
image_save_name = "".join([char if char.isalnum() else "_" for char in item.id])
230230
return img_data, image_save_name
231231

src/otx/core/data/dataset/anomaly.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_item_impl(
7979
datumaro_item = self.dm_subset[index]
8080
img = datumaro_item.media_as(Image)
8181
# returns image in RGB format if self.image_color_channel is RGB
82-
img_data, img_shape = self._get_img_data_and_shape(img)
82+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
8383

8484
label = self._get_label(datumaro_item)
8585

src/otx/core/data/dataset/base.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import abstractmethod
99
from collections.abc import Iterable
1010
from contextlib import contextmanager
11-
from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union
11+
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union
1212

1313
import cv2
1414
import numpy as np
@@ -92,6 +92,7 @@ def __init__(
9292
self.image_color_channel = image_color_channel
9393
self.stack_images = stack_images
9494
self.to_tv_image = to_tv_image
95+
9596
if self.dm_subset.categories():
9697
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
9798
else:
@@ -141,11 +142,31 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:
141142
msg = f"Reach the maximum refetch number ({self.max_refetch})"
142143
raise RuntimeError(msg)
143144

144-
def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
145+
def _get_img_data_and_shape(
146+
self,
147+
img: Image,
148+
roi: dict[str, Any] | None = None,
149+
) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]:
150+
"""Get image data and shape.
151+
152+
This method is used to get image data and shape from Datumaro image object.
153+
If ROI is provided, the image data is extracted from the ROI.
154+
155+
Args:
156+
img (Image): Image object from Datumaro.
157+
roi (dict[str, Any] | None, Optional): Region of interest.
158+
Represented by dict with coordinates and some meta information.
159+
160+
Returns:
161+
The image data, shape, and ROI meta information
162+
"""
145163
key = img.path if isinstance(img, ImageFromFile) else id(img)
164+
roi_meta = None
146165

147-
if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
148-
return img_data, img_data.shape[:2]
166+
# check if the image is already in the cache
167+
img_data, roi_meta = self.mem_cache_handler.get(key=key)
168+
if img_data is not None:
169+
return img_data, img_data.shape[:2], roi_meta
149170

150171
with image_decode_context():
151172
img_data = (
@@ -158,11 +179,28 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in
158179
msg = "Cannot get image data"
159180
raise RuntimeError(msg)
160181

161-
img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))
182+
if roi:
183+
# extract ROI from image
184+
shape = roi["shape"]
185+
h, w = img_data.shape[:2]
186+
x1, y1, x2, y2 = (
187+
int(np.clip(np.trunc(shape["x1"] * w), 0, w)),
188+
int(np.clip(np.trunc(shape["y1"] * h), 0, h)),
189+
int(np.clip(np.ceil(shape["x2"] * w), 0, w)),
190+
int(np.clip(np.ceil(shape["y2"] * h), 0, h)),
191+
)
192+
if (x2 - x1) * (y2 - y1) <= 0:
193+
msg = f"ROI has zero or negative area. ROI coordinates: {x1}, {y1}, {x2}, {y2}"
194+
raise ValueError(msg)
195+
196+
img_data = img_data[y1:y2, x1:x2]
197+
roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)}
198+
199+
img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta)
162200

163-
return img_data, img_data.shape[:2]
201+
return img_data, img_data.shape[:2], roi_meta
164202

165-
def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
203+
def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray:
166204
"""Cache an image after resizing.
167205
168206
If there is available space in the memory pool, the input image is cached.
@@ -182,14 +220,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
182220
return img_data
183221

184222
if self.mem_cache_img_max_size is None:
185-
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
223+
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
186224
return img_data
187225

188226
height, width = img_data.shape[:2]
189227
max_height, max_width = self.mem_cache_img_max_size
190228

191229
if height <= max_height and width <= max_width:
192-
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
230+
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
193231
return img_data
194232

195233
# Preserve the image size ratio and fit to max_height or max_width
@@ -206,7 +244,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
206244
self.mem_cache_handler.put(
207245
key=key,
208246
data=resized_img,
209-
meta=None,
247+
meta=meta,
210248
)
211249
return resized_img
212250

src/otx/core/data/dataset/classification.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ class OTXMulticlassClsDataset(OTXDataset[MulticlassClsDataEntity]):
3232
def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
3333
item = self.dm_subset[index]
3434
img = item.media_as(Image)
35-
img_data, img_shape = self._get_img_data_and_shape(img)
35+
roi = item.attributes.get("roi", None)
36+
img_data, img_shape, _ = self._get_img_data_and_shape(img, roi)
37+
if roi:
38+
# extract labels from ROI
39+
labels_ids = [
40+
label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION"
41+
]
42+
label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids]
43+
else:
44+
# extract labels from annotations
45+
label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)]
3646

37-
label_anns = []
38-
for ann in item.annotations:
39-
if isinstance(ann, Label):
40-
label_anns.append(ann)
41-
else:
42-
# If the annotation is not Label, it should be converted to Label.
43-
# For Chained Task: Detection (Bbox) -> Classification (Label)
44-
label = Label(label=ann.label)
45-
if label not in label_anns:
46-
label_anns.append(label)
4747
if len(label_anns) > 1:
4848
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
4949
raise ValueError(msg)
@@ -56,7 +56,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
5656
ori_shape=img_shape,
5757
image_color_channel=self.image_color_channel,
5858
),
59-
labels=torch.as_tensor([ann.label for ann in label_anns]),
59+
labels=torch.as_tensor(label_anns),
6060
)
6161

6262
return self._apply_transforms(entity)
@@ -78,7 +78,7 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
7878
item = self.dm_subset[index]
7979
img = item.media_as(Image)
8080
ignored_labels: list[int] = [] # This should be assigned form item
81-
img_data, img_shape = self._get_img_data_and_shape(img)
81+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
8282

8383
label_anns = []
8484
for ann in item.annotations:
@@ -195,7 +195,7 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
195195
item = self.dm_subset[index]
196196
img = item.media_as(Image)
197197
ignored_labels: list[int] = [] # This should be assigned form item
198-
img_data, img_shape = self._get_img_data_and_shape(img)
198+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
199199

200200
label_anns = []
201201
for ann in item.annotations:

src/otx/core/data/dataset/detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None:
2626
item = self.dm_subset[index]
2727
img = item.media_as(Image)
2828
ignored_labels: list[int] = [] # This should be assigned form item
29-
img_data, img_shape = self._get_img_data_and_shape(img)
29+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
3030

3131
bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
3232

src/otx/core/data/dataset/instance_segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
4040
item = self.dm_subset[index]
4141
img = item.media_as(Image)
4242
ignored_labels: list[int] = []
43-
img_data, img_shape = self._get_img_data_and_shape(img)
43+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
4444

4545
gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []
4646

src/otx/core/data/dataset/keypoint_detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
8686
item = self.dm_subset[index]
8787
img = item.media_as(Image)
8888
ignored_labels: list[int] = [] # This should be assigned form item
89-
img_data, img_shape = self._get_img_data_and_shape(img)
89+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
9090

9191
bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
9292
bboxes = (

src/otx/core/data/dataset/segmentation.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,14 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None:
202202
item = self.dm_subset[index]
203203
img = item.media_as(Image)
204204
ignored_labels: list[int] = []
205-
img_data, img_shape = self._get_img_data_and_shape(img)
205+
roi = item.attributes.get("roi", None)
206+
img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi)
206207
if item.annotations:
207-
extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)
208+
ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape
209+
extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index)
210+
if roi_meta:
211+
extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]]
212+
208213
masks = tv_tensors.Mask(extracted_mask[None])
209214
else:
210215
# semi-supervised learning, unlabeled dataset

src/otx/core/data/dataset/tile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
370370
"""
371371
item = self.dm_subset[index]
372372
img = item.media_as(Image)
373-
img_data, img_shape = self._get_img_data_and_shape(img)
373+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
374374

375375
bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
376376

@@ -461,7 +461,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
461461
"""
462462
item = self.dm_subset[index]
463463
img = item.media_as(Image)
464-
img_data, img_shape = self._get_img_data_and_shape(img)
464+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
465465

466466
gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []
467467

src/otx/core/data/dataset/visual_prompting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
8080
item = self.dm_subset[index]
8181
img = item.media_as(dmImage)
82-
img_data, img_shape = self._get_img_data_and_shape(img)
82+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
8383

8484
gt_bboxes, gt_points = [], []
8585
gt_masks = defaultdict(list)
@@ -214,7 +214,7 @@ def __init__(
214214
def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None:
215215
item = self.dm_subset[index]
216216
img = item.media_as(dmImage)
217-
img_data, img_shape = self._get_img_data_and_shape(img)
217+
img_data, img_shape, _ = self._get_img_data_and_shape(img)
218218

219219
gt_prompts: list[tvBoundingBoxes | Points] = []
220220
gt_masks: list[tvMask] = []
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest import mock
2+
3+
import numpy as np
4+
import pytest
5+
from datumaro.components.media import Image
6+
from otx.core.data.dataset.base import OTXDataset
7+
8+
9+
class TestOTXDataset:
10+
@pytest.fixture()
11+
def mock_image(self) -> Image:
12+
img = mock.Mock(spec=Image)
13+
img.data = np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)
14+
img.path = "test_path"
15+
return img
16+
17+
@pytest.fixture()
18+
def mock_mem_cache_handler(self):
19+
mem_cache_handler = mock.MagicMock()
20+
mem_cache_handler.frozen = False
21+
return mem_cache_handler
22+
23+
@pytest.fixture()
24+
def otx_dataset(self, mock_mem_cache_handler):
25+
class MockOTXDataset(OTXDataset):
26+
def _get_item_impl(self, idx: int) -> None:
27+
return None
28+
29+
@property
30+
def collate_fn(self) -> None:
31+
return None
32+
33+
dm_subset = mock.Mock()
34+
dm_subset.categories = mock.MagicMock()
35+
dm_subset.categories.return_value = None
36+
37+
return MockOTXDataset(
38+
dm_subset=dm_subset,
39+
transforms=None,
40+
mem_cache_handler=mock_mem_cache_handler,
41+
mem_cache_img_max_size=None,
42+
)
43+
44+
def test_get_img_data_and_shape_no_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
45+
mock_mem_cache_handler.get.return_value = (None, None)
46+
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
47+
assert img_data.shape == (10, 10, 3)
48+
assert img_shape == (10, 10)
49+
assert roi_meta is None
50+
51+
def test_get_img_data_and_shape_with_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
52+
mock_mem_cache_handler.get.return_value = (np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8), None)
53+
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
54+
assert img_data.shape == (10, 10, 3)
55+
assert img_shape == (10, 10)
56+
assert roi_meta is None
57+
58+
def test_get_img_data_and_shape_with_roi(self, otx_dataset, mock_image, mock_mem_cache_handler):
59+
roi = {"shape": {"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}}
60+
mock_mem_cache_handler.get.return_value = (None, None)
61+
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image, roi)
62+
assert img_data.shape == (8, 8, 3)
63+
assert img_shape == (8, 8)
64+
assert roi_meta == {"x1": 1, "y1": 1, "x2": 9, "y2": 9, "orig_image_shape": (10, 10)}
65+
66+
def test_cache_img_no_resize(self, otx_dataset):
67+
img_data = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8)
68+
key = "test_key"
69+
70+
cached_img = otx_dataset._cache_img(key, img_data)
71+
72+
assert np.array_equal(cached_img, img_data)
73+
otx_dataset.mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)
74+
75+
def test_cache_img_with_resize(self, otx_dataset, mock_mem_cache_handler):
76+
otx_dataset.mem_cache_img_max_size = (100, 100)
77+
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
78+
key = "test_key"
79+
80+
cached_img = otx_dataset._cache_img(key, img_data)
81+
82+
assert cached_img.shape == (100, 100, 3)
83+
mock_mem_cache_handler.put.assert_called_once()
84+
assert mock_mem_cache_handler.put.call_args[1]["data"].shape == (100, 100, 3)
85+
86+
def test_cache_img_no_max_size(self, otx_dataset, mock_mem_cache_handler):
87+
otx_dataset.mem_cache_img_max_size = None
88+
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
89+
key = "test_key"
90+
91+
cached_img = otx_dataset._cache_img(key, img_data)
92+
93+
assert np.array_equal(cached_img, img_data)
94+
mock_mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)
95+
96+
def test_cache_img_frozen_handler(self, otx_dataset, mock_mem_cache_handler):
97+
mock_mem_cache_handler.frozen = True
98+
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
99+
key = "test_key"
100+
101+
cached_img = otx_dataset._cache_img(key, img_data)
102+
103+
assert np.array_equal(cached_img, img_data)
104+
mock_mem_cache_handler.put.assert_not_called()

0 commit comments

Comments
 (0)