|
6 | 6 |
|
7 | 7 | import albumentations as A |
8 | 8 | import numpy as np |
| 9 | +from albumentations.core.composition import TransformsSeqType |
9 | 10 | from loguru import logger |
10 | 11 | from typing_extensions import override |
11 | 12 |
|
@@ -396,6 +397,20 @@ def __init__( |
396 | 397 | f"Only subclasses of `A.BasicTransform` are allowed. " |
397 | 398 | ) |
398 | 399 |
|
| 400 | + wrapped_spatial_ops: TransformsSeqType = [] |
| 401 | + if "keypoints" in targets.values(): |
| 402 | + for op in spatial_transforms: |
| 403 | + wrapped_spatial_ops.append(op) |
| 404 | + wrapped_spatial_ops.append( |
| 405 | + A.Lambda( |
| 406 | + image=lambda img, **kw: img, |
| 407 | + keypoints=self._mark_invisible_keypoints, |
| 408 | + p=1.0, |
| 409 | + ) |
| 410 | + ) |
| 411 | + else: |
| 412 | + wrapped_spatial_ops = spatial_transforms |
| 413 | + |
399 | 414 | if resize_transform is None: |
400 | 415 | if keep_aspect_ratio: |
401 | 416 | resize_transform = LetterboxResize(height=height, width=width) |
@@ -424,7 +439,7 @@ def get_params(is_custom: bool = False) -> dict[str, Any]: |
424 | 439 | batch_transforms, **get_params(is_custom=True) |
425 | 440 | ) |
426 | 441 | self.spatial_transform = wrap_transform( |
427 | | - A.Compose(spatial_transforms, **get_params()) |
| 442 | + A.Compose(wrapped_spatial_ops, **get_params()) |
428 | 443 | ) |
429 | 444 | self.pixel_transform = wrap_transform( |
430 | 445 | A.Compose(pixel_transforms), is_pixel=True |
@@ -625,6 +640,26 @@ def postprocess( |
625 | 640 |
|
626 | 641 | return out_image, out_labels |
627 | 642 |
|
| 643 | + @staticmethod |
| 644 | + def _mark_invisible_keypoints( |
| 645 | + keypoints: np.ndarray, **kwargs |
| 646 | + ) -> np.ndarray: |
| 647 | + """ |
| 648 | + keypoints: np.ndarray of shape (N,6) columns = [x, y, z, a, s, v] |
| 649 | + Zeroes out the visibility (last) column if (x,y) is out of image bounds. |
| 650 | + """ |
| 651 | + shape = kwargs.get("shape") |
| 652 | + if shape is None: |
| 653 | + raise ValueError( |
| 654 | + "Shape must be provided in kwargs to mark invisible keypoints." |
| 655 | + ) |
| 656 | + h, w = shape[:2] |
| 657 | + kps = keypoints.copy() |
| 658 | + xs, ys = kps[:, 0], kps[:, 1] |
| 659 | + oob = (xs < 0) | (ys < 0) | (xs >= w) | (ys >= h) |
| 660 | + kps[oob, -1] = 0.0 |
| 661 | + return kps |
| 662 | + |
628 | 663 | @staticmethod |
629 | 664 | def create_transformation( |
630 | 665 | config: AlbumentationConfigItem, |
|
0 commit comments