Skip to content

Commit f308a4c

Browse files
authored
Fix failing keypoints loader (#284)
1 parent ba75c07 commit f308a4c

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

luxonis_ml/data/augmentations/albumentations_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,10 @@ def postprocess(
546546
if task_name not in bboxes_indices:
547547
if "bboxes" in self.targets.values():
548548
bbox_ordering = np.array([], dtype=int)
549+
elif target_type == "keypoints":
550+
bbox_ordering = np.arange(
551+
array.shape[0] // n_keypoints[target_name]
552+
)
549553
else:
550554
bbox_ordering = np.arange(array.shape[0])
551555
else:

tests/test_data/test_dataset.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,32 @@ def generator() -> DatasetIterator:
700700
dataset = create_dataset(dataset_name, generator())
701701

702702
assert dataset.get_classes() == {"": {"person": 0}}
703+
704+
705+
@pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")
706+
def test_keypoints_solo(dataset_name: str, tempdir: Path):
707+
def generator() -> DatasetIterator:
708+
for i in range(4):
709+
img = create_image(i, tempdir)
710+
yield {
711+
"file": img,
712+
"annotation": {
713+
"class": "person",
714+
"keypoints": {"keypoints": [[0.1, 0.1, 0], [0.2, 0.2, 1]]},
715+
},
716+
}
717+
718+
augs = [
719+
{"name": "Normalize"},
720+
{"name": "Defocus", "params": {"p": 1}},
721+
{
722+
"name": "Mosaic4",
723+
"params": {"out_width": 512, "out_height": 512, "p": 1},
724+
},
725+
]
726+
dataset = create_dataset(dataset_name, generator())
727+
loader = LuxonisLoader(
728+
dataset, height=512, width=512, augmentation_config=augs
729+
)
730+
for _ in loader:
731+
pass

0 commit comments

Comments
 (0)