Skip to content

Commit d33cd07

Browse files
committed
Add byte tracker code
1 parent fb3ea03 commit d33cd07

File tree

8 files changed

+833
-10
lines changed

8 files changed

+833
-10
lines changed

abraia/detect.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,19 @@ def get_mask(row, box, size):
103103
def mask_to_polygon(mask, origin):
104104
"""Returns the largest bounding polygon based on the segmentation mask."""
105105
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
106-
polygon = np.array([])
107-
for contour in contours:
108-
contour = contour.reshape(-1, 2)
109-
polygon = contour if contour.shape[0] > polygon.shape[0] else polygon
106+
lengths = [len(contour) for contour in contours]
107+
polygon = contours[np.argmax(lengths)].reshape(-1, 2)
110108
polygon = polygon + np.array(origin)
111109
return polygon.tolist()
112110

113111

112+
def approximate_polygon(polygon, approx=0.02):
113+
contour = np.array([polygon]).astype(np.int32)
114+
epsilon = approx * cv2.arcLength(contour, True)
115+
contour = cv2.approxPolyDP(contour, epsilon, True)
116+
return contour.reshape(-1, 2).tolist()
117+
118+
114119
def process_output(outputs, size, shape, classes, confidence, iou_threshold):
115120
"""Converts the RAW model output from YOLOv8 to an array of detected
116121
objects, containing the bounding box, label and the probability.
@@ -147,7 +152,8 @@ def process_output(outputs, size, shape, classes, confidence, iou_threshold):
147152
mask = get_mask(mask, result['box'], size)
148153
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
149154
result['polygon'] = mask_to_polygon(mask, (x, y))
150-
result.pop('mask', None)
155+
# result.pop('mask', None)
156+
result['mask'] = mask
151157
return results
152158

153159

abraia/draw.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,33 @@ def draw_filled_rectangle(img, rect, color, opacity = 1):
5151

5252

5353
def draw_polygon(img, polygon, color, thickness = 2):
54-
polygon = np.round(polygon).astype(np.int32)
55-
cv2.polylines(img, [polygon], True, color, thickness)
54+
points = np.round(polygon).astype(np.int32)
55+
cv2.polylines(img, [points], True, color, thickness)
5656
return img
5757

5858

5959
def draw_filled_polygon(img, polygon, color, opacity = 1):
60-
polygon = np.round(polygon).astype(np.int32)
60+
points = np.round(polygon).astype(np.int32)
6161
if opacity == 1:
62-
cv2.fillPoly(img, [polygon], color)
62+
cv2.fillPoly(img, [points], color)
6363
else:
6464
img_copy = img.copy()
65-
cv2.fillPoly(img_copy, [polygon], color)
65+
cv2.fillPoly(img_copy, [points], color)
6666
cv2.addWeighted(img_copy, opacity, img, 1 - opacity, 0, img)
6767
return img
6868

6969

70+
def draw_blurred_polygon(img, polygon):
71+
w_k = int(0.1 * max(img.shape[:2]))
72+
w_k = w_k + 1 if w_k % 2 == 0 else w_k
73+
blurred_img = cv2.GaussianBlur(img, (w_k, w_k), 0)
74+
points = np.round(polygon).astype(np.int32)
75+
mask = np.zeros(img.shape, dtype=np.uint8)
76+
mask = cv2.fillPoly(mask, [points], (255, 255, 255))
77+
img = np.where(mask==0, img, blurred_img)
78+
return img
79+
80+
7081
def draw_text(img, text, point, background_color = None, text_color = (255, 255, 255),
7182
text_scale = 0.8, padding = 8):
7283
text_font, text_thickness = cv2.FONT_HERSHEY_DUPLEX, 1

abraia/tracker/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
from .byte_tracker import BYTETracker as ByteTracker
3+
4+
__all__ = ['ByteTracker']

abraia/tracker/basetrack.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
from collections import OrderedDict
3+
4+
5+
class TrackState(object):
6+
New = 0
7+
Tracked = 1
8+
Lost = 2
9+
Removed = 3
10+
11+
12+
class BaseTrack(object):
13+
_count = 0
14+
15+
track_id = 0
16+
is_activated = False
17+
state = TrackState.New
18+
19+
history = OrderedDict()
20+
features = []
21+
curr_feature = None
22+
score = 0
23+
start_frame = 0
24+
frame_id = 0
25+
time_since_update = 0
26+
27+
# multi-camera
28+
location = (np.inf, np.inf)
29+
30+
@property
31+
def end_frame(self):
32+
return self.frame_id
33+
34+
@staticmethod
35+
def next_id():
36+
BaseTrack._count += 1
37+
return BaseTrack._count
38+
39+
def activate(self, *args):
40+
raise NotImplementedError
41+
42+
def predict(self):
43+
raise NotImplementedError
44+
45+
def update(self, *args, **kwargs):
46+
raise NotImplementedError
47+
48+
def mark_lost(self):
49+
self.state = TrackState.Lost
50+
51+
def mark_removed(self):
52+
self.state = TrackState.Removed

0 commit comments

Comments
 (0)