Skip to content

Commit 99cd396

Browse files
committed
util.stats supports weights for variance
1 parent 88a39a4 commit 99cd396

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

Orange/statistics/util.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -341,31 +341,31 @@ def stats(X, weights=None, compute_variance=False):
341341
"""
342342
is_numeric = np.issubdtype(X.dtype, np.number)
343343
is_sparse = sp.issparse(X)
344-
weighted = weights is not None and X.dtype != object
345-
weights = weights if weighted else None
344+
345+
if X.size and is_numeric:
346+
if compute_variance:
347+
means, vars = nan_mean_variance_axis(X, axis=0, weights=weights)
348+
else:
349+
means = nanmean(X, axis=0, weights=weights)
350+
vars = np.zeros(X.shape[1] if X.ndim == 2 else 1)
346351

347352
if X.size and is_numeric and not is_sparse:
348353
nans = np.isnan(X).sum(axis=0)
349354
return np.column_stack((
350355
np.nanmin(X, axis=0),
351356
np.nanmax(X, axis=0),
352-
nanmean(X, axis=0, weights=weights),
353-
nanvar(X, axis=0) if compute_variance else \
354-
np.zeros(X.shape[1] if X.ndim == 2 else 1),
357+
means,
358+
vars,
355359
nans,
356360
X.shape[0] - nans))
357361
elif is_sparse and X.size:
358-
if compute_variance and weighted:
359-
raise NotImplementedError
360-
361362
non_zero = np.bincount(X.nonzero()[1], minlength=X.shape[1])
362363
X = X.tocsc()
363364
return np.column_stack((
364365
nanmin(X, axis=0),
365366
nanmax(X, axis=0),
366-
nanmean(X, axis=0, weights=weights),
367-
nanvar(X, axis=0) if compute_variance else \
368-
np.zeros(X.shape[1] if X.ndim == 2 else 1),
367+
means,
368+
vars,
369369
X.shape[0] - non_zero,
370370
non_zero))
371371
else:

Orange/tests/test_statistics.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,9 @@ def test_stats_weights_sparse(self):
155155
np.testing.assert_equal(stats(X, weights), [[0, 2, 1.5, 0, 1, 1],
156156
[1, 3, 2.5, 0, 0, 2]])
157157

158-
with self.assertRaises(NotImplementedError):
159-
stats(X, weights, compute_variance=True)
158+
np.testing.assert_equal(stats(X, weights, compute_variance=True),
159+
[[0, 2, 1.5, 0.75, 1, 1],
160+
[1, 3, 2.5, 0.75, 0, 2]])
160161

161162
def test_stats_non_numeric(self):
162163
X = np.array([

0 commit comments

Comments
 (0)