diff --git a/src/main/python/tensorframes_snippets/kmeans_demo.py b/src/main/python/tensorframes_snippets/kmeans_demo.py
index eb9cee4..71fd94d 100644
--- a/src/main/python/tensorframes_snippets/kmeans_demo.py
+++ b/src/main/python/tensorframes_snippets/kmeans_demo.py
@@ -19,28 +19,9 @@ def tf_compute_distances(points, start_centers):
     :param start_centers: a numpy array of shape num_centroid x dim
     :return: a TF tensor of shape num_points x num_centroids
     """
-    with tf.variable_scope("distances"):
-        # The dimensions in the problem
-        (num_centroids, _) = np.shape(start_centers)
-        # The shape of the block is extracted as a TF variable.
-        num_points = tf.shape(points)[0]
-        # The centers are embedded in the TF program.
-        centers = tf.constant(start_centers)
-        # Computation of the minimum distance. This is a standard implementation that follows
-        # what MLlib does.
-        squares = tf.reduce_sum(tf.square(points), reduction_indices=1)
-        center_squares = tf.reduce_sum(tf.square(centers), reduction_indices=1)
-        prods = tf.matmul(points, centers, transpose_b = True)
-        # This code simply expresses two outer products: center_squares * ones(num_points)
-        # and ones(num_centroids) * squares
-        t1a = tf.expand_dims(center_squares, 0)
-        t1b = tf.stack([num_points, 1])
-        t1 = tf.tile(t1a, t1b)
-        t2a = tf.expand_dims(squares, 1)
-        t2b = tf.stack([1, num_centroids])
-        t2 = tf.tile(t2a, t2b)
-        distances = t1 + t2 - 2 * prods
-    return distances
+    return  tf.sqrt(tf.norm(points, axis=1, keep_dims=True) ** 2 
+                    - 2 * tf.matmul(points, start_centers, transpose_b=True) 
+                    + tf.transpose(tf.norm(start_centers, axis=1, keep_dims=True)) ** 2)
 
 
 def run_one_step(dataframe, start_centers):