-
Notifications
You must be signed in to change notification settings - Fork 276
Open
Labels
enhancementNew feature or requestNew feature or request
Description
I have implemented versions of both of these:
from __future__ import annotations
from typing import Union, cast
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.scipy.linalg import cho_solve
from jax.scipy.special import gammaln
from numpy.typing import NDArray
from numpyro.distributions import (
Chi2,
Distribution,
MultivariateNormal,
MultivariateStudentT,
Normal,
constraints,
)
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
def delta(skewers_: NDArray[float], cov_: NDArray[float]):
return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
1 + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
)
# Efficient computation of the distribution functions of student's t chi-squared and f to moderate accuracy
# https://sci-hub.se/10.1080/00949658208810542
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
@jax.jit
def t_cdf_approx(df: Union[NDArray[float], float], t: Union[NDArray[float], float]):
a = df - 1 / 2
b = 48 * a**2
# Add epsilon to avoid undefined gradient at 0
z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)
u = (
z
+ (z**3 + 3 * z) / b
- (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))
)
return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))
# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes
arg_constraints = {
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["loc", "scale_tril", "skewers"]
uv_norm = Normal(0.0, 1.0)
@staticmethod
def mk_big_mv_norm(loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]):
cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
delta_ = delta(skewers, cov)
cov_star = jnp.block(
[
[jnp.ones(skewers.shape[:-1] + (1, 1)), jnp.expand_dims(delta_, axis=-2)],
[jnp.expand_dims(delta_, axis=-1), cov],
]
)
return MultivariateNormal(loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star))
def __init__(
self,
scale_tril: NDArray[float],
skewers: NDArray[float],
loc: Union[NDArray[float], float] = 0,
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
(self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))
# Used for sampling
self._big_mv_norm = self.mk_big_mv_norm(
# The blog post just uses unstandardized skewers here but that leads to
# a discrepancy between sampling and log_prob
loc=self.loc,
skewers=skewers / self._std_devs,
scale_tril=scale_tril,
)
# Used for log_prob
self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)
skew_mean = jnp.sqrt(2 / jnp.pi) * delta(self.skewers / self._std_devs, cov_batch)
self._mean = self.loc + skew_mean
# The paper just uses `mean` here but that's definitely not right because
# it potentially leads to covariance matrices which are not positive semi definite
self._covariance = cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)
event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
return (
jnp.log(2)
+ self._mv_norm.log_prob(value)
+ jnp.log(
self.uv_norm.cdf(jnp.einsum("...k,...k->...", (value - self.loc) / self._std_devs, self.skewers))
)
)
@staticmethod
def infer_shapes(loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape
# https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
assert is_prng_key(key)
x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())
@property
def covariance_matrix(self):
return self._covariance
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateStudentT(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["df", "loc", "scale_tril", "skewers"]
def __init__( # pylint: disable=too-many-arguments
self,
df: float,
scale_tril: NDArray[float],
skewers: NDArray[float],
loc: Union[NDArray[float], float] = 0,
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(df), jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.df,) = promote_shapes(df, shape=batch_shape)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
(self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
self._width = scale_tril.shape[-1]
# For log_prob
self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)
eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])
prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))
self.prec = jnp.einsum("...ij,...hj->...ih", prec_scale_tril, prec_scale_tril)
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))
# For sample
self._mv_skew_norm = SkewMultivariateNormal(
scale_tril=scale_tril, loc=jnp.zeros(self._width), skewers=skewers
)
self._chi2 = Chi2(self.df)
# Mean
b = jnp.sqrt(self.df / jnp.pi) * jnp.exp(gammaln((self.df - 1) / 2) - gammaln(self.df / 2))
skew_mean = b[..., jnp.newaxis] * delta(self.skewers / self._std_devs, cov_batch)
self._mean = self.loc + skew_mean
# The paper says we should multiply by the std devs but that produces results that
# disagree with `sample` and with `SkewMultivariateNormal`
# It also says we should use `_mean` instead of `skew_mean` but that allows for
# covariance matrices which are not positive semi-definite
self._covariance = jnp.array((self.df / (self.df - 2)))[
..., jnp.newaxis, jnp.newaxis
] * cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)
event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
distance = value - self.loc
Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
skew = t_cdf_approx(
self.df + self._width,
jnp.einsum(
"...k,...k->...",
self.skewers,
jnp.einsum(
"...i,...->...i", distance / self._std_devs, jnp.sqrt((self.df + self._width) / (Qy + self.df))
),
),
)
return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
@staticmethod
def infer_shapes(df: float, loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape
def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
assert is_prng_key(key)
key_normal, key_chi2 = random.split(key)
normal = self._mv_skew_norm.sample(key_normal, sample_shape=sample_shape)
chi = self._chi2.sample(key_chi2, sample_shape)
return self.loc + jnp.einsum("...i,...->...i", normal, jnp.sqrt(self.df / chi))
@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())
@property
def covariance_matrix(self):
return self._covariance(I also have some coding testing them.)
- Is there interest in upstreaming these?
- Are there obvious simplifications?
SkewMultivariateStudentTis notably slower thanMultivariateStudentTin some circumstances. Are there any obvious performance improvements available?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request