-
Notifications
You must be signed in to change notification settings - Fork 213
/
loss.py
55 lines (43 loc) · 1.85 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch import nn
from torch.nn import functional as F
import utils
import numpy as np
class LossBinary:
"""
Loss defined as \alpha BCE - (1 - \alpha) SoftJaccard
"""
def __init__(self, jaccard_weight=0):
self.nll_loss = nn.BCEWithLogitsLoss()
self.jaccard_weight = jaccard_weight
def __call__(self, outputs, targets):
loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
if self.jaccard_weight:
eps = 1e-15
jaccard_target = (targets == 1).float()
jaccard_output = F.sigmoid(outputs)
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
loss -= self.jaccard_weight * torch.log((intersection + eps) / (union - intersection + eps))
return loss
class LossMulti:
def __init__(self, jaccard_weight=0, class_weights=None, num_classes=1):
if class_weights is not None:
nll_weight = utils.cuda(
torch.from_numpy(class_weights.astype(np.float32)))
else:
nll_weight = None
self.nll_loss = nn.NLLLoss2d(weight=nll_weight)
self.jaccard_weight = jaccard_weight
self.num_classes = num_classes
def __call__(self, outputs, targets):
loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
if self.jaccard_weight:
eps = 1e-15
for cls in range(self.num_classes):
jaccard_target = (targets == cls).float()
jaccard_output = outputs[:, cls].exp()
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
loss -= torch.log((intersection + eps) / (union - intersection + eps)) * self.jaccard_weight
return loss