Skip to content

Commit e44fe60

Browse files
author
Hundt
committed
first commit
0 parents  commit e44fe60

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# generalized_dice_loss

segmentation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
def generalized_dice_loss(pred, true, eps=1E-64):
5+
"""pred and true are tensors of shape (b, w_0, w_1, ..., c) where
6+
b ... batch size
7+
w_k ... width of input in k-th dimension
8+
c ... number of segments/classes
9+
furthermore, boths tensors have exclusively values in [0, 1]"""
10+
11+
assert(pred.get_shape()[1:] == true.get_shape()[1:])
12+
13+
shape_pred = pred.get_shape()
14+
shape_true = true.get_shape()
15+
prod_pred = reduce(lambda x,y:x*y, shape_pred[1:-1], tf.Dimension(1))
16+
prod_true = reduce(lambda x,y:x*y, shape_true[1:-1], tf.Dimension(1))
17+
18+
# reshape to shape (b, W, c) where W is product of w_k
19+
pred = tf.reshape(pred, [-1, prod_pred, shape_pred[-1]])
20+
true = tf.reshape(true, [-1, prod_true, shape_true[-1]])
21+
22+
# inverse square weighting for class cardinalities
23+
weights = tf.square(tf.reduce_sum(true, axis=[1]))+eps
24+
weights = tf.expand_dims(tf.reduce_sum(weights, axis=[-1]), -1)/weights
25+
26+
# the traditional dice formula
27+
inter = tf.reduce_sum(weights*tf.reduce_sum(pred*true, axis=[1]), axis=[-1])
28+
union = tf.reduce_sum(weights*tf.reduce_sum(pred+true, axis=[1]), axis=[-1])
29+
30+
return tf.reduce_mean(1.0-2.0*(inter+eps)/(union+eps))
31+
32+
def convert_to_mask(batch, threshold=0.5):
33+
34+
result = np.zeros(batch.shape+(2,), dtype=batch.dtype)
35+
result[:,:,0] = batch > threshold
36+
result[:,:,1] = batch <= threshold
37+
38+
return result
39+
40+
if __name__ == "__main__":
41+
from tensorflow.examples.tutorials.mnist import input_data
42+
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
43+
44+
batch_size, print_every, activation = 128, 1024, lambda x:0.5*(tf.tanh(x)+1)
45+
46+
x = tf.placeholder(tf.float32, [None, 784])
47+
x_ = tf.placeholder(tf.float32, [None, 784, 2])
48+
49+
W = tf.Variable(tf.zeros([784, 784, 2]))
50+
b = tf.Variable(tf.zeros([784, 2]))
51+
52+
y = activation(tf.tensordot(x, W, axes=[[1],[0]])+b)
53+
54+
loss = generalized_dice_loss(y, x_)
55+
step = tf.train.AdamOptimizer(0.001).minimize(loss)
56+
57+
sess = tf.Session()
58+
sess.run(tf.global_variables_initializer())
59+
60+
for iteration in range(2**16):
61+
batch_x, _ = mnist.train.next_batch(batch_size)
62+
step_, loss_ = sess.run([step, loss],
63+
feed_dict={x : batch_x,
64+
x_: convert_to_mask(batch_x)})
65+
66+
if iteration % print_every == 0:
67+
print "loss :", loss_
68+
69+
import matplotlib; matplotlib.use("Agg")
70+
import pylab as pl
71+
72+
for index, image in enumerate(mnist.test.next_batch(batch_size)[0]):
73+
predict = sess.run(y, feed_dict={x: np.expand_dims(image, 0)})
74+
pl.subplot(131)
75+
pl.imshow(image.reshape((28, 28)))
76+
pl.subplot(132)
77+
pl.imshow(predict[0,:,0].reshape((28, 28)))
78+
pl.subplot(133)
79+
pl.imshow(predict[0,:,1].reshape((28, 28)))
80+
pl.savefig(str(index)+".png")

0 commit comments

Comments
 (0)