File tree 1 file changed +9
-5
lines changed
1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change 3
3
import torch .nn .functional as F
4
4
5
5
class FocalLoss (nn .Module ):
6
-
7
- def __init__ (self , gamma = 1.0 ):
6
+
7
+ def __init__ (self , device , gamma = 1.0 ):
8
8
super (FocalLoss , self ).__init__ ()
9
- self .gamma = torch .tensor (gamma , dtype = torch .float32 )
9
+ self .device = device
10
+ self .gamma = torch .tensor (gamma , dtype = torch .float32 ).to (device )
10
11
self .eps = 1e-6
11
-
12
+
13
+ # self.BCE_loss = nn.BCEWithLogitsLoss(reduction='none').to(device)
14
+
12
15
def forward (self , input , target ):
13
16
14
- BCE_loss = F .binary_cross_entropy_with_logits (input , target , reduction = 'none' )
17
+ BCE_loss = F .binary_cross_entropy_with_logits (input , target , reduction = 'none' ).to (self .device )
18
+ # BCE_loss = self.BCE_loss(input, target)
15
19
pt = torch .exp (- BCE_loss ) # prevents nans when probability 0
16
20
F_loss = (1 - pt )** self .gamma * BCE_loss
17
21
You can’t perform that action at this time.
0 commit comments