Skip to content

Commit 4376a88

Browse files
authored
Merge pull request #5428 from VesnaT/dbscan_normalize
[ENH] DBSCAN: Optional normalization
2 parents 6f25bda + 5a0f03e commit 4376a88

File tree

3 files changed

+68
-6
lines changed

3 files changed

+68
-6
lines changed

Orange/widgets/unsupervised/owdbscan.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class Error(widget.OWWidget.Error):
7878
min_samples = Setting(4)
7979
eps = Setting(0.5)
8080
metric_idx = Setting(0)
81+
normalize = Setting(True)
8182
auto_commit = Setting(True)
8283
k_distances = None
8384
cut_point = None
@@ -102,6 +103,8 @@ def __init__(self):
102103
gui.comboBox(box, self, "metric_idx",
103104
items=list(zip(*self.METRICS))[0],
104105
callback=self._metirc_changed)
106+
gui.checkBox(box, self, "normalize", "Normalize features",
107+
callback=self._on_normalize_changed)
105108

106109
gui.auto_apply(self.buttonsArea, self, "auto_commit")
107110
gui.rubber(self.controlArea)
@@ -161,9 +164,9 @@ def _compute_cut_point(self):
161164
self.cut_point = int(DEFAULT_CUT_POINT * len(self.k_distances))
162165
self.eps = self.k_distances[self.cut_point]
163166

164-
if self.eps < EPS_BOTTOM_LIMIT:
165-
self.eps = np.min(
166-
self.k_distances[self.k_distances >= EPS_BOTTOM_LIMIT])
167+
mask = self.k_distances >= EPS_BOTTOM_LIMIT
168+
if self.eps < EPS_BOTTOM_LIMIT and sum(mask):
169+
self.eps = np.min(self.k_distances[mask])
167170
self.cut_point = self._find_nearest_dist(self.eps)
168171

169172
@Inputs.data
@@ -180,13 +183,18 @@ def set_data(self, data):
180183
if self.data is None:
181184
return
182185

183-
# preprocess data
184-
for pp in PREPROCESSORS:
185-
self.data_normalized = pp(self.data_normalized)
186+
self._preprocess_data()
186187

187188
self._compute_and_plot()
188189
self.unconditional_commit()
189190

191+
def _preprocess_data(self):
192+
self.data_normalized = self.data
193+
for pp in PREPROCESSORS:
194+
if isinstance(pp, Normalize) and not self.normalize:
195+
continue
196+
self.data_normalized = pp(self.data_normalized)
197+
190198
def send_data(self):
191199
model = self.model
192200

@@ -248,6 +256,13 @@ def _min_samples_changed(self):
248256
self._compute_and_plot(cut_point=self.cut_point)
249257
self._invalidate()
250258

259+
def _on_normalize_changed(self):
260+
if not self.data:
261+
return
262+
self._preprocess_data()
263+
self._compute_and_plot()
264+
self._invalidate()
265+
251266

252267
if __name__ == "__main__":
253268
a = QApplication(sys.argv)

Orange/widgets/unsupervised/tests/test_owdbscan.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# pylint: disable=protected-access
2+
import unittest
3+
24
import numpy as np
35
from scipy.sparse import csr_matrix, csc_matrix
46

57
from Orange.data import Table
8+
from Orange.clustering import DBSCAN
69
from Orange.distance import Euclidean
10+
from Orange.preprocess import Normalize, Continuize, SklImpute
711
from Orange.widgets.tests.base import WidgetTest
812
from Orange.widgets.tests.utils import simulate, possible_duplicate_table
913
from Orange.widgets.unsupervised.owdbscan import OWDBSCAN, get_kth_distances
@@ -226,3 +230,45 @@ def test_missing_data(self):
226230
self.send_signal(w.Inputs.data, self.iris)
227231
output = self.get_output(w.Outputs.annotated_data)
228232
self.assertTupleEqual((150, 1), output[:, "Cluster"].metas.shape)
233+
234+
def test_normalize_data(self):
235+
# not normalized
236+
self.widget.controls.normalize.setChecked(False)
237+
238+
data = Table("heart_disease")
239+
self.send_signal(self.widget.Inputs.data, data)
240+
241+
kwargs = {"eps": self.widget.eps,
242+
"min_samples": self.widget.min_samples,
243+
"metric": "euclidean"}
244+
clusters = DBSCAN(**kwargs)(data)
245+
246+
output = self.get_output(self.widget.Outputs.annotated_data)
247+
output_clusters = output.metas[:, 0]
248+
output_clusters[np.isnan(output_clusters)] = -1
249+
np.testing.assert_array_equal(output_clusters, clusters)
250+
251+
# normalized
252+
self.widget.controls.normalize.setChecked(True)
253+
254+
kwargs = {"eps": self.widget.eps,
255+
"min_samples": self.widget.min_samples,
256+
"metric": "euclidean"}
257+
for pp in (Continuize(), Normalize(), SklImpute()):
258+
data = pp(data)
259+
clusters = DBSCAN(**kwargs)(data)
260+
261+
output = self.get_output(self.widget.Outputs.annotated_data)
262+
output_clusters = output.metas[:, 0]
263+
output_clusters[np.isnan(output_clusters)] = -1
264+
np.testing.assert_array_equal(output_clusters, clusters)
265+
266+
def test_normalize_changed(self):
267+
self.send_signal(self.widget.Inputs.data, self.iris)
268+
simulate.combobox_run_through_all(self.widget.controls.metric_idx)
269+
self.widget.controls.normalize.setChecked(False)
270+
simulate.combobox_run_through_all(self.widget.controls.metric_idx)
271+
272+
273+
if __name__ == '__main__':
274+
unittest.main()

Orange/widgets/utils/slidergraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def clear_plot(self):
104104
This function clears the plot and removes data.
105105
"""
106106
self.clear()
107+
self.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0))
107108
self.plot_horlabel = []
108109
self.plot_horline = []
109110
self._line = None

0 commit comments

Comments
 (0)