Skip to content

Commit 49fdb23

Browse files
authored
Merge pull request #7010 from markotoplak/same_clustering_for_new_data
[ENH] K-means: clusters can be inferred for new data
2 parents ba72a1b + 6ff8cff commit 49fdb23

File tree

6 files changed

+87
-15
lines changed

6 files changed

+87
-15
lines changed

Orange/clustering/clustering.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ def __init__(self, projector):
1313
self.projector = projector
1414
self.domain = None
1515
self.original_domain = None
16-
self.labels = projector.labels_
16+
17+
@property
18+
def labels(self):
19+
# converted into a property for __eq__ and __hash__ implementation
20+
return self.projector.labels_
1721

1822
def __call__(self, data):
1923
def fix_dim(x):
@@ -57,6 +61,17 @@ def predict(self, X):
5761
raise NotImplementedError(
5862
"This clustering algorithm does not support predicting.")
5963

64+
def __eq__(self, other):
65+
if self is other:
66+
return True
67+
return type(self) is type(other) \
68+
and self.projector == other.projector \
69+
and self.domain == other.domain \
70+
and self.original_domain == other.original_domain
71+
72+
def __hash__(self):
73+
return hash((type(self), self.projector, self.domain, self.original_domain))
74+
6075

6176
class Clustering(metaclass=WrapperMeta):
6277
"""

Orange/clustering/kmeans.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,20 @@
1111

1212
class KMeansModel(ClusteringModel):
1313

14+
InheritEq = True
15+
1416
def __init__(self, projector):
1517
super().__init__(projector)
16-
self.centroids = projector.cluster_centers_
17-
self.k = projector.get_params()["n_clusters"]
18+
19+
@property
20+
def centroids(self):
21+
# converted into a property for __eq__ and __hash__ implementation
22+
return self.projector.cluster_centers_
23+
24+
@property
25+
def k(self):
26+
# converted into a property for __eq__ and __hash__ implementation
27+
return self.projector.get_params()["n_clusters"]
1828

1929
def predict(self, X):
2030
return self.projector.predict(X)

Orange/tests/test_clustering_kmeans.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,27 @@ def test_model_data_table_domain(self):
138138
# totally different domain - should fail
139139
self.assertRaises(DomainTransformationError, c, Table("housing"))
140140

141+
def test_model_eq_hash(self):
142+
kmeans = KMeans(n_clusters=2, max_iter=10, random_state=42)
143+
d = self.iris
144+
k1 = kmeans.get_model(d)
145+
k2 = kmeans.get_model(d)
146+
147+
# results are the same
148+
c1, c2 = k1(d), k2(d)
149+
np.testing.assert_equal(c1, c2)
150+
151+
# transformations are not because .projector is a different object
152+
self.assertNotEqual(k1, k2)
153+
self.assertNotEqual(k1.projector, k2.projector)
154+
self.assertNotEqual(hash(k1), hash(k2))
155+
self.assertNotEqual(hash(k1.projector), hash(k2.projector))
156+
157+
# if projector was hacket to be the same, they match
158+
k1.projector = k2.projector
159+
self.assertEqual(k1, k2)
160+
self.assertEqual(hash(k1), hash(k2))
161+
141162
def test_deprecated_silhouette(self):
142163
with warnings.catch_warnings(record=True) as w:
143164
KMeans(compute_silhouette_score=True)

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def has_attributes(self):
283283
return len(self.data.domain.attributes)
284284

285285
@staticmethod
286-
def _compute_clustering(data, k, init, n_init, max_iter, random_state):
286+
def _compute_clustering(data, k, init, n_init, max_iter, random_state, original_domain):
287287
# type: (Table, int, str, int, int, bool) -> KMeansModel
288288
if k > len(data):
289289
raise NotEnoughData()
@@ -293,6 +293,9 @@ def _compute_clustering(data, k, init, n_init, max_iter, random_state):
293293
random_state=random_state, preprocessors=[]
294294
).get_model(data)
295295

296+
# set explict original domain because data was preprocessed separately
297+
model.original_domain = original_domain
298+
296299
if data.X.shape[0] <= SILHOUETTE_MAX_SAMPLES:
297300
model.silhouette_samples = silhouette_samples(data.X, model.labels)
298301
model.silhouette = np.mean(model.silhouette_samples)
@@ -365,6 +368,7 @@ def __launch_tasks(self, ks):
365368
n_init=self.n_init,
366369
max_iter=self.max_iterations,
367370
random_state=RANDOM_STATE,
371+
original_domain=self.data.domain,
368372
) for k in ks]
369373
watcher = FutureSetWatcher(futures)
370374
watcher.resultReadyAt.connect(self.__clustering_complete)
@@ -521,15 +525,16 @@ def send_data(self):
521525
domain = self.data.domain
522526
cluster_var = DiscreteVariable(
523527
get_unique_names(domain, "Cluster"),
524-
values=["C%d" % (x + 1) for x in range(km.k)]
528+
values=["C%d" % (x + 1) for x in range(km.k)],
529+
compute_value=km
525530
)
526-
clust_ids = km.labels
527531
silhouette_var = ContinuousVariable(
528532
get_unique_names(domain, "Silhouette"))
529533
if km.silhouette_samples is not None:
530534
self.Warning.no_silhouettes.clear()
531535
scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
532536
clust_scores = []
537+
clust_ids = km.labels
533538
for i in range(km.k):
534539
in_clust = clust_ids == i
535540
if in_clust.any():
@@ -545,7 +550,6 @@ def send_data(self):
545550
new_domain = add_columns(domain, metas=[cluster_var, silhouette_var])
546551
new_table = self.data.transform(new_domain)
547552
with new_table.unlocked(new_table.metas):
548-
new_table.set_column(cluster_var, clust_ids)
549553
new_table.set_column(silhouette_var, scores)
550554

551555
domain_attributes = set(domain.attributes)

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import Orange.clustering
1313
from Orange.data import Table, Domain
14+
from Orange.data.table import DomainTransformationError
1415
from Orange.widgets import gui
1516
from Orange.widgets.tests.base import WidgetTest
1617
from Orange.widgets.unsupervised.owkmeans import OWKMeans, ClusterTableModel
@@ -200,20 +201,41 @@ def test_data_on_output(self):
200201
# removing data should have cleared the output
201202
self.assertEqual(self.widget.data, None)
202203

204+
def test_clusters_compute_value(self):
205+
orig_data = self.data[:20]
206+
self.send_signal(self.widget.Inputs.data, orig_data, wait=5000)
207+
out = self.get_output(self.widget.Outputs.annotated_data)
208+
orig = out.get_column("Cluster")
209+
210+
transformed = orig_data.transform(out.domain).get_column("Cluster")
211+
np.testing.assert_equal(orig, transformed)
212+
213+
new_data = self.data[20:40]
214+
transformed = new_data.transform(out.domain).get_column("Cluster")
215+
np.testing.assert_equal(np.isnan(transformed), False)
216+
217+
incompatible_data = Table("iris")
218+
with self.assertRaises(DomainTransformationError):
219+
transformed = incompatible_data.transform(out.domain)
220+
203221
def test_centroids_on_output(self):
204222
widget = self.widget
205223
widget.optimize_k = False
206224
widget.k = 4
207225
self.send_signal(widget.Inputs.data, self.data)
208226
self.commit_and_wait()
209-
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()
210-
widget.clusterings[widget.k].silhouette_samples = np.arange(303) / 303
211-
widget.send_data()
227+
km = widget.clusterings[widget.k]
228+
212229
out = self.get_output(widget.Outputs.centroids)
213-
np.testing.assert_array_almost_equal(
214-
np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5],
215-
[1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5],
216-
[2, 0], [3, 0]]), out.metas.astype(float))
230+
sklearn_centroids = km.centroids
231+
np.testing.assert_equal(sklearn_centroids, out.X)
232+
233+
scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
234+
silhouette = [np.mean(scores[km.labels == i]) for i in range(4)]
235+
self.assertTrue(2, len(out.domain.metas))
236+
np.testing.assert_almost_equal([0, 1, 2, 3], out.get_column("Cluster"))
237+
np.testing.assert_almost_equal(silhouette, out.get_column("Silhouette"))
238+
217239
self.assertEqual(out.name, "heart_disease centroids")
218240

219241
def test_centroids_domain_on_output(self):

i18n/si/msgs.jaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ clustering/hierarchical.py:
893893
clustering/kmeans.py:
894894
KMeans: false
895895
class `KMeansModel`:
896-
def `__init__`:
896+
def `k`:
897897
n_clusters: false
898898
class `KMeans`:
899899
def `__init__`:

0 commit comments

Comments
 (0)