Skip to content

Commit

Permalink
Change aggregate stats type and added unique method. (#603)
Browse files Browse the repository at this point in the history
* added the `unique()` method to stats

* fix broken test

* Update test_core_stats_functions.py
  • Loading branch information
nwlandry authored Oct 20, 2024
1 parent 1fa9839 commit 139a2bc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
19 changes: 19 additions & 0 deletions tests/stats/test_core_stats_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

### General functionality

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -422,13 +423,15 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8):
assert round(H.nodes.degree.mean(), 3) == 1.125
assert round(H.nodes.degree.std(), 3) == 0.331
assert round(H.nodes.degree.var(), 3) == 0.109
assert np.allclose(H.nodes.degree.unique(), np.array([1, 2]))

assert H.edges.order.max() == 2
assert H.edges.order.min() == 0
assert H.edges.order.sum() == 5
assert round(H.edges.order.mean(), 3) == 1.25
assert round(H.edges.order.std(), 3) == 0.829
assert round(H.edges.order.var(), 3) == 0.688
assert np.allclose(H.edges.order.unique(), np.array([0, 1, 2]))

H = xgi.Hypergraph(edgelist2)
assert H.nodes.degree.max() == 2
Expand All @@ -442,13 +445,17 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8):
assert round(H.nodes.degree.mean(), 3) == 1.167
assert round(H.nodes.degree.std(), 3) == 0.373
assert round(H.nodes.degree.var(), 3) == 0.139
assert np.allclose(H.nodes.degree.unique(), np.array([1, 2]))

assert H.edges.order.max() == 2
assert H.edges.order.min() == 1
assert H.edges.order.sum() == 4
assert round(H.edges.order.mean(), 3) == 1.333
assert round(H.edges.order.std(), 3) == 0.471
assert round(H.edges.order.var(), 3) == 0.222
assert np.allclose(H.edges.order.unique(), np.array([1, 2]))
assert len(H.edges.order.unique(return_counts=True)) == 2
assert np.allclose(H.edges.order.unique(return_counts=True)[1], np.array([2, 1]))

H = xgi.Hypergraph(edgelist8)
assert H.nodes.degree.max() == 6
Expand Down Expand Up @@ -973,3 +980,15 @@ def test_multi_with_attrs(hyperwithattrs):
5: [2, "blue"],
}
assert multi.asdict(list) == d


def test_aggregate_stats_types(edgelist1):
H = xgi.Hypergraph(edgelist1)
assert isinstance(H.nodes.degree.max(), int)
assert isinstance(H.nodes.degree.min(), int)
assert isinstance(H.nodes.degree.median(), float)
assert isinstance(H.nodes.degree.mean(), float)
assert isinstance(H.nodes.degree.sum(), int)
assert isinstance(H.nodes.degree.std(), float)
assert isinstance(H.nodes.degree.var(), float)
assert isinstance(H.nodes.degree.moment(), float)
21 changes: 13 additions & 8 deletions xgi/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,23 @@ def ashist(self, bins=10, bin_edges=False, density=False, log_binning=False):

def max(self):
"""The maximum value of this stat."""
return self.asnumpy().max(axis=0)
return self.asnumpy().max(axis=0).item()

def min(self):
"""The minimum value of this stat."""
return self.asnumpy().min(axis=0)
return self.asnumpy().min(axis=0).item()

def sum(self):
"""The sum of this stat."""
return self.asnumpy().sum(axis=0)
return self.asnumpy().sum(axis=0).item()

def mean(self):
"""The arithmetic mean of this stat."""
return self.asnumpy().mean(axis=0)
return self.asnumpy().mean(axis=0).item()

def median(self):
"""The median of this stat."""
return np.median(self.asnumpy(), axis=0)
return np.median(self.asnumpy(), axis=0).item()

def std(self):
"""The standard deviation of this stat.
Expand All @@ -231,7 +231,7 @@ def std(self):
See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/
for more details.
"""
return self.asnumpy().std(axis=0)
return self.asnumpy().std(axis=0).item()

def var(self):
"""The variance of this stat.
Expand All @@ -243,7 +243,7 @@ def var(self):
See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/
for more details.
"""
return self.asnumpy().var(axis=0)
return self.asnumpy().var(axis=0).item()

def moment(self, order=2, center=False):
"""The statistical moments of this stat.
Expand All @@ -257,7 +257,9 @@ def moment(self, order=2, center=False):
"""
arr = self.asnumpy()
return spmoment(arr, moment=order) if center else np.mean(arr**order)
return (
spmoment(arr, moment=order) if center else np.mean(arr**order).item()
)

def argmin(self):
"""The ID corresponding to the minimum of the stat
Expand Down Expand Up @@ -306,6 +308,9 @@ def argsort(self, reverse=False):
d = self.asdict()
return sorted(d, key=d.get, reverse=reverse)

def unique(self, return_counts=False):
return np.unique(self.asnumpy(), return_counts=return_counts)


class NodeStat(IDStat):
"""An arbitrary node-quantity mapping.
Expand Down

0 comments on commit 139a2bc

Please sign in to comment.