Skip to content

Commit 62e6cb9

Browse files
committed
Fix loading logic for datasets with multiple tasks
1 parent 3716ba7 commit 62e6cb9

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

luxonis_ml/data/loaders/luxonis_loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ def _load_image_with_annotations(
323323

324324
labels[task] = (array, anns[0]._label_type)
325325

326-
if not labels:
327-
for task in self.classes_by_task.keys():
326+
missing_tasks = set(self.classes_by_task) - set(labels_by_task)
327+
if missing_tasks:
328+
for task in missing_tasks:
329+
class_mapping_len = len(self.class_mappings[task])
328330
if task == LabelType.SEGMENTATION:
329331
empty_array = np.zeros(
330-
(len(self.class_mappings[task]), height, width),
332+
(class_mapping_len, height, width),
331333
dtype=np.uint8,
332334
)
333335
elif task == LabelType.BOUNDINGBOX:
@@ -336,7 +338,7 @@ def _load_image_with_annotations(
336338
empty_array = np.zeros((0, 3), dtype=np.float32)
337339
elif task == LabelType.CLASSIFICATION:
338340
empty_array = np.zeros(
339-
(0, len(self.class_mappings[task])), dtype=np.float32
341+
(0, class_mapping_len), dtype=np.float32
340342
)
341343
labels[task] = (empty_array, task)
342344

tests/test_data/test_dataset.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,63 @@ def generator():
595595
loader = LuxonisLoader(dataset, augmentations=augments)
596596
for _, labels in loader:
597597
assert labels == {}
598+
599+
600+
def test_partial_labels():
601+
dataset = LuxonisDataset("__partial_labels", delete_existing=True)
602+
603+
def generator():
604+
for i in range(8):
605+
img = make_image(i)
606+
if i < 2:
607+
yield {
608+
"file": img,
609+
}
610+
elif i < 4:
611+
yield {
612+
"file": img,
613+
"annotation": {
614+
"type": "classification",
615+
"class": "dog",
616+
},
617+
}
618+
elif i < 6:
619+
yield {
620+
"file": img,
621+
"annotation": {
622+
"type": "boundingbox",
623+
"class": "dog",
624+
"x": 0.1,
625+
"y": 0.1,
626+
"w": 0.1,
627+
"h": 0.1,
628+
},
629+
}
630+
yield {
631+
"file": img,
632+
"annotation": {
633+
"type": "keypoints",
634+
"class": "dog",
635+
"keypoints": [[0.1, 0.1, 0], [0.2, 0.2, 1]],
636+
},
637+
}
638+
elif i < 8:
639+
yield {
640+
"file": img,
641+
"annotation": {
642+
"type": "mask",
643+
"class": "dog",
644+
"mask": np.random.rand(512, 512) > 0.5,
645+
},
646+
}
647+
648+
dataset.add(generator())
649+
dataset.make_splits([1, 0, 0])
650+
651+
augments = Augmentations([512, 512], [{"name": "Rotate", "params": {}}])
652+
loader = LuxonisLoader(dataset, augmentations=augments, view="train")
653+
for _, labels in loader:
654+
assert labels.get("boundingbox") is not None
655+
assert labels.get("classification") is not None
656+
assert labels.get("segmentation") is not None
657+
assert labels.get("keypoints") is not None

0 commit comments

Comments
 (0)