From 318a1bd4d09849419e866964a780dabbb9a3ebc7 Mon Sep 17 00:00:00 2001 From: hamadichihaoui Date: Sat, 11 Jul 2020 18:38:49 +0300 Subject: [PATCH] implement iou_loss --- effdet/bench.py | 2 +- effdet/iou_loss.py | 122 +++++++++++++++++++++++++++++++++++++++++++++ effdet/loss.py | 68 ++++++++++++++++++++++--- 3 files changed, 184 insertions(+), 8 deletions(-) create mode 100644 effdet/iou_loss.py diff --git a/effdet/bench.py b/effdet/bench.py index a35df694..ff0e024c 100644 --- a/effdet/bench.py +++ b/effdet/bench.py @@ -86,7 +86,7 @@ def __init__(self, model, config): config.num_scales, config.aspect_ratios, config.anchor_scale, config.image_size) self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5) - self.loss_fn = DetectionLoss(self.config) + self.loss_fn = DetectionLoss(self.config, self.anchors) def forward(self, x, target): class_out, box_out = self.model(x) diff --git a/effdet/iou_loss.py b/effdet/iou_loss.py new file mode 100644 index 00000000..fc13aeb6 --- /dev/null +++ b/effdet/iou_loss.py @@ -0,0 +1,122 @@ +''' +Based on: + https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/86a370aa2cadea6ba7e5dffb2efc4bacc4c863ea/utils/box/box_utils.py#L47 + + Distance-IoU Loss: Faster and Better Learning for Bounding Box Regression + https://arxiv.org/pdf/1911.08287.pdf + Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression + https://giou.stanford.edu/GIoU.pdf + UnitBox: An Advanced Object Detection Network + https://arxiv.org/pdf/1608.01471.pdf + + Important!!! (in case of c_iou_loss) + targets -> bboxes1, preds -> bboxes2 + ''' + +import torch +from torch import nn +import numpy as np + +eps = 10e-16 + + +def compute_iou(bboxes1, bboxes2): + "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" + assert bboxes1.size(0) == bboxes2.size(0) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) + min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) + max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) + min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) + max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) + + inter = torch.where(min_x2 - max_x1 > 0, min_x2 - max_x1, torch.tensor(0.)) * \ + torch.where(min_y2 - max_y1 > 0, min_y2 - max_y1, torch.tensor(0.)) + union = area1 + area2 - inter + iou = inter / union + iou = torch.clamp(iou, min=0, max=1.0) + return iou + + +def compute_g_iou(bboxes1, bboxes2): + "box1 of shape [N, 4] and box2 of shape [N, 4]" + #assert bboxes1.size(0) == bboxes2.size(0) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) + min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) + max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) + min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) + max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) + inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) + union = area1 + area2 - inter + C = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) * \ + (torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) + g_iou = inter / union - (C - union) / C + g_iou = torch.clamp(g_iou, min=0, max=1.0) + return g_iou + + +def compute_d_iou(bboxes1, bboxes2): + "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" + #assert bboxes1.size(0) == bboxes2.size(0) + area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) + area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) + min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) + max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) + min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) + max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) + inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) + union = area1 + area2 - inter + center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2 + center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2 + center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2 + center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2 + + # squared euclidian distance between the target and predicted bboxes + d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2 + # squared length of the diagonal of the minimum bbox that encloses both bboxes + c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + ( + torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2 + d_iou = inter / union - d_2 / c_2 + d_iou = torch.clamp(d_iou, min=-1.0, max=1.0) + + return d_iou + + +def compute_c_iou(bboxes1, bboxes2): + "bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" + #assert bboxes1.size(0) == bboxes2.size(0) + w1 = bboxes1[:, 2] - bboxes1[:, 0] + h1 = bboxes1[:, 3] - bboxes1[:, 1] + w2 = bboxes2[:, 2] - bboxes2[:, 0] + h2 = bboxes2[:, 3] - bboxes2[:, 1] + area1 = w1 * h1 + area2 = w2 * h2 + min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) + max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) + min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) + max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) + + inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) + union = area1 + area2 - inter + + center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2 + center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2 + center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2 + center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2 + # squared euclidian distance between the target and predicted bboxes + d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2 + # squared length of the diagonal of the minimum bbox that encloses both bboxes + c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + ( + torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2 + iou = inter / union + v = 4 / np.pi ** 2 * (np.arctan(w1 / h1) - np.arctan(w2 / h2)) ** 2 + with torch.no_grad(): + S = 1 - iou + alpha = v / (S + v + eps) + c_iou = iou - (d_2 / c_2 + alpha * v) + c_iou = torch.clamp(c_iou, min=-1.0, max=1.0) + return c_iou + + + diff --git a/effdet/loss.py b/effdet/loss.py index cce77875..c409b5be 100644 --- a/effdet/loss.py +++ b/effdet/loss.py @@ -3,7 +3,8 @@ import torch.nn.functional as F from typing import Optional, List - +from .anchors import decode_box_outputs +from .iou_loss import * def focal_loss(logits, targets, alpha: float, gamma: float, normalizer): """Compute the focal loss between `logits` and the golden `target` values. @@ -119,8 +120,35 @@ def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1): return box_loss + +class IouLoss(nn.Module): + + def __init__(self, losstype='Giou', reduction='mean'): + super(IouLoss, self).__init__() + self.reduction = reduction + self.loss = losstype + + def forward(self, target_bboxes, pred_bboxes): + num = target_bboxes.shape[0] + if self.loss == 'Iou': + loss = torch.sum(1.0 - compute_iou(target_bboxes, pred_bboxes)) + else: + if self.loss == 'Giou': + loss = torch.sum(1.0 - compute_g_iou(target_bboxes, pred_bboxes)) + else: + if self.loss == 'Diou': + loss = torch.sum(1.0 - compute_d_iou(target_bboxes, pred_bboxes)) + else: + loss = torch.sum(1.0 - compute_c_iou(target_bboxes, pred_bboxes)) + + if self.reduction == 'mean': + return loss / num + else: + return loss + + class DetectionLoss(nn.Module): - def __init__(self, config): + def __init__(self, config, anchors, use_iou_loss = False): super(DetectionLoss, self).__init__() self.config = config self.num_classes = config.num_classes @@ -128,6 +156,10 @@ def __init__(self, config): self.gamma = config.gamma self.delta = config.delta self.box_loss_weight = config.box_loss_weight + self.use_iou_loss = use_iou_loss + if self.use_iou_loss: + self.anchors = anchors + self.iou_loss = IouLoss() def forward( self, cls_outputs: List[torch.Tensor], box_outputs: List[torch.Tensor], @@ -161,6 +193,11 @@ def forward( cls_losses = [] box_losses = [] + if self.use_iou_loss: + box_outputs_list = [] + cls_targets_list = [] + box_targets_list = [] + for l in range(levels): cls_targets_at_level = cls_targets[l] box_targets_at_level = box_targets[l] @@ -182,12 +219,29 @@ def forward( cls_loss = cls_loss.view(bs, height, width, -1, self.num_classes) cls_loss *= (cls_targets_at_level != -2).unsqueeze(-1).float() cls_losses.append(cls_loss.sum()) + if not self.use_iou_loss: + box_losses.append(_box_loss( + box_outputs[l].permute(0, 2, 3, 1), + box_targets_at_level, + num_positives_sum, + delta=self.delta)) + + else: + box_outputs_list.append(box_outputs[l].permute(0, 2, 3, 1).reshape([bs, -1, 4])) + cls_targets_list.append(cls_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 1])) + box_targets_list.append(box_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 4])) + + + if self.use_iou_loss: + # apply bounding box regression to anchors + for k in range(box_outputs_list.shape[0]): + pred_boxes = decode_box_outputs(box_outputs_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True) + target_boxes = decode_box_outputs(box_targets_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True) + # indices where an anchor is assigned to target box + indices = box_targets_list[k] == 0.0 + pred_boxes = torch.clamp(pred_boxes, 0) + box_losses.append(self.iou_loss(target_boxes[indices.view(-1)], pred_boxes[indices.view(-1)])) - box_losses.append(_box_loss( - box_outputs[l].permute(0, 2, 3, 1), - box_targets_at_level, - num_positives_sum, - delta=self.delta)) # Sum per level losses to total loss. cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1)