1
1
import tensorflow as tf
2
2
3
- def generalized_dice_loss (pred , true , p = 2 , q = 1 , eps = 1E-64 ):
3
+ def generalized_dice_loss (pred , true , p = 1 , q = 1 , eps = 1E-6 ):
4
4
"""pred and true are tensors of shape (b, w_0, w_1, ..., c) where
5
5
b ... batch size
6
6
w_k ... width of input in k-th dimension
@@ -17,10 +17,7 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
17
17
assert (pred .get_shape ()[1 :] == true .get_shape ()[1 :])
18
18
19
19
m = "the values in your last layer must be strictly in [0, 1]"
20
- with tf .control_dependencies ([tf .assert_non_negative (pred , message = m ),
21
- tf .assert_non_negative (true , message = m ),
22
- tf .assert_less_equal (pred , 1.0 , message = m ),
23
- tf .assert_less_equal (true , 1.0 , message = m )]):
20
+ with tf .control_dependencies ([]):
24
21
25
22
shape_pred = pred .get_shape ()
26
23
shape_true = true .get_shape ()
@@ -34,19 +31,19 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
34
31
# no class reweighting at all
35
32
if p == 0 :
36
33
# unweighted intersection and union
37
- inter = tf .reduce_sum (pred * true , axis = [1 , 2 ])
38
- union = tf .reduce_sum (pred + true , axis = [1 , 2 ])
34
+ inter = tf .reduce_mean (pred * true , axis = [1 , 2 ])
35
+ union = tf .reduce_mean (pred + true , axis = [1 , 2 ])
39
36
else :
40
37
# inverse L_p weighting for class cardinalities
41
38
weights = tf .abs (tf .reduce_sum (true , axis = [1 ]))** p + eps
42
39
weights = tf .expand_dims (tf .reduce_sum (weights , axis = [- 1 ]), - 1 ) \
43
40
/ weights
44
41
45
42
# weighted intersection and union
46
- inter = tf .reduce_sum (weights * tf .reduce_sum (pred * true , axis = [1 ]),
47
- axis = [- 1 ])
48
- union = tf .reduce_sum (weights * tf .reduce_sum (pred + true , axis = [1 ]),
49
- axis = [- 1 ])
43
+ inter = tf .reduce_mean (weights * tf .reduce_mean (pred * true , axis = [1 ]),
44
+ axis = [- 1 ])
45
+ union = tf .reduce_mean (weights * tf .reduce_mean (pred + true , axis = [1 ]),
46
+ axis = [- 1 ])
50
47
51
48
# the traditional dice formula
52
49
loss = 1.0 - 2.0 * (inter + eps )/ (union + eps )
@@ -59,7 +56,7 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
59
56
weights = tf .abs (loss )** q + eps
60
57
weights = tf .reduce_sum (weights )/ weights
61
58
62
- return tf .reduce_sum (loss * weights )/ tf .reduce_sum (weights )
59
+ return tf .reduce_mean (loss * weights )/ tf .reduce_mean (weights )
63
60
64
61
if __name__ == "__main__" :
65
62
import numpy as np
@@ -86,19 +83,24 @@ def convert_to_mask(batch, threshold=0.5):
86
83
y = activation (tf .tensordot (x , W , axes = [[1 ],[0 ]])+ b )
87
84
88
85
loss = generalized_dice_loss (y , x_ )
89
- step = tf .train .AdamOptimizer (0.001 ).minimize (loss )
86
+ step = tf .train .AdamOptimizer (0.01 ).minimize (loss )
90
87
91
88
sess = tf .Session ()
92
89
sess .run (tf .global_variables_initializer ())
93
90
94
- for iteration in range (2 ** 14 ):
91
+ for iteration in range (2 ** 16 ):
95
92
batch_x , _ = mnist .train .next_batch (batch_size )
96
93
step_ , loss_ = sess .run ([step , loss ],
97
94
feed_dict = {x : batch_x ,
98
95
x_ : convert_to_mask (batch_x )})
99
96
100
97
if iteration % print_every == 0 :
101
- print "loss :" , loss_
98
+ batch_x , _ = mnist .test .next_batch (10000 )
99
+ loss_val = sess .run (loss ,
100
+ feed_dict = {x : batch_x ,
101
+ x_ : convert_to_mask (batch_x )})
102
+
103
+ print "loss:" , loss_ , "loss_val:" , loss_val
102
104
103
105
import matplotlib ; matplotlib .use ("Agg" )
104
106
import pylab as pl
0 commit comments