@@ -39,15 +39,24 @@ def postprocess_mask(mask: np.ndarray) -> np.ndarray:
3939
4040
4141def postprocess_bboxes (bboxes : np .ndarray ) -> Tuple [np .ndarray , np .ndarray ]:
42+ area_threshold = 0.0004 # 0.02 * 0.02 Small area threshold to remove invalid bboxes and respective keypoints.
4243 if bboxes .size == 0 :
43- return np .zeros ((0 , 6 )), np .zeros ((0 , 1 ), dtype = np .uint8 )
44-
44+ return np .zeros ((0 , 5 )), np .zeros ((0 ,), dtype = np .uint8 )
4545 ordering = bboxes [:, - 1 ]
46- out_bboxes = bboxes [:, :- 1 ]
47- out_bboxes [:, 2 ] -= out_bboxes [:, 0 ]
48- out_bboxes [:, 3 ] -= out_bboxes [:, 1 ]
46+ raw_bboxes = bboxes [:, :- 1 ]
47+ raw_bboxes [:, 2 ] -= raw_bboxes [:, 0 ]
48+ raw_bboxes [:, 3 ] -= raw_bboxes [:, 1 ]
49+ widths = raw_bboxes [:, 2 ]
50+ heights = raw_bboxes [:, 3 ]
51+ areas = widths * heights
52+
53+ valid_mask = areas >= area_threshold
54+ raw_bboxes = raw_bboxes [valid_mask ]
55+ refined_ordering = ordering [valid_mask ]
56+
57+ out_bboxes = raw_bboxes [:, [4 , 0 , 1 , 2 , 3 ]]
4958
50- return out_bboxes [:, [ 4 , 0 , 1 , 2 , 3 ]], ordering .astype (np .uint8 )
59+ return out_bboxes , refined_ordering .astype (np .uint8 )
5160
5261
5362def postprocess_keypoints (
@@ -57,12 +66,31 @@ def postprocess_keypoints(
5766 image_width : int ,
5867 n_keypoints : int ,
5968) -> np .ndarray :
60- keypoints = np .reshape (keypoints [:, :3 ], (- 1 , n_keypoints * 3 ))[
61- bboxes_ordering
62- ]
63- np .maximum (keypoints , 0 , out = keypoints )
64- keypoints [..., ::3 ] /= image_width
65- keypoints [..., 1 ::3 ] /= image_height
69+ keypoints = keypoints [:, : (n_keypoints * 3 )]
70+ keypoints = keypoints .reshape (- 1 , n_keypoints , 3 )
71+
72+ keypoints = keypoints [bboxes_ordering ]
73+
74+ x = keypoints [..., 0 ]
75+ y = keypoints [..., 1 ]
76+ v = keypoints [..., 2 ]
77+
78+ in_bounds = (x >= 0 ) & (x < image_width ) & (y >= 0 ) & (y < image_height )
79+
80+ v [~ in_bounds ] = 0
81+
82+ x = np .clip (x , 0 , image_width )
83+ y = np .clip (y , 0 , image_height )
84+
85+ x /= image_width
86+ y /= image_height
87+
88+ keypoints [..., 0 ] = x
89+ keypoints [..., 1 ] = y
90+ keypoints [..., 2 ] = v
91+
92+ keypoints = keypoints .reshape (- 1 , n_keypoints * 3 )
93+
6694 return keypoints
6795
6896
0 commit comments