@@ -166,14 +166,14 @@ def load_checkpoint(checkpoint_path, model, optimizer, cpu):
166
166
print ('model loaded from %s' % checkpoint_path )
167
167
168
168
def jaccard (intersection , union , eps = 1e-15 ):
169
- return (intersection + eps ) / (union - intersection + eps )
169
+ return (intersection ) / (union - intersection + eps )
170
170
171
171
def dice (intersection , union , eps = 1e-15 ):
172
- return (2. * intersection + eps ) / (union + eps )
172
+ return (2. * intersection ) / (union + eps )
173
173
174
174
class BCESoftJaccardDice :
175
175
176
- def __init__ (self , bce_weight = 0.5 , mode = "dice" , eps = 1e-15 , weight = None ):
176
+ def __init__ (self , bce_weight = 0.5 , mode = "dice" , eps = 1e-7 , weight = None ):
177
177
self .nll_loss = torch .nn .BCEWithLogitsLoss (weight = weight )
178
178
self .bce_weight = bce_weight
179
179
self .eps = eps
@@ -457,7 +457,7 @@ def fit(self, dataset, dataset_val, **kwargs):
457
457
y_pred = nn .functional .interpolate (y_pred , scale_factor = 2 , mode = 'bilinear' , align_corners = True )
458
458
459
459
loss_fn = BCESoftJaccardDice (bce_weight = bce_loss_weight ,
460
- weight = mask_w .cuda (self .device_idx ), mode = "dice" , eps = 1. )
460
+ weight = mask_w .cuda (self .device_idx ), mode = "dice" )
461
461
loss = loss_fn (y_pred , Variable (mask .cuda (self .device_idx )))
462
462
463
463
self .optimizer .zero_grad ()
0 commit comments