Skip to content

Commit c5eb7fc

Browse files
committed
OWBoxPlot: Fix quartiles computation
The code assumed that all quartiles would (or need to) be distinct. Replace the linear scan with a numpy implementation.
1 parent 8e0c7d8 commit c5eb7fc

File tree

2 files changed

+80
-19
lines changed

2 files changed

+80
-19
lines changed

Orange/widgets/visualize/owboxplot.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,38 @@ def compute_scale(min_, max_):
4444
return first_val, step
4545

4646

47+
def _quantiles(a, freq, q, interpolation="midpoint"):
48+
"""
49+
Somewhat like np.quantiles, but with explicit sample frequencies.
50+
51+
* Only 'higher', 'lower' and 'midpoint' interpolation.
52+
* `a` MUST be sorted.
53+
"""
54+
a = np.asarray(a)
55+
freq = np.asarray(freq)
56+
assert a.size > 0 and a.size == freq.size
57+
cumdist = np.cumsum(freq)
58+
cumdist /= cumdist[-1]
59+
60+
if interpolation == "midpoint": # R quantile(..., type=2)
61+
left = np.searchsorted(cumdist, q, side="left")
62+
right = np.searchsorted(cumdist, q, side="right")
63+
# no mid point for the right most position
64+
np.clip(right, 0, a.size - 1, out=right)
65+
# right and left will be different only on the `q` boundaries
66+
# (excluding the right most sample)
67+
return (a[left] + a[right]) / 2
68+
elif interpolation == "higher": # R quantile(... type=1)
69+
right = np.searchsorted(cumdist, q, side="right")
70+
np.clip(right, 0, a.size - 1, out=right)
71+
return a[right]
72+
elif interpolation == "lower":
73+
left = np.searchsorted(cumdist, q, side="left")
74+
return a[left]
75+
else: # pragma: no covers
76+
raise ValueError("invalid interpolation: '{}'".format(interpolation))
77+
78+
4779
class BoxData:
4880
def __init__(self, dist, attr, group_val_index=None, group_var=None):
4981
self.dist = dist
@@ -55,24 +87,12 @@ def __init__(self, dist, attr, group_val_index=None, group_var=None):
5587
self.mean = float(np.sum(dist[0] * dist[1]) / n)
5688
self.var = float(np.sum(dist[1] * (dist[0] - self.mean) ** 2) / n)
5789
self.dev = math.sqrt(self.var)
58-
s = 0
59-
thresholds = [n / 4, n / 2, n / 4 * 3]
60-
thresh_i = 0
61-
q = []
62-
for i, e in enumerate(dist[1]):
63-
s += e
64-
if s >= thresholds[thresh_i]:
65-
if s == thresholds[thresh_i] and i + 1 < dist.shape[1]:
66-
q.append(float((dist[0, i] + dist[0, i + 1]) / 2))
67-
else:
68-
q.append(float(dist[0, i]))
69-
thresh_i += 1
70-
if thresh_i == 3:
71-
self.q25, self.median, self.q75 = q
72-
break
73-
else:
74-
self.q25 = self.q75 = None
75-
self.median = q[1] if len(q) == 2 else None
90+
a, freq = np.asarray(dist)
91+
q25, median, q75 = _quantiles(a, freq, [0.25, 0.5, 0.75])
92+
self.median = median
93+
# The code below omits the q25 or q75 in the plot when they are None
94+
self.q25 = None if q25 == median else q25
95+
self.q75 = None if q75 == median else q75
7696
self.conditions = [FilterContinuous(attr, FilterContinuous.Between,
7797
self.q25, self.q75)]
7898
if group_val_index is not None:

Orange/widgets/visualize/tests/test_owboxplot.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
import unittest
34

5+
import numpy.testing
46
import numpy as np
57
from AnyQt.QtCore import QItemSelectionModel
68
from AnyQt.QtTest import QTest
79

810
from Orange.data import Table, ContinuousVariable, StringVariable, Domain
9-
from Orange.widgets.visualize.owboxplot import OWBoxPlot, FilterGraphicsRectItem
11+
from Orange.widgets.visualize.owboxplot import (
12+
OWBoxPlot, FilterGraphicsRectItem, _quantiles
13+
)
1014
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin
1115

1216

@@ -208,3 +212,40 @@ def __select_value(self, list, value):
208212
if m.data(idx) == value:
209213
list.selectionModel().setCurrentIndex(
210214
idx, QItemSelectionModel.ClearAndSelect)
215+
216+
217+
class TestUtils(unittest.TestCase):
218+
def test(self):
219+
np.testing.assert_array_equal(
220+
_quantiles(range(1, 8 + 1), [1.] * 8, [0.0, 0.25, 0.5, 0.75, 1.0]),
221+
[1., 2.5, 4.5, 6.5, 8.]
222+
)
223+
np.testing.assert_array_equal(
224+
_quantiles(range(1, 8 + 1), [1.] * 8, [0.0, 0.25, 0.5, 0.75, 1.0]),
225+
[1., 2.5, 4.5, 6.5, 8.]
226+
)
227+
np.testing.assert_array_equal(
228+
_quantiles(range(1, 4 + 1), [1., 2., 1., 2],
229+
[0.0, 0.25, 0.5, 0.75, 1.0]),
230+
[1.0, 2.0, 2.5, 4.0, 4.0]
231+
)
232+
np.testing.assert_array_equal(
233+
_quantiles(range(1, 4 + 1), [2., 1., 1., 2.],
234+
[0.0, 0.25, 0.5, 0.75, 1.0]),
235+
[1.0, 1.0, 2.5, 4.0, 4.0]
236+
)
237+
np.testing.assert_array_equal(
238+
_quantiles(range(1, 4 + 1), [1., 1., 1., 1.],
239+
[0.0, 0.25, 0.5, 0.75, 1.0]),
240+
[1.0, 1.5, 2.5, 3.5, 4.0]
241+
)
242+
np.testing.assert_array_equal(
243+
_quantiles(range(1, 4 + 1), [1., 1., 1., 1.],
244+
[0.0, 0.25, 0.5, 0.75, 1.0], interpolation="higher"),
245+
[1, 2, 3, 4, 4]
246+
)
247+
np.testing.assert_array_equal(
248+
_quantiles(range(1, 4 + 1), [1., 1., 1., 1.],
249+
[0.0, 0.25, 0.5, 0.75, 1.0], interpolation="lower"),
250+
[1, 1, 2, 3, 4]
251+
)

0 commit comments

Comments
 (0)