Skip to content

Commit

Permalink
FIX: sort named values
Browse files Browse the repository at this point in the history
  • Loading branch information
MarekWadinger committed Aug 2, 2023
1 parent 318e973 commit cdb56c9
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions river/proba/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class MultivariateGaussian(base.ContinuousDistribution):
Retrieving current state in nice format is simple
>>> p
𝒩(
μ=(0.416, 0.387, 0.518),
μ=(0.518, 0.387, 0.416),
σ^2=(
[ 0.076 0.020 -0.010]
[ 0.020 0.113 -0.053]
Expand All @@ -152,17 +152,17 @@ class MultivariateGaussian(base.ContinuousDistribution):
>>> p.n_samples
8.0
>>> p.mode # doctest: +ELLIPSIS
{'red': 0.415..., 'green': 0.386..., 'blue': 0.517...}
{'blue': 0.5179..., 'green': 0.3866..., 'red': 0.4158...}
To retrieve pdf and cdf
>>> p(x) # doctest: +ELLIPSIS
1.26921953490694...
0.97967086129734...
>>> p.cdf(x) # doctest: +ELLIPSIS
0.00787141517849810...
0.00509653891791713...
To sample data from distribution
>>> p.sample() # doctest: +ELLIPSIS
[0.203..., -0.0532..., 0.840...]
[0.3053..., -0.0532..., 0.7388...]
MultivariateGaussian works with `utils.Rolling`
Expand Down Expand Up @@ -199,12 +199,11 @@ class MultivariateGaussian(base.ContinuousDistribution):
... p_ = p_.update(x['blue'])
>>> p.sigma['blue']['blue'] == p_.sigma
True
""" # noqa: W291
"""

def __init__(self, seed=None):
super().__init__(seed)
self._var = covariance.EmpiricalCovariance(ddof=1)
self.feature_names_in_ = None

# TODO: add method _from_state to initialize model (for warm starting)

Expand All @@ -220,7 +219,7 @@ def mu(self):
"""The mean value of the distribution."""
return {
key1: values.mean.get()
for (key1, key2), values in self._var.matrix.items()
for (key1, key2), values in sorted(self._var.matrix.items())
if key1 == key2
}

Expand Down Expand Up @@ -256,19 +255,19 @@ def __repr__(self):
var_str = " [" + var_str.replace("\n", "]\n [") + "]"
return f"𝒩(\n μ=({mu_str}),\n σ^2=(\n{var_str}\n )\n)"

def update(self, x, w=1.0):
def update(self, x):
# TODO: add support for weigthed samples
self._var.update(x)
return self

def revert(self, x, w=1.0):
def revert(self, x):
# TODO: add support for weigthed samples
self._var.revert(x)
return self

def __call__(self, x):
"""PDF(x) method."""
x = list(x.values())
x = [x[i] for i in self.mu]
var = self.var
if var is not None:
try:
Expand All @@ -283,7 +282,7 @@ def __call__(self, x):
return 0.0 # pragma: no cover

def cdf(self, x):
x = list(x.values())
x = [x[i] for i in self.mu]
return multivariate_normal([*self.mu.values()], self.var, allow_singular=True).cdf(x)

def sample(self):
Expand Down

0 comments on commit cdb56c9

Please sign in to comment.