Skip to content

Commit 539228b

Browse files
authored
Ensure Keypoint Visibility Is Updated After Spatial Transforms (#346)
1 parent be79c8e commit 539228b

File tree

5 files changed

+58
-8
lines changed

5 files changed

+58
-8
lines changed

luxonis_ml/data/__main__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import random
21
import shutil
32
from collections.abc import Iterator
43
from pathlib import Path
@@ -334,9 +333,6 @@ def inspect(
334333
):
335334
"""Inspects images and annotations in a dataset."""
336335
check_exists(name, bucket_storage)
337-
if deterministic:
338-
np.random.seed(42)
339-
random.seed(42)
340336

341337
view = view or ["train"]
342338
dataset = LuxonisDataset(name, bucket_storage=bucket_storage)
@@ -349,7 +345,12 @@ def inspect(
349345
if aug_config is not None:
350346
h, w, _ = loader[0][0].shape
351347
loader.augmentations = loader._init_augmentations(
352-
"albumentations", aug_config, h, w, not ignore_aspect_ratio
348+
augmentation_engine="albumentations",
349+
augmentation_config=aug_config,
350+
height=h,
351+
width=w,
352+
keep_aspect_ratio=not ignore_aspect_ratio,
353+
seed=42 if deterministic else None,
353354
)
354355

355356
if len(dataset) == 0:

luxonis_ml/data/augmentations/albumentations_engine.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import albumentations as A
88
import numpy as np
9+
from albumentations.core.composition import TransformsSeqType
910
from loguru import logger
1011
from typing_extensions import override
1112

@@ -396,6 +397,20 @@ def __init__(
396397
f"Only subclasses of `A.BasicTransform` are allowed. "
397398
)
398399

400+
wrapped_spatial_ops: TransformsSeqType = []
401+
if "keypoints" in targets.values():
402+
for op in spatial_transforms:
403+
wrapped_spatial_ops.append(op)
404+
wrapped_spatial_ops.append(
405+
A.Lambda(
406+
image=lambda img, **kw: img,
407+
keypoints=self._mark_invisible_keypoints,
408+
p=1.0,
409+
)
410+
)
411+
else:
412+
wrapped_spatial_ops = spatial_transforms
413+
399414
if resize_transform is None:
400415
if keep_aspect_ratio:
401416
resize_transform = LetterboxResize(height=height, width=width)
@@ -424,7 +439,7 @@ def get_params(is_custom: bool = False) -> dict[str, Any]:
424439
batch_transforms, **get_params(is_custom=True)
425440
)
426441
self.spatial_transform = wrap_transform(
427-
A.Compose(spatial_transforms, **get_params())
442+
A.Compose(wrapped_spatial_ops, **get_params())
428443
)
429444
self.pixel_transform = wrap_transform(
430445
A.Compose(pixel_transforms), is_pixel=True
@@ -625,6 +640,26 @@ def postprocess(
625640

626641
return out_image, out_labels
627642

643+
@staticmethod
644+
def _mark_invisible_keypoints(
645+
keypoints: np.ndarray, **kwargs
646+
) -> np.ndarray:
647+
"""
648+
keypoints: np.ndarray of shape (N,6) columns = [x, y, z, a, s, v]
649+
Zeroes out the visibility (last) column if (x,y) is out of image bounds.
650+
"""
651+
shape = kwargs.get("shape")
652+
if shape is None:
653+
raise ValueError(
654+
"Shape must be provided in kwargs to mark invisible keypoints."
655+
)
656+
h, w = shape[:2]
657+
kps = keypoints.copy()
658+
xs, ys = kps[:, 0], kps[:, 1]
659+
oob = (xs < 0) | (ys < 0) | (xs >= w) | (ys >= h)
660+
kps[oob, -1] = 0.0
661+
return kps
662+
628663
@staticmethod
629664
def create_transformation(
630665
config: AlbumentationConfigItem,

tests/test_data/test_augmentations/test_special.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,10 @@ def test_skip_augmentations():
7777
t.__class__.__name__ for t in augmentations.batch_transform.transforms
7878
]
7979

80-
assert spatial_transform_names == ["Perspective", "Rotate"]
80+
assert spatial_transform_names == [
81+
"Perspective",
82+
"Lambda",
83+
"Rotate",
84+
"Lambda",
85+
]
8186
assert batched_transform_names == ["Mosaic4"]

tests/test_data/test_export.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def test_dir_parser(
4141
.to_dict(as_series=False)
4242
)
4343
anns = {k: sorted(v) for k, v in anns.items()}
44+
splits = dataset.get_splits()
45+
assert splits is not None
46+
splits = {split: sorted(files) for split, files in splits.items()}
4447

4548
zip_result = dataset.export(tempdir / "exported")
4649
zip_path = zip_result[0] if isinstance(zip_result, list) else zip_result
@@ -67,6 +70,12 @@ def test_dir_parser(
6770
.to_dict(as_series=False)
6871
)
6972
imported_anns = {k: sorted(v) for k, v in imported_anns.items()}
73+
imported_splits = exported_dataset.get_splits()
74+
assert imported_splits is not None
75+
imported_splits = {
76+
split: sorted(files) for split, files in imported_splits.items()
77+
}
78+
assert imported_splits == splits
7079
del imported_metadata["tasks"]
7180
del imported_metadata["skeletons"]
7281
assert imported_metadata == metadata

tests/test_data/test_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def round_nested_list(
545545
new_aug_annotations = [convert_annotation(ann) for _, ann in loader_aug]
546546

547547
original_aug_annotations = load_annotations(
548-
"test_augmentation_reproducibility_labels.json"
548+
"test_augmentation_reproducibility_annotations.json"
549549
)
550550

551551
for orig_ann, new_ann in zip(

0 commit comments

Comments
 (0)