Skip to content

Commit 05b36c3

Browse files
authored
fix kpts and bbox edge cases (#243)
1 parent 6b059e4 commit 05b36c3

File tree

3 files changed

+349
-20
lines changed

3 files changed

+349
-20
lines changed

luxonis_ml/data/augmentations/base_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
config: Iterable[Params],
2525
keep_aspect_ratio: bool,
2626
is_validation_pipeline: bool,
27-
min_bbox_visibility: float = 0,
27+
min_bbox_visibility: float = 0.0,
2828
):
2929
"""Initialize augmentation pipeline from configuration.
3030
@@ -52,8 +52,7 @@ def __init__(
5252
@param is_validation_pipeline: Whether this is a validation
5353
pipeline (in which case some augmentations are skipped)
5454
@type min_bbox_visibility: float
55-
@param min_bbox_visibility: Minimum area of a bounding box to be
56-
considered visible.
55+
@param min_bbox_visibility: Minimum fraction of the original bounding box that must remain visible after augmentation.
5756
"""
5857
...
5958

luxonis_ml/data/augmentations/utils.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,24 @@ def postprocess_mask(mask: np.ndarray) -> np.ndarray:
3939

4040

4141
def postprocess_bboxes(bboxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
42+
area_threshold = 0.0004 # 0.02 * 0.02 Small area threshold to remove invalid bboxes and respective keypoints.
4243
if bboxes.size == 0:
43-
return np.zeros((0, 6)), np.zeros((0, 1), dtype=np.uint8)
44-
44+
return np.zeros((0, 5)), np.zeros((0,), dtype=np.uint8)
4545
ordering = bboxes[:, -1]
46-
out_bboxes = bboxes[:, :-1]
47-
out_bboxes[:, 2] -= out_bboxes[:, 0]
48-
out_bboxes[:, 3] -= out_bboxes[:, 1]
46+
raw_bboxes = bboxes[:, :-1]
47+
raw_bboxes[:, 2] -= raw_bboxes[:, 0]
48+
raw_bboxes[:, 3] -= raw_bboxes[:, 1]
49+
widths = raw_bboxes[:, 2]
50+
heights = raw_bboxes[:, 3]
51+
areas = widths * heights
52+
53+
valid_mask = areas >= area_threshold
54+
raw_bboxes = raw_bboxes[valid_mask]
55+
refined_ordering = ordering[valid_mask]
56+
57+
out_bboxes = raw_bboxes[:, [4, 0, 1, 2, 3]]
4958

50-
return out_bboxes[:, [4, 0, 1, 2, 3]], ordering.astype(np.uint8)
59+
return out_bboxes, refined_ordering.astype(np.uint8)
5160

5261

5362
def postprocess_keypoints(
@@ -57,12 +66,31 @@ def postprocess_keypoints(
5766
image_width: int,
5867
n_keypoints: int,
5968
) -> np.ndarray:
60-
keypoints = np.reshape(keypoints[:, :3], (-1, n_keypoints * 3))[
61-
bboxes_ordering
62-
]
63-
np.maximum(keypoints, 0, out=keypoints)
64-
keypoints[..., ::3] /= image_width
65-
keypoints[..., 1::3] /= image_height
69+
keypoints = keypoints[:, : (n_keypoints * 3)]
70+
keypoints = keypoints.reshape(-1, n_keypoints, 3)
71+
72+
keypoints = keypoints[bboxes_ordering]
73+
74+
x = keypoints[..., 0]
75+
y = keypoints[..., 1]
76+
v = keypoints[..., 2]
77+
78+
in_bounds = (x >= 0) & (x < image_width) & (y >= 0) & (y < image_height)
79+
80+
v[~in_bounds] = 0
81+
82+
x = np.clip(x, 0, image_width)
83+
y = np.clip(y, 0, image_height)
84+
85+
x /= image_width
86+
y /= image_height
87+
88+
keypoints[..., 0] = x
89+
keypoints[..., 1] = y
90+
keypoints[..., 2] = v
91+
92+
keypoints = keypoints.reshape(-1, n_keypoints * 3)
93+
6694
return keypoints
6795

6896

0 commit comments

Comments
 (0)