Skip to content

Commit df68b55

Browse files
authored
Feature: Grayscale Images Support (#320)
1 parent 0c6b120 commit df68b55

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

luxonis_ml/data/loaders/luxonis_loader.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
width: int | None = None,
4747
keep_aspect_ratio: bool = True,
4848
exclude_empty_annotations: bool = False,
49-
color_space: Literal["RGB", "BGR"] = "RGB",
49+
color_space: Literal["RGB", "BGR", "GRAY"] = "RGB",
5050
seed: int | None = None,
5151
*,
5252
keep_categorical_as_strings: bool = False,
@@ -90,7 +90,7 @@ def __init__(
9090
@type keep_aspect_ratio: bool
9191
@param keep_aspect_ratio: Whether to keep the aspect ratio of the
9292
images. Defaults to C{True}.
93-
@type color_space: Literal["RGB", "BGR"]
93+
@type color_space: Literal["RGB", "BGR", "GRAY"]
9494
@param color_space: The color space of the output images. Defaults
9595
to C{"RGB"}.
9696
@type seed: Optional[int]
@@ -117,7 +117,7 @@ def __init__(
117117
"""
118118

119119
self.exclude_empty_annotations = exclude_empty_annotations
120-
self.color_space: Literal["RGB", "BGR"] = color_space
120+
self.color_space: Literal["RGB", "BGR", "GRAY"] = color_space
121121
self.height = height
122122
self.width = width
123123

@@ -247,13 +247,15 @@ def __getitem__(self, idx: int) -> LoaderOutput:
247247
else:
248248
img, labels = self._load_with_augmentations(idx)
249249

250+
if not self.exclude_empty_annotations:
251+
img, labels = self._add_empty_annotations(img, labels)
252+
250253
if self.color_space == "BGR":
251254
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
255+
elif self.color_space == "GRAY":
256+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[..., np.newaxis]
252257

253-
if self.exclude_empty_annotations:
254-
return img, labels
255-
256-
return self._add_empty_annotations(img, labels)
258+
return img, labels
257259

258260
def _add_empty_annotations(
259261
self, img: np.ndarray, labels: Labels

tests/test_data/test_loader.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Any
66

7+
import cv2
78
import numpy as np
89

910
from luxonis_ml.data import (
@@ -129,7 +130,9 @@ def set_seed(seed: int):
129130

130131

131132
def create_loader(
132-
storage_url: str, tempdir: Path, augmentation_config: list[Params]
133+
storage_url: str,
134+
tempdir: Path,
135+
**kwargs,
133136
) -> LuxonisLoader:
134137
with set_seed(42):
135138
dataset = LuxonisParser(
@@ -145,7 +148,7 @@ def create_loader(
145148
width=512,
146149
view="train",
147150
seed=42,
148-
augmentation_config=augmentation_config,
151+
**kwargs,
149152
)
150153

151154

@@ -462,10 +465,14 @@ def generator() -> DatasetIterator:
462465

463466

464467
def test_dataset_reproducibility(storage_url: str, tempdir: Path):
465-
loader1 = create_loader(storage_url, tempdir, AUGMENTATIONS_CONFIG)
468+
loader1 = create_loader(
469+
storage_url, tempdir, augmentation_config=AUGMENTATIONS_CONFIG
470+
)
466471
run1 = [ann for _, ann in loader1]
467472

468-
loader2 = create_loader(storage_url, tempdir, AUGMENTATIONS_CONFIG)
473+
loader2 = create_loader(
474+
storage_url, tempdir, augmentation_config=AUGMENTATIONS_CONFIG
475+
)
469476
run2 = [ann for _, ann in loader2]
470477

471478
assert all(
@@ -533,7 +540,9 @@ def round_nested_list(
533540
),
534541
}
535542

536-
loader_aug = create_loader(storage_url, tempdir, AUGMENTATIONS_CONFIG)
543+
loader_aug = create_loader(
544+
storage_url, tempdir, augmentation_config=AUGMENTATIONS_CONFIG
545+
)
537546
new_aug_annotations = [convert_annotation(ann) for _, ann in loader_aug]
538547

539548
original_aug_annotations = load_annotations(
@@ -550,3 +559,22 @@ def round_nested_list(
550559
new_mask = rle_to_mask(new_ann["segmentation"], 512, 512)
551560
diff = np.count_nonzero(orig_mask != new_mask)
552561
assert diff <= 50
562+
563+
564+
def test_colorspace(storage_url: str, tempdir: Path):
565+
loader = create_loader(storage_url, tempdir)
566+
rgb_img, _ = next(iter(loader))
567+
assert len(rgb_img.shape) == 3
568+
assert rgb_img.shape[2] == 3
569+
loader = create_loader(storage_url, tempdir, color_space="BGR")
570+
bgr_img, _ = next(iter(loader))
571+
assert len(bgr_img.shape) == 3
572+
assert bgr_img.shape[2] == 3
573+
assert np.array_equal(rgb_img, bgr_img[:, :, ::-1])
574+
loader = create_loader(storage_url, tempdir, color_space="GRAY")
575+
gray_img, _ = next(iter(loader))
576+
assert len(gray_img.shape) == 3
577+
assert gray_img.shape[2] == 1
578+
assert np.array_equal(
579+
cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), gray_img[:, :, 0]
580+
)

0 commit comments

Comments
 (0)