Skip to content

Commit 244ff38

Browse files
authored
Add Bayesian VAR(2) example script #1658 (#1915)
* Add Bayesian VAR(2) example script * Added index with thumbnail * Fix Linting issues * Apply ruff formatting to examples/var2.py * Added header * Added Header * Added Header * Added header * shape fixed and added event * Added dim and event
1 parent d8bc7f7 commit 244ff38

File tree

3 files changed

+217
-0
lines changed

3 files changed

+217
-0
lines changed
128 KB
Loading

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ NumPyro documentation
8383
examples/zero_inflated_poisson
8484
examples/cvae
8585
tutorials/tbip
86+
examples/var2
8687

8788
.. nbgallery::
8889
:maxdepth: 1

examples/var2.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
r"""
5+
Example: VAR(2) process
6+
=======================
7+
8+
In this example, we demonstrate how to implement and perform Bayesian inference for a
9+
Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in
10+
time series analysis, especially for capturing the dynamics between multiple variables.
11+
12+
A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as:
13+
14+
.. math::
15+
16+
y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t
17+
18+
Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1
19+
and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a
20+
covariance matrix :math:`\Sigma`.
21+
22+
This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without
23+
explicit Python loops.
24+
25+
For more general time series forecasting techniques and examples, refer to the
26+
`Time Series Forecasting` tutorial:
27+
https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting
28+
29+
Reference
30+
---------
31+
For more information on Vector Autoregressive models, see:
32+
https://otexts.com/fpp2/VAR.html
33+
34+
.. image:: ../_static/img/examples/var2.png
35+
:align: center
36+
"""
37+
38+
import argparse
39+
import os
40+
import time
41+
42+
import matplotlib.pyplot as plt
43+
import numpy as np
44+
45+
from jax import random
46+
import jax.numpy as jnp
47+
48+
import numpyro
49+
from numpyro.contrib.control_flow import scan
50+
import numpyro.distributions as dist
51+
52+
53+
def var2_scan(y):
54+
T, K = y.shape # Number of time steps and number of variables
55+
56+
# Priors for constants and coefficients
57+
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K
58+
Phi1 = numpyro.sample(
59+
"Phi1", dist.Normal(0, 1).expand([K, K]).to_event(2)
60+
) # Coefficients for lag 1
61+
Phi2 = numpyro.sample(
62+
"Phi2", dist.Normal(0, 1).expand([K, K]).to_event(2)
63+
) # Coefficients for lag 2
64+
65+
# Priors for error terms
66+
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0).expand([K]).to_event(1))
67+
L_omega = numpyro.sample(
68+
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)
69+
)
70+
L_Sigma = (
71+
sigma[..., None] * L_omega
72+
) # Alternative: jnp.einsum("...i,...ij->...ij", sigma, L_omega)
73+
74+
def transition(carry, t):
75+
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data
76+
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction
77+
# Conditioned on observed y
78+
y_t = numpyro.sample(
79+
f"y_{t}",
80+
dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma),
81+
obs=y_obs[t],
82+
)
83+
new_carry = (y_t, y_prev1, y_obs)
84+
return new_carry, m_t
85+
86+
# Initial carry: observations at time steps 1 and 0
87+
init_carry = (y[1], y[0], y[2:])
88+
89+
# Time indices starting from time step 2
90+
time_indices = jnp.arange(T - 2)
91+
92+
# Run the scan
93+
_, mu = scan(transition, init_carry, time_indices)
94+
95+
# Store the mean trajectory as a deterministic variable
96+
numpyro.deterministic("mu", mu)
97+
98+
99+
def generate_var2_data(T, K, c, Phi1, Phi2, sigma):
100+
"""
101+
Generate time series data from a VAR(2) process.
102+
Args:
103+
T (int): Number of time steps.
104+
K (int): Number of variables in the time series.
105+
c (array): Constants (shape: (K,)).
106+
Phi1 (array): Coefficients for lag 1 (shape: (K, K)).
107+
Phi2 (array): Coefficients for lag 2 (shape: (K, K)).
108+
sigma (array): Covariance matrix for the noise (shape: (K, K)).
109+
Returns:
110+
np.ndarray: Generated time series data (shape: (T, K)).
111+
"""
112+
# Initialize time series with random values
113+
y = np.zeros((T, K))
114+
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2)
115+
116+
# Generate the time series
117+
for t in range(2, T):
118+
y[t] = (
119+
c
120+
+ Phi1 @ y[t - 1]
121+
+ Phi2 @ y[t - 2]
122+
+ np.random.multivariate_normal(mean=np.zeros(K), cov=sigma)
123+
)
124+
125+
return y
126+
127+
128+
def run_inference(model, args, rng_key, y):
129+
"""
130+
Run MCMC inference for the given model.
131+
Args:
132+
model: The probabilistic model to infer.
133+
args: Command-line arguments.
134+
rng_key: PRNG key for randomness.
135+
y: Observed time series data.
136+
"""
137+
start = time.time()
138+
sampler = numpyro.infer.NUTS(model)
139+
mcmc = numpyro.infer.MCMC(
140+
sampler,
141+
num_warmup=args.num_warmup,
142+
num_samples=args.num_samples,
143+
num_chains=args.num_chains,
144+
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
145+
)
146+
mcmc.run(rng_key, y=y)
147+
mcmc.print_summary()
148+
print("\nMCMC elapsed time:", time.time() - start)
149+
return mcmc.get_samples()
150+
151+
152+
def main(args):
153+
# Generate artificial dataset
154+
T = args.num_data # Number of time steps
155+
K = 2 # Number of variables
156+
c_true = jnp.array([0.5, -0.3]) # Constants
157+
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1
158+
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2
159+
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix
160+
161+
rng_key = random.PRNGKey(0)
162+
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true)
163+
164+
# Perform inference
165+
samples = run_inference(var2_scan, args, rng_key, y)
166+
167+
# Prediction
168+
mean_prediction = samples["mu"].mean(axis=0)
169+
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile
170+
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile
171+
172+
# Plot results
173+
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True)
174+
time_steps = jnp.arange(T)
175+
176+
for i in range(K):
177+
# True values
178+
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue")
179+
# Posterior mean prediction
180+
axes[i].plot(
181+
time_steps[2:],
182+
mean_prediction[:, i],
183+
label=f"Predicted Mean Variable {i + 1}",
184+
color="orange",
185+
)
186+
# 95% confidence interval
187+
axes[i].fill_between(
188+
time_steps[2:],
189+
lower_bound[:, i],
190+
upper_bound[:, i],
191+
color="orange",
192+
alpha=0.2,
193+
label="95% CI",
194+
)
195+
axes[i].set_title(f"Variable {i + 1}")
196+
axes[i].legend()
197+
axes[i].grid(True)
198+
199+
plt.xlabel("Time Steps")
200+
plt.tight_layout()
201+
plt.savefig("var2.png")
202+
203+
204+
if __name__ == "__main__":
205+
parser = argparse.ArgumentParser(description="VAR(2) example")
206+
parser.add_argument("--num-data", nargs="?", default=100, type=int)
207+
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
208+
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
209+
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
210+
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
211+
args = parser.parse_args()
212+
213+
numpyro.set_platform(args.device)
214+
numpyro.set_host_device_count(args.num_chains)
215+
216+
main(args)

0 commit comments

Comments
 (0)