Skip to content

Commit a5acdef

Browse files
committed
owkmeans: fix original_domain of output KMeansModel
1 parent ab0b5de commit a5acdef

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

Orange/widgets/unsupervised/owkmeans.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def has_attributes(self):
283283
return len(self.data.domain.attributes)
284284

285285
@staticmethod
286-
def _compute_clustering(data, k, init, n_init, max_iter, random_state):
286+
def _compute_clustering(data, k, init, n_init, max_iter, random_state, original_domain):
287287
# type: (Table, int, str, int, int, bool) -> KMeansModel
288288
if k > len(data):
289289
raise NotEnoughData()
@@ -293,6 +293,9 @@ def _compute_clustering(data, k, init, n_init, max_iter, random_state):
293293
random_state=random_state, preprocessors=[]
294294
).get_model(data)
295295

296+
# set explict original domain because data was preprocessed separately
297+
model.original_domain = original_domain
298+
296299
if data.X.shape[0] <= SILHOUETTE_MAX_SAMPLES:
297300
model.silhouette_samples = silhouette_samples(data.X, model.labels)
298301
model.silhouette = np.mean(model.silhouette_samples)
@@ -365,6 +368,7 @@ def __launch_tasks(self, ks):
365368
n_init=self.n_init,
366369
max_iter=self.max_iterations,
367370
random_state=RANDOM_STATE,
371+
original_domain=self.data.domain,
368372
) for k in ks]
369373
watcher = FutureSetWatcher(futures)
370374
watcher.resultReadyAt.connect(self.__clustering_complete)

Orange/widgets/unsupervised/tests/test_owkmeans.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import Orange.clustering
1313
from Orange.data import Table, Domain
14+
from Orange.data.table import DomainTransformationError
1415
from Orange.widgets import gui
1516
from Orange.widgets.tests.base import WidgetTest
1617
from Orange.widgets.unsupervised.owkmeans import OWKMeans, ClusterTableModel
@@ -214,8 +215,8 @@ def test_clusters_compute_value(self):
214215
np.testing.assert_equal(np.isnan(transformed), False)
215216

216217
incompatible_data = Table("iris")
217-
transformed = incompatible_data.transform(out.domain).get_column("Cluster")
218-
np.testing.assert_equal(np.isnan(transformed), True)
218+
with self.assertRaises(DomainTransformationError):
219+
transformed = incompatible_data.transform(out.domain)
219220

220221
def test_centroids_on_output(self):
221222
widget = self.widget

0 commit comments

Comments
 (0)