Skip to content

Commit 671ef4f

Browse files
Feat/remove flip constraint (#370)
Co-authored-by: klemen1999 <[email protected]>
1 parent 8a168f0 commit 671ef4f

File tree

3 files changed

+24
-39
lines changed

3 files changed

+24
-39
lines changed

luxonis_ml/data/augmentations/albumentations_engine.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -241,36 +241,23 @@ def apply_to_array(
241241
...
242242
"""
243243

244-
def _should_skip_augmentation(
244+
def _check_augmentation_warnings(
245245
self, config_item: dict[str, Any], available_target_types: set
246-
) -> bool:
247-
skip_rules = {
248-
"keypoints": [
249-
"HorizontalFlip",
250-
"VerticalFlip",
251-
"Flip",
252-
],
253-
}
246+
) -> None:
254247
augmentation_name = config_item["name"]
255-
skipped_for = [
256-
target_type
257-
for target_type, skip_list in skip_rules.items()
258-
if target_type in available_target_types
259-
and augmentation_name in skip_list
260-
]
261-
if skipped_for:
262-
extra_msg = ""
263-
if "keypoints" in skipped_for and augmentation_name in [
264-
"HorizontalFlip",
265-
"VerticalFlip",
266-
"Flip",
267-
]:
268-
extra_msg = " For keypoints, please use 'HorizontalSymetricKeypointsFlip' or 'VerticalSymetricKeypointsFlip'."
248+
249+
if "keypoints" in available_target_types and augmentation_name in [
250+
"HorizontalFlip",
251+
"VerticalFlip",
252+
"Transpose",
253+
]:
269254
logger.warning(
270-
f"Skipping augmentation '{augmentation_name}' due to known issues for {skipped_for} target types. {extra_msg}"
255+
f"Using '{augmentation_name}' with keypoints."
256+
"If your dataset contains symmetric keypoints (e.g. left/right arms),"
257+
"you should use our custom HorizontalSymetricKeypointsFlip,"
258+
"VerticalSymetricKeypointsFlip, or TransposeSymmetricKeypoints"
259+
"to ensure keypoints are correctly reordered."
271260
)
272-
return True
273-
return False
274261

275262
@override
276263
def __init__(
@@ -372,13 +359,10 @@ def __init__(
372359
available_target_types = set(self.targets.values())
373360

374361
for config_item in config:
375-
if self._should_skip_augmentation(
362+
self._check_augmentation_warnings(
376363
config_item, available_target_types
377-
):
378-
continue
379-
364+
)
380365
cfg = AlbumentationConfigItem(**config_item) # type: ignore
381-
382366
transform = self.create_transformation(cfg)
383367

384368
if cfg.use_for_resizing:
@@ -431,9 +415,11 @@ def get_params(is_custom: bool = False) -> dict[str, Any]:
431415
"keypoint_params": A.KeypointParams(
432416
format="xy", remove_invisible=False
433417
),
434-
"additional_targets": self.targets
435-
if is_custom
436-
else targets_without_instance_mask,
418+
"additional_targets": (
419+
self.targets
420+
if is_custom
421+
else targets_without_instance_mask
422+
),
437423
"seed": seed,
438424
}
439425

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def augmentation_config() -> list[Params]:
125125
{"name": "MixUp", "params": {"p": 1.0}},
126126
{"name": "Defocus", "params": {"p": 1.0}},
127127
{"name": "Sharpen", "params": {"p": 1.0}},
128-
{"name": "Flip", "params": {"p": 1.0}},
129128
{"name": "RandomRotate90", "params": {"p": 1.0}},
130129
]
131130

tests/test_data/test_augmentations/test_special.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def test_skip_augmentations():
3939
{
4040
"name": "Perspective",
4141
},
42-
{
43-
"name": "Flip",
44-
},
4542
{
4643
"name": "HorizontalFlip",
4744
},
@@ -87,10 +84,13 @@ def test_skip_augmentations():
8784
batched_transform_names = [
8885
t.__class__.__name__ for t in augmentations.batch_transform.transforms
8986
]
90-
9187
assert spatial_transform_names == [
9288
"Perspective",
9389
"Lambda",
90+
"HorizontalFlip",
91+
"Lambda",
92+
"VerticalFlip",
93+
"Lambda",
9494
"Rotate",
9595
"Lambda",
9696
]

0 commit comments

Comments
 (0)