Skip to content
54 changes: 20 additions & 34 deletions luxonis_ml/data/augmentations/albumentations_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,36 +241,23 @@ def apply_to_array(
...
"""

def _should_skip_augmentation(
def _check_augmentation_warnings(
self, config_item: dict[str, Any], available_target_types: set
) -> bool:
skip_rules = {
"keypoints": [
"HorizontalFlip",
"VerticalFlip",
"Flip",
],
}
) -> None:
augmentation_name = config_item["name"]
skipped_for = [
target_type
for target_type, skip_list in skip_rules.items()
if target_type in available_target_types
and augmentation_name in skip_list
]
if skipped_for:
extra_msg = ""
if "keypoints" in skipped_for and augmentation_name in [
"HorizontalFlip",
"VerticalFlip",
"Flip",
]:
extra_msg = " For keypoints, please use 'HorizontalSymetricKeypointsFlip' or 'VerticalSymetricKeypointsFlip'."

if "keypoints" in available_target_types and augmentation_name in [
"HorizontalFlip",
"VerticalFlip",
"Transpose",
]:
logger.warning(
f"Skipping augmentation '{augmentation_name}' due to known issues for {skipped_for} target types. {extra_msg}"
f"Using '{augmentation_name}' with keypoints."
"If your dataset contains symmetric keypoints (e.g. left/right arms),"
"you should use our custom HorizontalSymetricKeypointsFlip,"
"VerticalSymetricKeypointsFlip, or TransposeSymmetricKeypoints"
"to ensure keypoints are correctly reordered."
)
return True
return False

@override
def __init__(
Expand Down Expand Up @@ -372,13 +359,10 @@ def __init__(
available_target_types = set(self.targets.values())

for config_item in config:
if self._should_skip_augmentation(
self._check_augmentation_warnings(
config_item, available_target_types
):
continue

)
cfg = AlbumentationConfigItem(**config_item) # type: ignore

transform = self.create_transformation(cfg)

if cfg.use_for_resizing:
Expand Down Expand Up @@ -431,9 +415,11 @@ def get_params(is_custom: bool = False) -> dict[str, Any]:
"keypoint_params": A.KeypointParams(
format="xy", remove_invisible=False
),
"additional_targets": self.targets
if is_custom
else targets_without_instance_mask,
"additional_targets": (
self.targets
if is_custom
else targets_without_instance_mask
),
"seed": seed,
}

Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def augmentation_config() -> list[Params]:
{"name": "MixUp", "params": {"p": 1.0}},
{"name": "Defocus", "params": {"p": 1.0}},
{"name": "Sharpen", "params": {"p": 1.0}},
{"name": "Flip", "params": {"p": 1.0}},
{"name": "RandomRotate90", "params": {"p": 1.0}},
]

Expand Down
8 changes: 4 additions & 4 deletions tests/test_data/test_augmentations/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def test_skip_augmentations():
{
"name": "Perspective",
},
{
"name": "Flip",
},
{
"name": "HorizontalFlip",
},
Expand Down Expand Up @@ -87,10 +84,13 @@ def test_skip_augmentations():
batched_transform_names = [
t.__class__.__name__ for t in augmentations.batch_transform.transforms
]

assert spatial_transform_names == [
"Perspective",
"Lambda",
"HorizontalFlip",
"Lambda",
"VerticalFlip",
"Lambda",
"Rotate",
"Lambda",
]
Expand Down