Skip to content

Commit

Permalink
Add byte tracker code
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Oct 13, 2024
1 parent fb3ea03 commit d33cd07
Show file tree
Hide file tree
Showing 8 changed files with 833 additions and 10 deletions.
16 changes: 11 additions & 5 deletions abraia/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,19 @@ def get_mask(row, box, size):
def mask_to_polygon(mask, origin):
"""Returns the largest bounding polygon based on the segmentation mask."""
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
polygon = np.array([])
for contour in contours:
contour = contour.reshape(-1, 2)
polygon = contour if contour.shape[0] > polygon.shape[0] else polygon
lengths = [len(contour) for contour in contours]
polygon = contours[np.argmax(lengths)].reshape(-1, 2)
polygon = polygon + np.array(origin)
return polygon.tolist()


def approximate_polygon(polygon, approx=0.02):
contour = np.array([polygon]).astype(np.int32)
epsilon = approx * cv2.arcLength(contour, True)
contour = cv2.approxPolyDP(contour, epsilon, True)
return contour.reshape(-1, 2).tolist()


def process_output(outputs, size, shape, classes, confidence, iou_threshold):
"""Converts the RAW model output from YOLOv8 to an array of detected
objects, containing the bounding box, label and the probability.
Expand Down Expand Up @@ -147,7 +152,8 @@ def process_output(outputs, size, shape, classes, confidence, iou_threshold):
mask = get_mask(mask, result['box'], size)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
result['polygon'] = mask_to_polygon(mask, (x, y))
result.pop('mask', None)
# result.pop('mask', None)
result['mask'] = mask
return results


Expand Down
21 changes: 16 additions & 5 deletions abraia/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,33 @@ def draw_filled_rectangle(img, rect, color, opacity = 1):


def draw_polygon(img, polygon, color, thickness = 2):
polygon = np.round(polygon).astype(np.int32)
cv2.polylines(img, [polygon], True, color, thickness)
points = np.round(polygon).astype(np.int32)
cv2.polylines(img, [points], True, color, thickness)
return img


def draw_filled_polygon(img, polygon, color, opacity = 1):
polygon = np.round(polygon).astype(np.int32)
points = np.round(polygon).astype(np.int32)
if opacity == 1:
cv2.fillPoly(img, [polygon], color)
cv2.fillPoly(img, [points], color)
else:
img_copy = img.copy()
cv2.fillPoly(img_copy, [polygon], color)
cv2.fillPoly(img_copy, [points], color)
cv2.addWeighted(img_copy, opacity, img, 1 - opacity, 0, img)
return img


def draw_blurred_polygon(img, polygon):
w_k = int(0.1 * max(img.shape[:2]))
w_k = w_k + 1 if w_k % 2 == 0 else w_k
blurred_img = cv2.GaussianBlur(img, (w_k, w_k), 0)
points = np.round(polygon).astype(np.int32)
mask = np.zeros(img.shape, dtype=np.uint8)
mask = cv2.fillPoly(mask, [points], (255, 255, 255))
img = np.where(mask==0, img, blurred_img)
return img


def draw_text(img, text, point, background_color = None, text_color = (255, 255, 255),
text_scale = 0.8, padding = 8):
text_font, text_thickness = cv2.FONT_HERSHEY_DUPLEX, 1
Expand Down
4 changes: 4 additions & 0 deletions abraia/tracker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

from .byte_tracker import BYTETracker as ByteTracker

__all__ = ['ByteTracker']
52 changes: 52 additions & 0 deletions abraia/tracker/basetrack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
from collections import OrderedDict


class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3


class BaseTrack(object):
_count = 0

track_id = 0
is_activated = False
state = TrackState.New

history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0

# multi-camera
location = (np.inf, np.inf)

@property
def end_frame(self):
return self.frame_id

@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count

def activate(self, *args):
raise NotImplementedError

def predict(self):
raise NotImplementedError

def update(self, *args, **kwargs):
raise NotImplementedError

def mark_lost(self):
self.state = TrackState.Lost

def mark_removed(self):
self.state = TrackState.Removed
Loading

0 comments on commit d33cd07

Please sign in to comment.