@@ -241,36 +241,23 @@ def apply_to_array(
241241 ...
242242 """
243243
244- def _should_skip_augmentation (
244+ def _check_augmentation_warnings (
245245 self , config_item : dict [str , Any ], available_target_types : set
246- ) -> bool :
247- skip_rules = {
248- "keypoints" : [
249- "HorizontalFlip" ,
250- "VerticalFlip" ,
251- "Flip" ,
252- ],
253- }
246+ ) -> None :
254247 augmentation_name = config_item ["name" ]
255- skipped_for = [
256- target_type
257- for target_type , skip_list in skip_rules .items ()
258- if target_type in available_target_types
259- and augmentation_name in skip_list
260- ]
261- if skipped_for :
262- extra_msg = ""
263- if "keypoints" in skipped_for and augmentation_name in [
264- "HorizontalFlip" ,
265- "VerticalFlip" ,
266- "Flip" ,
267- ]:
268- extra_msg = " For keypoints, please use 'HorizontalSymetricKeypointsFlip' or 'VerticalSymetricKeypointsFlip'."
248+
249+ if "keypoints" in available_target_types and augmentation_name in [
250+ "HorizontalFlip" ,
251+ "VerticalFlip" ,
252+ "Transpose" ,
253+ ]:
269254 logger .warning (
270- f"Skipping augmentation '{ augmentation_name } ' due to known issues for { skipped_for } target types. { extra_msg } "
255+ f"Using '{ augmentation_name } ' with keypoints."
256+ "If your dataset contains symmetric keypoints (e.g. left/right arms),"
257+ "you should use our custom HorizontalSymetricKeypointsFlip,"
258+ "VerticalSymetricKeypointsFlip, or TransposeSymmetricKeypoints"
259+ "to ensure keypoints are correctly reordered."
271260 )
272- return True
273- return False
274261
275262 @override
276263 def __init__ (
@@ -372,13 +359,10 @@ def __init__(
372359 available_target_types = set (self .targets .values ())
373360
374361 for config_item in config :
375- if self ._should_skip_augmentation (
362+ self ._check_augmentation_warnings (
376363 config_item , available_target_types
377- ):
378- continue
379-
364+ )
380365 cfg = AlbumentationConfigItem (** config_item ) # type: ignore
381-
382366 transform = self .create_transformation (cfg )
383367
384368 if cfg .use_for_resizing :
@@ -431,9 +415,11 @@ def get_params(is_custom: bool = False) -> dict[str, Any]:
431415 "keypoint_params" : A .KeypointParams (
432416 format = "xy" , remove_invisible = False
433417 ),
434- "additional_targets" : self .targets
435- if is_custom
436- else targets_without_instance_mask ,
418+ "additional_targets" : (
419+ self .targets
420+ if is_custom
421+ else targets_without_instance_mask
422+ ),
437423 "seed" : seed ,
438424 }
439425
0 commit comments