44from pathlib import Path
55from typing import Any
66
7+ import cv2
78import numpy as np
89
910from luxonis_ml .data import (
@@ -129,7 +130,9 @@ def set_seed(seed: int):
129130
130131
131132def create_loader (
132- storage_url : str , tempdir : Path , augmentation_config : list [Params ]
133+ storage_url : str ,
134+ tempdir : Path ,
135+ ** kwargs ,
133136) -> LuxonisLoader :
134137 with set_seed (42 ):
135138 dataset = LuxonisParser (
@@ -145,7 +148,7 @@ def create_loader(
145148 width = 512 ,
146149 view = "train" ,
147150 seed = 42 ,
148- augmentation_config = augmentation_config ,
151+ ** kwargs ,
149152 )
150153
151154
@@ -462,10 +465,14 @@ def generator() -> DatasetIterator:
462465
463466
464467def test_dataset_reproducibility (storage_url : str , tempdir : Path ):
465- loader1 = create_loader (storage_url , tempdir , AUGMENTATIONS_CONFIG )
468+ loader1 = create_loader (
469+ storage_url , tempdir , augmentation_config = AUGMENTATIONS_CONFIG
470+ )
466471 run1 = [ann for _ , ann in loader1 ]
467472
468- loader2 = create_loader (storage_url , tempdir , AUGMENTATIONS_CONFIG )
473+ loader2 = create_loader (
474+ storage_url , tempdir , augmentation_config = AUGMENTATIONS_CONFIG
475+ )
469476 run2 = [ann for _ , ann in loader2 ]
470477
471478 assert all (
@@ -533,7 +540,9 @@ def round_nested_list(
533540 ),
534541 }
535542
536- loader_aug = create_loader (storage_url , tempdir , AUGMENTATIONS_CONFIG )
543+ loader_aug = create_loader (
544+ storage_url , tempdir , augmentation_config = AUGMENTATIONS_CONFIG
545+ )
537546 new_aug_annotations = [convert_annotation (ann ) for _ , ann in loader_aug ]
538547
539548 original_aug_annotations = load_annotations (
@@ -550,3 +559,22 @@ def round_nested_list(
550559 new_mask = rle_to_mask (new_ann ["segmentation" ], 512 , 512 )
551560 diff = np .count_nonzero (orig_mask != new_mask )
552561 assert diff <= 50
562+
563+
564+ def test_colorspace (storage_url : str , tempdir : Path ):
565+ loader = create_loader (storage_url , tempdir )
566+ rgb_img , _ = next (iter (loader ))
567+ assert len (rgb_img .shape ) == 3
568+ assert rgb_img .shape [2 ] == 3
569+ loader = create_loader (storage_url , tempdir , color_space = "BGR" )
570+ bgr_img , _ = next (iter (loader ))
571+ assert len (bgr_img .shape ) == 3
572+ assert bgr_img .shape [2 ] == 3
573+ assert np .array_equal (rgb_img , bgr_img [:, :, ::- 1 ])
574+ loader = create_loader (storage_url , tempdir , color_space = "GRAY" )
575+ gray_img , _ = next (iter (loader ))
576+ assert len (gray_img .shape ) == 3
577+ assert gray_img .shape [2 ] == 1
578+ assert np .array_equal (
579+ cv2 .cvtColor (rgb_img , cv2 .COLOR_RGB2GRAY ), gray_img [:, :, 0 ]
580+ )
0 commit comments