Skip to content

Commit 2e2e8bf

Browse files
committedNov 6, 2017
probably more stable
1 parent d08b7f6 commit 2e2e8bf

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed
 

‎generalized_dice_loss.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tensorflow as tf
22

3-
def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
3+
def generalized_dice_loss(pred, true, p=1, q=1, eps=1E-6):
44
"""pred and true are tensors of shape (b, w_0, w_1, ..., c) where
55
b ... batch size
66
w_k ... width of input in k-th dimension
@@ -17,10 +17,7 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
1717
assert(pred.get_shape()[1:] == true.get_shape()[1:])
1818

1919
m = "the values in your last layer must be strictly in [0, 1]"
20-
with tf.control_dependencies([tf.assert_non_negative(pred, message=m),
21-
tf.assert_non_negative(true, message=m),
22-
tf.assert_less_equal(pred, 1.0, message=m),
23-
tf.assert_less_equal(true, 1.0, message=m)]):
20+
with tf.control_dependencies([]):
2421

2522
shape_pred = pred.get_shape()
2623
shape_true = true.get_shape()
@@ -34,19 +31,19 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
3431
# no class reweighting at all
3532
if p == 0:
3633
# unweighted intersection and union
37-
inter = tf.reduce_sum(pred*true, axis=[1, 2])
38-
union = tf.reduce_sum(pred+true, axis=[1, 2])
34+
inter = tf.reduce_mean(pred*true, axis=[1, 2])
35+
union = tf.reduce_mean(pred+true, axis=[1, 2])
3936
else:
4037
# inverse L_p weighting for class cardinalities
4138
weights = tf.abs(tf.reduce_sum(true, axis=[1]))**p+eps
4239
weights = tf.expand_dims(tf.reduce_sum(weights, axis=[-1]), -1) \
4340
/ weights
4441

4542
# weighted intersection and union
46-
inter = tf.reduce_sum(weights*tf.reduce_sum(pred*true, axis=[1]),
47-
axis=[-1])
48-
union = tf.reduce_sum(weights*tf.reduce_sum(pred+true, axis=[1]),
49-
axis=[-1])
43+
inter = tf.reduce_mean(weights*tf.reduce_mean(pred*true, axis=[1]),
44+
axis=[-1])
45+
union = tf.reduce_mean(weights*tf.reduce_mean(pred+true, axis=[1]),
46+
axis=[-1])
5047

5148
# the traditional dice formula
5249
loss = 1.0-2.0*(inter+eps)/(union+eps)
@@ -59,7 +56,7 @@ def generalized_dice_loss(pred, true, p=2, q=1, eps=1E-64):
5956
weights = tf.abs(loss)**q+eps
6057
weights = tf.reduce_sum(weights)/weights
6158

62-
return tf.reduce_sum(loss*weights)/tf.reduce_sum(weights)
59+
return tf.reduce_mean(loss*weights)/tf.reduce_mean(weights)
6360

6461
if __name__ == "__main__":
6562
import numpy as np
@@ -86,19 +83,24 @@ def convert_to_mask(batch, threshold=0.5):
8683
y = activation(tf.tensordot(x, W, axes=[[1],[0]])+b)
8784

8885
loss = generalized_dice_loss(y, x_)
89-
step = tf.train.AdamOptimizer(0.001).minimize(loss)
86+
step = tf.train.AdamOptimizer(0.01).minimize(loss)
9087

9188
sess = tf.Session()
9289
sess.run(tf.global_variables_initializer())
9390

94-
for iteration in range(2**14):
91+
for iteration in range(2**16):
9592
batch_x, _ = mnist.train.next_batch(batch_size)
9693
step_, loss_ = sess.run([step, loss],
9794
feed_dict={x : batch_x,
9895
x_: convert_to_mask(batch_x)})
9996

10097
if iteration % print_every == 0:
101-
print "loss :", loss_
98+
batch_x, _ = mnist.test.next_batch(10000)
99+
loss_val = sess.run(loss,
100+
feed_dict={x : batch_x,
101+
x_: convert_to_mask(batch_x)})
102+
103+
print "loss:", loss_, "loss_val:", loss_val
102104

103105
import matplotlib; matplotlib.use("Agg")
104106
import pylab as pl

0 commit comments

Comments
 (0)
Please sign in to comment.