Skip to content

[ENH] K-means: clusters can be inferred for new data #7010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion Orange/clustering/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
self.projector = projector
self.domain = None
self.original_domain = None
self.labels = projector.labels_

@property
def labels(self):
# converted into a property for __eq__ and __hash__ implementation
return self.projector.labels_

def __call__(self, data):
def fix_dim(x):
Expand Down Expand Up @@ -57,6 +61,17 @@
raise NotImplementedError(
"This clustering algorithm does not support predicting.")

def __eq__(self, other):
if self is other:
return True

Check warning on line 66 in Orange/clustering/clustering.py

View check run for this annotation

Codecov / codecov/patch

Orange/clustering/clustering.py#L66

Added line #L66 was not covered by tests
return type(self) is type(other) \
and self.projector == other.projector \
and self.domain == other.domain \
and self.original_domain == other.original_domain

def __hash__(self):
return hash((type(self), self.projector, self.domain, self.original_domain))


class Clustering(metaclass=WrapperMeta):
"""
Expand Down
14 changes: 12 additions & 2 deletions Orange/clustering/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@

class KMeansModel(ClusteringModel):

InheritEq = True

def __init__(self, projector):
super().__init__(projector)
self.centroids = projector.cluster_centers_
self.k = projector.get_params()["n_clusters"]

@property
def centroids(self):
# converted into a property for __eq__ and __hash__ implementation
return self.projector.cluster_centers_

@property
def k(self):
# converted into a property for __eq__ and __hash__ implementation
return self.projector.get_params()["n_clusters"]

def predict(self, X):
return self.projector.predict(X)
Expand Down
21 changes: 21 additions & 0 deletions Orange/tests/test_clustering_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ def test_model_data_table_domain(self):
# totally different domain - should fail
self.assertRaises(DomainTransformationError, c, Table("housing"))

def test_model_eq_hash(self):
kmeans = KMeans(n_clusters=2, max_iter=10, random_state=42)
d = self.iris
k1 = kmeans.get_model(d)
k2 = kmeans.get_model(d)

# results are the same
c1, c2 = k1(d), k2(d)
np.testing.assert_equal(c1, c2)

# transformations are not because .projector is a different object
self.assertNotEqual(k1, k2)
self.assertNotEqual(k1.projector, k2.projector)
self.assertNotEqual(hash(k1), hash(k2))
self.assertNotEqual(hash(k1.projector), hash(k2.projector))

# if projector was hacket to be the same, they match
k1.projector = k2.projector
self.assertEqual(k1, k2)
self.assertEqual(hash(k1), hash(k2))

def test_deprecated_silhouette(self):
with warnings.catch_warnings(record=True) as w:
KMeans(compute_silhouette_score=True)
Expand Down
12 changes: 8 additions & 4 deletions Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def has_attributes(self):
return len(self.data.domain.attributes)

@staticmethod
def _compute_clustering(data, k, init, n_init, max_iter, random_state):
def _compute_clustering(data, k, init, n_init, max_iter, random_state, original_domain):
# type: (Table, int, str, int, int, bool) -> KMeansModel
if k > len(data):
raise NotEnoughData()
Expand All @@ -293,6 +293,9 @@ def _compute_clustering(data, k, init, n_init, max_iter, random_state):
random_state=random_state, preprocessors=[]
).get_model(data)

# set explict original domain because data was preprocessed separately
model.original_domain = original_domain

if data.X.shape[0] <= SILHOUETTE_MAX_SAMPLES:
model.silhouette_samples = silhouette_samples(data.X, model.labels)
model.silhouette = np.mean(model.silhouette_samples)
Expand Down Expand Up @@ -365,6 +368,7 @@ def __launch_tasks(self, ks):
n_init=self.n_init,
max_iter=self.max_iterations,
random_state=RANDOM_STATE,
original_domain=self.data.domain,
) for k in ks]
watcher = FutureSetWatcher(futures)
watcher.resultReadyAt.connect(self.__clustering_complete)
Expand Down Expand Up @@ -521,15 +525,16 @@ def send_data(self):
domain = self.data.domain
cluster_var = DiscreteVariable(
get_unique_names(domain, "Cluster"),
values=["C%d" % (x + 1) for x in range(km.k)]
values=["C%d" % (x + 1) for x in range(km.k)],
compute_value=km
)
clust_ids = km.labels
silhouette_var = ContinuousVariable(
get_unique_names(domain, "Silhouette"))
if km.silhouette_samples is not None:
self.Warning.no_silhouettes.clear()
scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
clust_scores = []
clust_ids = km.labels
for i in range(km.k):
in_clust = clust_ids == i
if in_clust.any():
Expand All @@ -545,7 +550,6 @@ def send_data(self):
new_domain = add_columns(domain, metas=[cluster_var, silhouette_var])
new_table = self.data.transform(new_domain)
with new_table.unlocked(new_table.metas):
new_table.set_column(cluster_var, clust_ids)
new_table.set_column(silhouette_var, scores)

domain_attributes = set(domain.attributes)
Expand Down
36 changes: 29 additions & 7 deletions Orange/widgets/unsupervised/tests/test_owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import Orange.clustering
from Orange.data import Table, Domain
from Orange.data.table import DomainTransformationError
from Orange.widgets import gui
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.unsupervised.owkmeans import OWKMeans, ClusterTableModel
Expand Down Expand Up @@ -200,20 +201,41 @@ def test_data_on_output(self):
# removing data should have cleared the output
self.assertEqual(self.widget.data, None)

def test_clusters_compute_value(self):
orig_data = self.data[:20]
self.send_signal(self.widget.Inputs.data, orig_data, wait=5000)
out = self.get_output(self.widget.Outputs.annotated_data)
orig = out.get_column("Cluster")

transformed = orig_data.transform(out.domain).get_column("Cluster")
np.testing.assert_equal(orig, transformed)

new_data = self.data[20:40]
transformed = new_data.transform(out.domain).get_column("Cluster")
np.testing.assert_equal(np.isnan(transformed), False)

incompatible_data = Table("iris")
with self.assertRaises(DomainTransformationError):
transformed = incompatible_data.transform(out.domain)

def test_centroids_on_output(self):
widget = self.widget
widget.optimize_k = False
widget.k = 4
self.send_signal(widget.Inputs.data, self.data)
self.commit_and_wait()
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()
widget.clusterings[widget.k].silhouette_samples = np.arange(303) / 303
widget.send_data()
km = widget.clusterings[widget.k]

out = self.get_output(widget.Outputs.centroids)
np.testing.assert_array_almost_equal(
np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5],
[1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5],
[2, 0], [3, 0]]), out.metas.astype(float))
sklearn_centroids = km.centroids
np.testing.assert_equal(sklearn_centroids, out.X)

scores = np.arctan(km.silhouette_samples) / np.pi + 0.5
silhouette = [np.mean(scores[km.labels == i]) for i in range(4)]
self.assertTrue(2, len(out.domain.metas))
np.testing.assert_almost_equal([0, 1, 2, 3], out.get_column("Cluster"))
np.testing.assert_almost_equal(silhouette, out.get_column("Silhouette"))

self.assertEqual(out.name, "heart_disease centroids")

def test_centroids_domain_on_output(self):
Expand Down
2 changes: 1 addition & 1 deletion i18n/si/msgs.jaml
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ clustering/hierarchical.py:
clustering/kmeans.py:
KMeans: false
class `KMeansModel`:
def `__init__`:
def `k`:
n_clusters: false
class `KMeans`:
def `__init__`:
Expand Down
Loading