Skip to content

Commit 753553f

Browse files
authored
Add HSGP contribution module (#1794)
* hsgp_init * add licesnse * simplyfy function names and add author reference * feedback part 3 * feedback part 4 * fix name
1 parent 5eb134d commit 753553f

File tree

10 files changed

+1977
-0
lines changed

10 files changed

+1977
-0
lines changed

docs/source/contrib.rst

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,205 @@ Stochastic Support
9797
:undoc-members:
9898
:show-inheritance:
9999
:member-order: bysource
100+
101+
102+
Hilbert Space Gaussian Processes Approximation
103+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104+
105+
This module contains helper functions for use in the Hilbert Space Gaussian Process (HSGP) approximation method
106+
described in [1] and [2].
107+
108+
.. warning::
109+
This module is experimental. Currently, it only supports Gaussian processes with one-dimensional inputs.
110+
111+
112+
**Why do we need an approximation?**
113+
114+
Gaussian processes do not scale well with the number of data points. Recall we had to invert the kernel matrix!
115+
The computational complexity of the Gaussian process model is :math:`\mathcal{O}(n^3)`, where :math:`n` is the number of data
116+
points. The HSGP approximation method is a way to reduce the computational complexity of the Gaussian process model
117+
to :math:`\mathcal{O}(mn + m)`, where :math:`m` is the number of basis functions used in the approximation.
118+
119+
**Approximation Strategy Steps:**
120+
121+
We strongly recommend reading [1] and [2] for a detailed explanation of the approximation method. In [3] you can find
122+
a practical approach using NumPyro and PyMC.
123+
124+
Here we provide the main steps and ingredients of the approximation method:
125+
126+
1. Each stationary kernel :math:`k` has an associated spectral density :math:`S(\omega)`. There are closed formulas for the most common kernels. These formulas depend on the hyperparameters of the kernel (e.g. amplitudes and length scales).
127+
2. We can approximate the spectral density :math:`S(\omega)` as a polynomial series in :math:`||\omega||`. We call :math:`\omega` the frequency.
128+
3. We can interpret these polynomial terms as "powers" of the Laplacian operator. The key observation is that the Fourier transform of the Laplacian operator is :math:`||\omega||^2`.
129+
4. Next, we impose Dirichlet boundary conditions on the Laplacian operator which makes it self-adjoint and with discrete spectrum.
130+
5. We identify the expansion in (2) with the sum of powers of the Laplacian operator in the eigenbasis of (4).
131+
132+
For the one dimensional case the approximation formula, in the non-centered parameterization, is:
133+
134+
.. math::
135+
136+
f(x) \approx \sum_{j = 1}^{m}
137+
\overbrace{\color{red}{\left(S(\sqrt{\lambda_j})\right)^{1/2}}}^{\text{all hyperparameters are here!}}
138+
\times
139+
\underbrace{\color{blue}{\phi_{j}(x)}}_{\text{easy to compute!}}
140+
\times
141+
\overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)}
142+
143+
where :math:`\lambda_j` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(x)` are the eigenfunctions of the
144+
Laplacian operator, and :math:`\beta_{j}` are the coefficients of the expansion (see Eq. (8) in [2]). We expect this
145+
to be a good approximation for a finite number of :math:`m` terms in the series as long as the inputs values :math:`x`
146+
are not too close to the boundaries `ell` amd `-ell`.
147+
148+
.. note::
149+
Even though the periodic kernel is not stationary, one can still adapt and find a similar approximation formula.
150+
See Appendix B in [2] for more details.
151+
152+
**Example:**
153+
154+
Here is an example of how to use the HSGP approximation method with NumPyro. We will use the squared exponential kernel.
155+
Other kernels can be used similarly.
156+
157+
.. code-block:: python
158+
159+
>>> from jax import random
160+
>>> import jax.numpy as jnp
161+
162+
>>> import numpyro
163+
>>> from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
164+
>>> import numpyro.distributions as dist
165+
>>> from numpyro.infer import MCMC, NUTS
166+
167+
168+
>>> def generate_synthetic_data(rng_key, start, stop: float, num, scale):
169+
... """Generate synthetic data."""
170+
... x = jnp.linspace(start=start, stop=stop, num=num)
171+
... y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x)
172+
... y_obs = y + scale * random.normal(rng_key, shape=(num,))
173+
... return x, y_obs
174+
175+
176+
>>> rng_key = random.PRNGKey(seed=42)
177+
>>> rng_key, rng_subkey = random.split(rng_key)
178+
>>> x, y_obs = generate_synthetic_data(
179+
... rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3
180+
>>> )
181+
182+
183+
>>> def model(x, ell, m, non_centered, y=None):
184+
... # --- Priors ---
185+
... alpha = numpyro.sample("alpha", dist.InverseGamma(concentration=12, rate=10))
186+
... length = numpyro.sample("length", dist.InverseGamma(concentration=6, rate=1))
187+
... noise = numpyro.sample("noise", dist.InverseGamma(concentration=12, rate=10))
188+
... # --- Parametrization ---
189+
... f = hsgp_squared_exponential(
190+
... x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered
191+
... )
192+
... # --- Likelihood ---
193+
... with numpyro.plate("data", x.shape[0]):
194+
... numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y)
195+
196+
197+
>>> sampler = NUTS(model)
198+
>>> mcmc = MCMC(sampler=sampler, num_warmup=500, num_samples=1_000, num_chains=2)
199+
200+
>>> rng_key, rng_subkey = random.split(rng_key)
201+
202+
>>> ell = 1.3
203+
>>> m = 20
204+
>>> non_centered = True
205+
206+
>>> mcmc.run(rng_subkey, x, ell, m, non_centered, y_obs)
207+
208+
>>> mcmc.print_summary()
209+
210+
mean std median 5.0% 95.0% n_eff r_hat
211+
alpha 1.24 0.34 1.18 0.72 1.74 1804.01 1.00
212+
beta[0] -0.10 0.66 -0.10 -1.24 0.92 1819.91 1.00
213+
beta[1] 0.00 0.71 -0.01 -1.09 1.26 1872.82 1.00
214+
beta[2] -0.05 0.69 -0.03 -1.09 1.16 2105.88 1.00
215+
beta[3] 0.25 0.74 0.26 -0.98 1.42 2281.30 1.00
216+
beta[4] -0.17 0.69 -0.17 -1.21 1.00 2551.39 1.00
217+
beta[5] 0.09 0.75 0.10 -1.13 1.30 3232.13 1.00
218+
beta[6] -0.49 0.75 -0.49 -1.65 0.82 3042.31 1.00
219+
beta[7] 0.42 0.75 0.44 -0.78 1.65 2885.42 1.00
220+
beta[8] 0.69 0.71 0.71 -0.48 1.82 2811.68 1.00
221+
beta[9] -1.43 0.75 -1.40 -2.63 -0.21 2858.68 1.00
222+
beta[10] 0.33 0.71 0.33 -0.77 1.51 2198.65 1.00
223+
beta[11] 1.09 0.73 1.11 -0.23 2.18 2765.99 1.00
224+
beta[12] -0.91 0.72 -0.91 -2.06 0.31 2586.53 1.00
225+
beta[13] 0.05 0.70 0.04 -1.16 1.12 2569.59 1.00
226+
beta[14] -0.44 0.71 -0.44 -1.58 0.73 2626.09 1.00
227+
beta[15] 0.69 0.73 0.70 -0.45 1.88 2626.32 1.00
228+
beta[16] 0.98 0.74 0.98 -0.15 2.28 2282.86 1.00
229+
beta[17] -2.54 0.77 -2.52 -3.82 -1.29 3347.56 1.00
230+
beta[18] 1.35 0.66 1.35 0.30 2.46 2638.17 1.00
231+
beta[19] 1.10 0.54 1.09 0.25 2.01 2428.37 1.00
232+
length 0.07 0.01 0.07 0.06 0.09 2321.67 1.00
233+
noise 0.33 0.03 0.33 0.29 0.38 2472.83 1.00
234+
235+
Number of divergences: 0
236+
237+
238+
.. note::
239+
Additional examples with code can be found in [3], [4] and [5].
240+
241+
**References:**
242+
243+
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
244+
Stat Comput 30, 419-446 (2020).
245+
246+
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
247+
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
248+
249+
3. `Orduz, J., A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods <https://juanitorduz.github.io/hsgp_intro>`_.
250+
251+
4. `Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_.
252+
253+
5. `Gelman, Vehtari, Simpson, et al., Bayesian workflow book - Birthdays <https://avehtari.github.io/casestudies/Birthdays/birthdays.html>`_.
254+
255+
.. note::
256+
The code of this module is based on the code of the example
257+
`Example: Hilbert space approximation for Gaussian processes <https://num.pyro.ai/en/stable/examples/hsgp.html>`_ by `Omar Sosa Rodríguez <https://github.com/omarfsosa>`_.
258+
259+
sqrt_eigenvalues
260+
----------------
261+
.. autofunction:: numpyro.contrib.hsgp.laplacian.sqrt_eigenvalues
262+
263+
eigenfunctions
264+
--------------
265+
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions
266+
267+
eigenfunctions_periodic
268+
-----------------------
269+
.. autofunction:: numpyro.contrib.hsgp.laplacian.eigenfunctions_periodic
270+
271+
spectral_density_squared_exponential
272+
------------------------------------
273+
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_squared_exponential
274+
275+
spectral_density_matern
276+
-----------------------
277+
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.spectral_density_matern
278+
279+
diag_spectral_density_squared_exponential
280+
-----------------------------------------
281+
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_squared_exponential
282+
283+
diag_spectral_density_matern
284+
----------------------------
285+
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_matern
286+
287+
diag_spectral_density_periodic
288+
------------------------------
289+
.. autofunction:: numpyro.contrib.hsgp.spectral_densities.diag_spectral_density_periodic
290+
291+
hsgp_squared_exponential
292+
------------------------
293+
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_squared_exponential
294+
295+
hsgp_matern
296+
-----------
297+
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_matern
298+
299+
hsgp_periodic_non_centered
300+
--------------------------
301+
.. autofunction:: numpyro.contrib.hsgp.approximation.hsgp_periodic_non_centered

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ NumPyro documentation
3737
tutorials/bad_posterior_geometry
3838
tutorials/truncated_distributions
3939
tutorials/censoring
40+
tutorials/hsgp_example
4041

4142
.. nbgallery::
4243
:maxdepth: 1

notebooks/source/hsgp_example.ipynb

Lines changed: 940 additions & 0 deletions
Large diffs are not rendered by default.

numpyro/contrib/hsgp/__init__.py

Whitespace-only changes.
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
This module contains the low-rank approximation functions of the Hilbert space Gaussian process.
6+
"""
7+
8+
from jaxlib.xla_extension import ArrayImpl
9+
10+
import jax.numpy as jnp
11+
12+
import numpyro
13+
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
14+
from numpyro.contrib.hsgp.spectral_densities import (
15+
diag_spectral_density_matern,
16+
diag_spectral_density_periodic,
17+
diag_spectral_density_squared_exponential,
18+
)
19+
import numpyro.distributions as dist
20+
21+
22+
def _non_centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
23+
with numpyro.plate("basis", m):
24+
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))
25+
26+
return phi @ (spd * beta)
27+
28+
29+
def _centered_approximation(phi: ArrayImpl, spd: ArrayImpl, m: int) -> ArrayImpl:
30+
with numpyro.plate("basis", m):
31+
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))
32+
33+
return phi @ beta
34+
35+
36+
def linear_approximation(
37+
phi: ArrayImpl, spd: ArrayImpl, m: int, non_centered: bool = True
38+
) -> ArrayImpl:
39+
"""
40+
Linear approximation formula of the Hilbert space Gaussian process.
41+
42+
See Eq. (8) in [1].
43+
44+
**References:**
45+
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
46+
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
47+
48+
:param ArrayImpl phi: laplacian eigenfunctions
49+
:param ArrayImpl spd: square root of the diagonal of the spectral density evaluated at square
50+
root of the first `m` eigenvalues.
51+
:param int m: number of eigenfunctions in the approximation
52+
:param bool non_centered: whether to use a non-centered parameterization
53+
:return: The low-rank approximation linear model
54+
:rtype: ArrayImpl
55+
"""
56+
if non_centered:
57+
return _non_centered_approximation(phi, spd, m)
58+
return _centered_approximation(phi, spd, m)
59+
60+
61+
def hsgp_squared_exponential(
62+
x: ArrayImpl,
63+
alpha: float,
64+
length: float,
65+
ell: float,
66+
m: int,
67+
non_centered: bool = True,
68+
) -> ArrayImpl:
69+
"""
70+
Hilbert space Gaussian process approximation using the squared exponential kernel.
71+
72+
The main idea of the approach is to combine the associated spectral density of the
73+
squared exponential kernel and the spectrum of the Dirichlet Laplacian operator to
74+
obtain a low-rank approximation of the Gram matrix. For more details see [1, 2].
75+
76+
**References:**
77+
78+
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
79+
Stat Comput 30, 419-446 (2020).
80+
81+
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
82+
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
83+
84+
:param ArrayImpl x: input data
85+
:param float alpha: amplitude of the squared exponential kernel
86+
:param float length: length scale of the squared exponential kernel
87+
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data
88+
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval
89+
:param int m: number of eigenvalues to compute and include in the approximation
90+
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True
91+
:return: the low-rank approximation linear model
92+
:rtype: ArrayImpl
93+
"""
94+
phi = eigenfunctions(x=x, ell=ell, m=m)
95+
spd = jnp.sqrt(
96+
diag_spectral_density_squared_exponential(
97+
alpha=alpha, length=length, ell=ell, m=m
98+
)
99+
)
100+
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)
101+
102+
103+
def hsgp_matern(
104+
x: ArrayImpl,
105+
nu: float,
106+
alpha: float,
107+
length: float,
108+
ell: float,
109+
m: int,
110+
non_centered: bool = True,
111+
):
112+
"""
113+
Hilbert space Gaussian process approximation using the Matérn kernel.
114+
115+
The main idea of the approach is to combine the associated spectral density of the
116+
Matérn kernel kernel and the spectrum of the Dirichlet Laplacian operator to obtain
117+
a low-rank approximation of the Gram matrix. For more details see [1, 2].
118+
119+
**References:**
120+
121+
1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression.
122+
Stat Comput 30, 419-446 (2020).
123+
124+
2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
125+
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
126+
127+
:param ArrayImpl x: input data
128+
:param float nu: smoothness parameter
129+
:param float alpha: amplitude of the squared exponential kernel
130+
:param float length: length scale of the squared exponential kernel
131+
:param float ell: positive value that parametrizes the length of the one-dimensional box so that the input data
132+
lies in the interval [-ell, ell]. We expect the approximation to be valid within this interval.
133+
:param int m: number of eigenvalues to compute and include in the approximation
134+
:param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True.
135+
:return: the low-rank approximation linear model
136+
:rtype: ArrayImpl
137+
"""
138+
phi = eigenfunctions(x=x, ell=ell, m=m)
139+
spd = jnp.sqrt(
140+
diag_spectral_density_matern(nu=nu, alpha=alpha, length=length, ell=ell, m=m)
141+
)
142+
return linear_approximation(phi=phi, spd=spd, m=m, non_centered=non_centered)
143+
144+
145+
def hsgp_periodic_non_centered(
146+
x: ArrayImpl, alpha: float, length: float, w0: float, m: int
147+
) -> ArrayImpl:
148+
"""
149+
Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.
150+
151+
See Appendix B in [1].
152+
153+
**References:**
154+
155+
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
156+
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).
157+
158+
:param ArrayImpl x: input data
159+
:param float alpha: amplitude
160+
:param float length: length scale
161+
:param float w0: frequency of the periodic kernel
162+
:param int m: number of eigenvalues to compute and include in the approximation
163+
:return: the low-rank approximation linear model
164+
:rtype: ArrayImpl
165+
"""
166+
q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m)
167+
cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m)
168+
169+
with numpyro.plate("cos_basis", m):
170+
beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1))
171+
172+
with numpyro.plate("sin_basis", m - 1):
173+
beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1))
174+
175+
# The first eigenfunction for the sine component
176+
# is zero, so the first parameter wouldn't contribute to the approximation.
177+
# We set it to zero to identify the model and avoid divergences.
178+
zero = jnp.array([0.0])
179+
beta_sin = jnp.concatenate((zero, beta_sin))
180+
181+
return cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin)

0 commit comments

Comments
 (0)