|
| 1 | +# Copyright Contributors to the Pyro project. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +r""" |
| 5 | +Example: VAR(2) process |
| 6 | +======================= |
| 7 | +
|
| 8 | +In this example, we demonstrate how to implement and perform Bayesian inference for a |
| 9 | +Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in |
| 10 | +time series analysis, especially for capturing the dynamics between multiple variables. |
| 11 | +
|
| 12 | +A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as: |
| 13 | +
|
| 14 | +.. math:: |
| 15 | +
|
| 16 | + y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t |
| 17 | +
|
| 18 | +Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1 |
| 19 | +and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a |
| 20 | +covariance matrix :math:`\Sigma`. |
| 21 | +
|
| 22 | +This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without |
| 23 | +explicit Python loops. |
| 24 | +
|
| 25 | +For more general time series forecasting techniques and examples, refer to the |
| 26 | +`Time Series Forecasting` tutorial: |
| 27 | +https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting |
| 28 | +
|
| 29 | +Reference |
| 30 | +--------- |
| 31 | +For more information on Vector Autoregressive models, see: |
| 32 | +https://otexts.com/fpp2/VAR.html |
| 33 | +
|
| 34 | +.. image:: ../_static/img/examples/var2.png |
| 35 | + :align: center |
| 36 | +""" |
| 37 | + |
| 38 | +import argparse |
| 39 | +import os |
| 40 | +import time |
| 41 | + |
| 42 | +import matplotlib.pyplot as plt |
| 43 | +import numpy as np |
| 44 | + |
| 45 | +from jax import random |
| 46 | +import jax.numpy as jnp |
| 47 | + |
| 48 | +import numpyro |
| 49 | +from numpyro.contrib.control_flow import scan |
| 50 | +import numpyro.distributions as dist |
| 51 | + |
| 52 | + |
| 53 | +def var2_scan(y): |
| 54 | + T, K = y.shape # Number of time steps and number of variables |
| 55 | + |
| 56 | + # Priors for constants and coefficients |
| 57 | + c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K |
| 58 | + Phi1 = numpyro.sample( |
| 59 | + "Phi1", dist.Normal(0, 1).expand([K, K]).to_event(2) |
| 60 | + ) # Coefficients for lag 1 |
| 61 | + Phi2 = numpyro.sample( |
| 62 | + "Phi2", dist.Normal(0, 1).expand([K, K]).to_event(2) |
| 63 | + ) # Coefficients for lag 2 |
| 64 | + |
| 65 | + # Priors for error terms |
| 66 | + sigma = numpyro.sample("sigma", dist.HalfNormal(1.0).expand([K]).to_event(1)) |
| 67 | + L_omega = numpyro.sample( |
| 68 | + "L_omega", dist.LKJCholesky(dimension=K, concentration=1.0) |
| 69 | + ) |
| 70 | + L_Sigma = ( |
| 71 | + sigma[..., None] * L_omega |
| 72 | + ) # Alternative: jnp.einsum("...i,...ij->...ij", sigma, L_omega) |
| 73 | + |
| 74 | + def transition(carry, t): |
| 75 | + y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data |
| 76 | + m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction |
| 77 | + # Conditioned on observed y |
| 78 | + y_t = numpyro.sample( |
| 79 | + f"y_{t}", |
| 80 | + dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma), |
| 81 | + obs=y_obs[t], |
| 82 | + ) |
| 83 | + new_carry = (y_t, y_prev1, y_obs) |
| 84 | + return new_carry, m_t |
| 85 | + |
| 86 | + # Initial carry: observations at time steps 1 and 0 |
| 87 | + init_carry = (y[1], y[0], y[2:]) |
| 88 | + |
| 89 | + # Time indices starting from time step 2 |
| 90 | + time_indices = jnp.arange(T - 2) |
| 91 | + |
| 92 | + # Run the scan |
| 93 | + _, mu = scan(transition, init_carry, time_indices) |
| 94 | + |
| 95 | + # Store the mean trajectory as a deterministic variable |
| 96 | + numpyro.deterministic("mu", mu) |
| 97 | + |
| 98 | + |
| 99 | +def generate_var2_data(T, K, c, Phi1, Phi2, sigma): |
| 100 | + """ |
| 101 | + Generate time series data from a VAR(2) process. |
| 102 | + Args: |
| 103 | + T (int): Number of time steps. |
| 104 | + K (int): Number of variables in the time series. |
| 105 | + c (array): Constants (shape: (K,)). |
| 106 | + Phi1 (array): Coefficients for lag 1 (shape: (K, K)). |
| 107 | + Phi2 (array): Coefficients for lag 2 (shape: (K, K)). |
| 108 | + sigma (array): Covariance matrix for the noise (shape: (K, K)). |
| 109 | + Returns: |
| 110 | + np.ndarray: Generated time series data (shape: (T, K)). |
| 111 | + """ |
| 112 | + # Initialize time series with random values |
| 113 | + y = np.zeros((T, K)) |
| 114 | + y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2) |
| 115 | + |
| 116 | + # Generate the time series |
| 117 | + for t in range(2, T): |
| 118 | + y[t] = ( |
| 119 | + c |
| 120 | + + Phi1 @ y[t - 1] |
| 121 | + + Phi2 @ y[t - 2] |
| 122 | + + np.random.multivariate_normal(mean=np.zeros(K), cov=sigma) |
| 123 | + ) |
| 124 | + |
| 125 | + return y |
| 126 | + |
| 127 | + |
| 128 | +def run_inference(model, args, rng_key, y): |
| 129 | + """ |
| 130 | + Run MCMC inference for the given model. |
| 131 | + Args: |
| 132 | + model: The probabilistic model to infer. |
| 133 | + args: Command-line arguments. |
| 134 | + rng_key: PRNG key for randomness. |
| 135 | + y: Observed time series data. |
| 136 | + """ |
| 137 | + start = time.time() |
| 138 | + sampler = numpyro.infer.NUTS(model) |
| 139 | + mcmc = numpyro.infer.MCMC( |
| 140 | + sampler, |
| 141 | + num_warmup=args.num_warmup, |
| 142 | + num_samples=args.num_samples, |
| 143 | + num_chains=args.num_chains, |
| 144 | + progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, |
| 145 | + ) |
| 146 | + mcmc.run(rng_key, y=y) |
| 147 | + mcmc.print_summary() |
| 148 | + print("\nMCMC elapsed time:", time.time() - start) |
| 149 | + return mcmc.get_samples() |
| 150 | + |
| 151 | + |
| 152 | +def main(args): |
| 153 | + # Generate artificial dataset |
| 154 | + T = args.num_data # Number of time steps |
| 155 | + K = 2 # Number of variables |
| 156 | + c_true = jnp.array([0.5, -0.3]) # Constants |
| 157 | + Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1 |
| 158 | + Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2 |
| 159 | + sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix |
| 160 | + |
| 161 | + rng_key = random.PRNGKey(0) |
| 162 | + y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true) |
| 163 | + |
| 164 | + # Perform inference |
| 165 | + samples = run_inference(var2_scan, args, rng_key, y) |
| 166 | + |
| 167 | + # Prediction |
| 168 | + mean_prediction = samples["mu"].mean(axis=0) |
| 169 | + lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile |
| 170 | + upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile |
| 171 | + |
| 172 | + # Plot results |
| 173 | + fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True) |
| 174 | + time_steps = jnp.arange(T) |
| 175 | + |
| 176 | + for i in range(K): |
| 177 | + # True values |
| 178 | + axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue") |
| 179 | + # Posterior mean prediction |
| 180 | + axes[i].plot( |
| 181 | + time_steps[2:], |
| 182 | + mean_prediction[:, i], |
| 183 | + label=f"Predicted Mean Variable {i + 1}", |
| 184 | + color="orange", |
| 185 | + ) |
| 186 | + # 95% confidence interval |
| 187 | + axes[i].fill_between( |
| 188 | + time_steps[2:], |
| 189 | + lower_bound[:, i], |
| 190 | + upper_bound[:, i], |
| 191 | + color="orange", |
| 192 | + alpha=0.2, |
| 193 | + label="95% CI", |
| 194 | + ) |
| 195 | + axes[i].set_title(f"Variable {i + 1}") |
| 196 | + axes[i].legend() |
| 197 | + axes[i].grid(True) |
| 198 | + |
| 199 | + plt.xlabel("Time Steps") |
| 200 | + plt.tight_layout() |
| 201 | + plt.savefig("var2.png") |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == "__main__": |
| 205 | + parser = argparse.ArgumentParser(description="VAR(2) example") |
| 206 | + parser.add_argument("--num-data", nargs="?", default=100, type=int) |
| 207 | + parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) |
| 208 | + parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) |
| 209 | + parser.add_argument("--num-chains", nargs="?", default=1, type=int) |
| 210 | + parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".') |
| 211 | + args = parser.parse_args() |
| 212 | + |
| 213 | + numpyro.set_platform(args.device) |
| 214 | + numpyro.set_host_device_count(args.num_chains) |
| 215 | + |
| 216 | + main(args) |
0 commit comments