Skip to content

Commit 3704fd3

Browse files
authored
added gpu support
1 parent c0f6bb3 commit 3704fd3

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

losses.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
import torch.nn.functional as F
44

55
class FocalLoss(nn.Module):
6-
7-
def __init__(self, gamma = 1.0):
6+
7+
def __init__(self, device, gamma = 1.0):
88
super(FocalLoss, self).__init__()
9-
self.gamma = torch.tensor(gamma, dtype = torch.float32)
9+
self.device = device
10+
self.gamma = torch.tensor(gamma, dtype = torch.float32).to(device)
1011
self.eps = 1e-6
11-
12+
13+
# self.BCE_loss = nn.BCEWithLogitsLoss(reduction='none').to(device)
14+
1215
def forward(self, input, target):
1316

14-
BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
17+
BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none').to(self.device)
18+
# BCE_loss = self.BCE_loss(input, target)
1519
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
1620
F_loss = (1-pt)**self.gamma * BCE_loss
1721

0 commit comments

Comments
 (0)