Skip to content

Commit 761c3f4

Browse files
committed
Handle COCO crowds.
1 parent de3ad95 commit 761c3f4

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

coco.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ def load_mask(self, image_id):
171171
# and end up rounded out. Skip those objects.
172172
if m.max() < 1:
173173
continue
174+
# Is it a crowd? If so, use a negative class ID.
175+
if annotation['iscrowd']:
176+
# Use negative class ID for crowds
177+
class_id *= -1
178+
# For crowd masks, annToMask() sometimes returns a mask
179+
# smaller than the given dimensions. If so, resize it.
180+
if m.shape[0] != image_info["height"] or m.shape[1] != image_info["width"]:
181+
m = np.ones([image_info["height"], image_info["width"]], dtype=bool)
174182
instance_masks.append(m)
175183
class_ids.append(class_id)
176184

model.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,16 +494,32 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
494494
gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2,
495495
name="trim_gt_masks")
496496

497+
# Handle COCO crowds
498+
# A crowd box in COCO is a bounding box around several instances. Exclude
499+
# them from training. A crowd box is given a negative class ID.
500+
crowd_ix = tf.where(gt_class_ids < 0)[:, 0]
501+
non_crowd_ix = tf.where(gt_class_ids > 0)[:, 0]
502+
crowd_boxes = tf.gather(gt_boxes, crowd_ix)
503+
crowd_masks = tf.gather(gt_masks, crowd_ix, axis=2)
504+
gt_class_ids = tf.gather(gt_class_ids, non_crowd_ix)
505+
gt_boxes = tf.gather(gt_boxes, non_crowd_ix)
506+
gt_masks = tf.gather(gt_masks, non_crowd_ix, axis=2)
507+
497508
# Compute overlaps matrix [proposals, gt_boxes]
498509
overlaps = overlaps_graph(proposals, gt_boxes)
499510

511+
# Compute overlaps with crowd boxes [anchors, crowds]
512+
crowd_overlaps = overlaps_graph(proposals, crowd_boxes)
513+
crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
514+
no_crowd_bool = (crowd_iou_max < 0.001)
515+
500516
# Determine postive and negative ROIs
501517
roi_iou_max = tf.reduce_max(overlaps, axis=1)
502518
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
503519
positive_roi_bool = (roi_iou_max >= 0.5)
504520
positive_indices = tf.where(positive_roi_bool)[:, 0]
505-
# 2. Negative ROIs are those with < 0.5 with every GT box
506-
negative_indices = tf.where(roi_iou_max < 0.5)[:, 0]
521+
# 2. Negative ROIs are those with < 0.5 with every GT box. Skip crowds.
522+
negative_indices = tf.where(tf.logical_and(roi_iou_max < 0.5, no_crowd_bool))[:, 0]
507523

508524
# Subsample ROIs. Aim for 33% positive
509525
# Positive ROIs
@@ -1357,6 +1373,23 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
13571373
# RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))]
13581374
rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4))
13591375

1376+
# Handle COCO crowds
1377+
# A crowd box in COCO is a bounding box around several instances. Exclude
1378+
# them from training. A crowd box is given a negative class ID.
1379+
crowd_ix = np.where(gt_class_ids < 0)[0]
1380+
if crowd_ix.shape[0] > 0:
1381+
# Filter out crowds from ground truth class IDs and boxes
1382+
non_crowd_ix = np.where(gt_class_ids > 0)[0]
1383+
crowd_boxes = gt_boxes[crowd_ix]
1384+
gt_class_ids = gt_class_ids[non_crowd_ix]
1385+
gt_boxes = gt_boxes[non_crowd_ix]
1386+
# Compute overlaps with crowd boxes [anchors, crowds]
1387+
crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes)
1388+
crowd_iou_max = np.amax(crowd_overlaps, axis=1)
1389+
no_crowd_bool = (crowd_iou_max < 0.001)
1390+
else:
1391+
# All anchors don't intersect a crowd
1392+
no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool)
13601393

13611394
# Compute overlaps [num_anchors, num_gt_boxes]
13621395
overlaps = utils.compute_overlaps(anchors, gt_boxes)
@@ -1369,10 +1402,11 @@ def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
13691402
# However, don't keep any GT box unmatched (rare, but happens). Instead,
13701403
# match it to the closest anchor (even if its max IoU is < 0.3).
13711404
#
1372-
# 1. Set negative anchors first. It gets overwritten if a gt box is matched to them.
1405+
# 1. Set negative anchors first. They get overwritten below if a GT box is
1406+
# matched to them. Skip boxes in crowd areas.
13731407
anchor_iou_argmax = np.argmax(overlaps, axis=1)
13741408
anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
1375-
rpn_match[anchor_iou_max < 0.3] = -1
1409+
rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1
13761410
# 2. Set an anchor for each GT box (regardless of IoU value).
13771411
# TODO: If multiple anchors have the same IoU match all of them
13781412
gt_iou_argmax = np.argmax(overlaps, axis=0)

0 commit comments

Comments
 (0)