diff --git a/Orange/widgets/visualize/owboxplot.py b/Orange/widgets/visualize/owboxplot.py index e65be8b73dd..0a28058c25c 100644 --- a/Orange/widgets/visualize/owboxplot.py +++ b/Orange/widgets/visualize/owboxplot.py @@ -44,6 +44,38 @@ def compute_scale(min_, max_): return first_val, step +def _quantiles(a, freq, q, interpolation="midpoint"): + """ + Somewhat like np.quantiles, but with explicit sample frequencies. + + * Only 'higher', 'lower' and 'midpoint' interpolation. + * `a` MUST be sorted. + """ + a = np.asarray(a) + freq = np.asarray(freq) + assert a.size > 0 and a.size == freq.size + cumdist = np.cumsum(freq) + cumdist /= cumdist[-1] + + if interpolation == "midpoint": # R quantile(..., type=2) + left = np.searchsorted(cumdist, q, side="left") + right = np.searchsorted(cumdist, q, side="right") + # no mid point for the right most position + np.clip(right, 0, a.size - 1, out=right) + # right and left will be different only on the `q` boundaries + # (excluding the right most sample) + return (a[left] + a[right]) / 2 + elif interpolation == "higher": # R quantile(... type=1) + right = np.searchsorted(cumdist, q, side="right") + np.clip(right, 0, a.size - 1, out=right) + return a[right] + elif interpolation == "lower": + left = np.searchsorted(cumdist, q, side="left") + return a[left] + else: # pragma: no cover + raise ValueError("invalid interpolation: '{}'".format(interpolation)) + + class BoxData: def __init__(self, dist, attr, group_val_index=None, group_var=None): self.dist = dist @@ -55,24 +87,12 @@ def __init__(self, dist, attr, group_val_index=None, group_var=None): self.mean = float(np.sum(dist[0] * dist[1]) / n) self.var = float(np.sum(dist[1] * (dist[0] - self.mean) ** 2) / n) self.dev = math.sqrt(self.var) - s = 0 - thresholds = [n / 4, n / 2, n / 4 * 3] - thresh_i = 0 - q = [] - for i, e in enumerate(dist[1]): - s += e - if s >= thresholds[thresh_i]: - if s == thresholds[thresh_i] and i + 1 < dist.shape[1]: - q.append(float((dist[0, i] + dist[0, i + 1]) / 2)) - else: - q.append(float(dist[0, i])) - thresh_i += 1 - if thresh_i == 3: - self.q25, self.median, self.q75 = q - break - else: - self.q25 = self.q75 = None - self.median = q[1] if len(q) == 2 else None + a, freq = np.asarray(dist) + q25, median, q75 = _quantiles(a, freq, [0.25, 0.5, 0.75]) + self.median = median + # The code below omits the q25 or q75 in the plot when they are None + self.q25 = None if q25 == median else q25 + self.q75 = None if q75 == median else q75 self.conditions = [FilterContinuous(attr, FilterContinuous.Between, self.q25, self.q75)] if group_val_index is not None: diff --git a/Orange/widgets/visualize/tests/test_owboxplot.py b/Orange/widgets/visualize/tests/test_owboxplot.py index 51a2617b90a..1466a38965b 100644 --- a/Orange/widgets/visualize/tests/test_owboxplot.py +++ b/Orange/widgets/visualize/tests/test_owboxplot.py @@ -1,12 +1,15 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring +import unittest import numpy as np from AnyQt.QtCore import QItemSelectionModel from AnyQt.QtTest import QTest from Orange.data import Table, ContinuousVariable, StringVariable, Domain -from Orange.widgets.visualize.owboxplot import OWBoxPlot, FilterGraphicsRectItem +from Orange.widgets.visualize.owboxplot import ( + OWBoxPlot, FilterGraphicsRectItem, _quantiles +) from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin @@ -208,3 +211,40 @@ def __select_value(self, list, value): if m.data(idx) == value: list.selectionModel().setCurrentIndex( idx, QItemSelectionModel.ClearAndSelect) + + +class TestUtils(unittest.TestCase): + def test(self): + np.testing.assert_array_equal( + _quantiles(range(1, 8 + 1), [1.] * 8, [0.0, 0.25, 0.5, 0.75, 1.0]), + [1., 2.5, 4.5, 6.5, 8.] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 8 + 1), [1.] * 8, [0.0, 0.25, 0.5, 0.75, 1.0]), + [1., 2.5, 4.5, 6.5, 8.] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 4 + 1), [1., 2., 1., 2], + [0.0, 0.25, 0.5, 0.75, 1.0]), + [1.0, 2.0, 2.5, 4.0, 4.0] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 4 + 1), [2., 1., 1., 2.], + [0.0, 0.25, 0.5, 0.75, 1.0]), + [1.0, 1.0, 2.5, 4.0, 4.0] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 4 + 1), [1., 1., 1., 1.], + [0.0, 0.25, 0.5, 0.75, 1.0]), + [1.0, 1.5, 2.5, 3.5, 4.0] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 4 + 1), [1., 1., 1., 1.], + [0.0, 0.25, 0.5, 0.75, 1.0], interpolation="higher"), + [1, 2, 3, 4, 4] + ) + np.testing.assert_array_equal( + _quantiles(range(1, 4 + 1), [1., 1., 1., 1.], + [0.0, 0.25, 0.5, 0.75, 1.0], interpolation="lower"), + [1, 1, 2, 3, 4] + )