Skip to content

Commit b40128e

Browse files
committed
add the option to have regularized posterior mean
1 parent 044c1e6 commit b40128e

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

bsm/bayesian_regression/gaussian_processes/gaussian_processes.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@
99
import optax
1010
from jax import vmap, jit
1111
from jax.scipy.stats import multivariate_normal
12-
from jaxtyping import PyTree
12+
from jaxtyping import PyTree, Float, Array, Scalar
1313

1414
import wandb
1515
from bsm.bayesian_regression.bayesian_regression_model import BayesianRegressionModel
1616
from bsm.bayesian_regression.gaussian_processes.kernels import Kernel, RBF
1717
from bsm.utils.normal_with_aleatoric import ExtendedNormal
1818
from bsm.utils.normalization import Normalizer, DataStats, Data
19+
from bsm.bayesian_regression.gaussian_processes.rkhs_optimization import alpha_minimize_distance, alpha_minimize_norm
1920

2021

2122
@chex.dataclass
2223
class GPModelState:
2324
history: Data
2425
data_stats: DataStats
2526
params: PyTree
27+
alphas: Float[Array, 'output_dim num_data'] | None = None
2628

2729

2830
class GaussianProcess(BayesianRegressionModel[GPModelState]):
@@ -34,6 +36,11 @@ def __init__(self,
3436
seed: int = 0,
3537
logging_wandb: bool = True,
3638
normalize: bool = True,
39+
predict_regularized_mean: bool = False,
40+
regularized_mean_strategy: str = 'minimize_norm',
41+
# Should be in ['minimize_norm', 'minimize_distance']
42+
f_norm_bound_RKHS_optimization: Float[Array, 'output_dim'] | Scalar = jnp.array(100.0),
43+
beta_RKHS_optimization: Float[Array, 'output_dim'] | Scalar = jnp.array(3.0),
3744
*args,
3845
**kwargs
3946
):
@@ -48,6 +55,17 @@ def __init__(self,
4855
self.tx = optax.adamw(learning_rate=lr_rate, weight_decay=weight_decay)
4956
self.key = jr.PRNGKey(seed)
5057
self.logging_wandb = logging_wandb
58+
self.predict_regularized_mean = predict_regularized_mean
59+
self.regularized_mean_strategy = regularized_mean_strategy
60+
61+
if f_norm_bound_RKHS_optimization.shape == ():
62+
f_norm_bound_RKHS_optimization = f_norm_bound_RKHS_optimization * jnp.ones(shape=(self.output_dim,))
63+
self.f_norm_bound_RKHS_optimization = f_norm_bound_RKHS_optimization
64+
65+
if beta_RKHS_optimization.shape == ():
66+
beta_RKHS_optimization = beta_RKHS_optimization * jnp.ones(
67+
shape=(self.output_dim,))
68+
self.beta_RKHS_optimization = beta_RKHS_optimization
5169

5270
self.v_kernel = vmap(self.kernel.apply, in_axes=(0, None, None), out_axes=0)
5371
self.m_kernel = vmap(self.v_kernel, in_axes=(None, 0, None), out_axes=1)
@@ -65,7 +83,7 @@ def init(self, key: chex.PRNGKey) -> GPModelState:
6583
data_stats = self.normalizer.init_stats(data)
6684
keys = jr.split(key, self.output_dim)
6785
params = vmap(self.kernel.init)(keys)
68-
return GPModelState(params=params, data_stats=data_stats, history=data)
86+
return GPModelState(params=params, data_stats=data_stats, history=data, alphas=jnp.ones_like(outputs))
6987

7088
def loss(self, vmapped_params, inputs, outputs, data_stats: DataStats):
7189
assert inputs.shape[0] == outputs.shape[0]
@@ -116,6 +134,44 @@ def _train_model(self, num_training_steps: int, model_state: GPModelState, data_
116134
new_model_state = GPModelState(history=data, data_stats=data_stats, params=vmapped_params)
117135
return new_model_state
118136

137+
def compute_alphas_for_regularized_mean(self, gp_model: GPModelState) -> Float[Array, 'output_dim num_data']:
138+
# Compute covariance matrix
139+
num_data = gp_model.history.inputs.shape[0]
140+
history_inputs_norm = vmap(self.normalizer.normalize, in_axes=(0, None))(gp_model.history.inputs,
141+
gp_model.data_stats.inputs)
142+
history_outputs_norm = vmap(self.normalizer.normalize, in_axes=(0, None))(gp_model.history.outputs,
143+
gp_model.data_stats.outputs)
144+
covariance_matrix = self.m_kernel_multiple_output(history_inputs_norm, history_inputs_norm, gp_model.params)
145+
assert covariance_matrix.shape == (self.output_dim, num_data, num_data)
146+
# Add noise term
147+
extended_eye = jnp.repeat(jnp.eye(covariance_matrix.shape[-1])[None, ...], repeats=self.output_dim, axis=0)
148+
outputs_stds_norm = self.normalizer.normalize_std(self.output_stds, gp_model.data_stats.outputs)
149+
noise_term = extended_eye * outputs_stds_norm[:, None, None] ** 2
150+
noisy_covariance_matrix = covariance_matrix + noise_term
151+
cholesky_tuples = vmap(jax.scipy.linalg.cho_factor)(noisy_covariance_matrix)
152+
153+
# Compute posterior mean
154+
denoised_mean = vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 1))((cholesky_tuples[0], False),
155+
history_outputs_norm)
156+
157+
# We have
158+
new_alphas = []
159+
for i in range(self.output_dim):
160+
# alpha_minimize_norm, alpha_minimize_distance
161+
if self.regularized_mean_strategy == 'minimize_distance':
162+
alpha_value, prob = alpha_minimize_distance(kernel_matrix=covariance_matrix[i],
163+
sigma=outputs_stds_norm[i],
164+
alpha_mu=denoised_mean[i],
165+
norm_bound=self.f_norm_bound_RKHS_optimization[i])
166+
elif self.regularized_mean_strategy == 'minimize_norm':
167+
alpha_value, prob = alpha_minimize_norm(kernel_matrix=covariance_matrix[i],
168+
sigma=outputs_stds_norm[i],
169+
alpha_mu=denoised_mean[i],
170+
beta=self.beta_RKHS_optimization[i])
171+
new_alphas.append(alpha_value)
172+
new_alphas = jnp.stack(new_alphas, axis=0)
173+
return new_alphas
174+
119175
def fit_model(self,
120176
data: Data,
121177
num_training_steps: int,
@@ -127,6 +183,9 @@ def fit_model(self,
127183
data_stats = self.normalizer.init_stats(data)
128184

129185
new_model_state = self._train_model(num_training_steps, model_state, data_stats, data)
186+
if self.predict_regularized_mean:
187+
alphas = self.compute_alphas_for_regularized_mean(new_model_state)
188+
new_model_state = new_model_state.replace(alphas=alphas)
130189
return new_model_state
131190

132191
@partial(jit, static_argnums=0)
@@ -165,7 +224,10 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
165224
# Compute posterior mean
166225
denoised_mean = vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 1))((cholesky_tuples[0], False),
167226
history_outputs_norm)
168-
mean = vmap(jnp.dot)(k_x_X, denoised_mean)
227+
if self.predict_regularized_mean:
228+
mean = vmap(jnp.dot)(k_x_X, model_state.alphas)
229+
else:
230+
mean = vmap(jnp.dot)(k_x_X, denoised_mean)
169231

170232
# Denormalize
171233
mean = self.normalizer.denormalize(mean, gp_model.data_stats.outputs)
@@ -184,7 +246,9 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
184246
if __name__ == '__main__':
185247
import time
186248
import matplotlib.pyplot as plt
187-
jax.config.update('jax_log_compiles', True)
249+
250+
# jax.config.update('jax_log_compiles', True)
251+
jax.config.update("jax_enable_x64", True)
188252

189253
key = jr.PRNGKey(0)
190254
input_dim = 1
@@ -200,7 +264,8 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
200264

201265
logging = False
202266
num_particles = 10
203-
model = GaussianProcess(input_dim=input_dim, output_dim=output_dim, output_stds=data_std, logging_wandb=False)
267+
model = GaussianProcess(input_dim=input_dim, output_dim=output_dim, output_stds=data_std, logging_wandb=False,
268+
predict_regularized_mean=True, regularized_mean_strategy='minimize_distance')
204269
model_state = model.init(model.key)
205270
start_time = time.time()
206271
print('Starting with training')

bsm/bayesian_regression/gaussian_processes/rkhs_optimization.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@ def alpha_minimize_norm(kernel_matrix: Float[Array, 'n_obs n_obs'],
1010
sigma: Scalar,
1111
alpha_mu: Float[Array, 'n_obs'],
1212
beta: Scalar = 2) -> Tuple[Float[Array, 'n_obs'], Problem]:
13+
numerical_correction = 0.0
1314
n_obs = kernel_matrix.shape[0]
1415
alpha = cp.Variable(n_obs)
1516
alpha_diff = alpha - alpha_mu
1617
K = kernel_matrix
18+
I = jnp.eye(n_obs)
1719

18-
constraints = [alpha_diff @ cp.psd_wrap(K + 1 / sigma ** 2 * K @ K.T) @ alpha_diff <= 4 * beta ** 2]
19-
objective = cp.Minimize(alpha @ cp.psd_wrap(K) @ alpha)
20+
constraints = [
21+
alpha_diff @ cp.psd_wrap(K + 1 / sigma ** 2 * K @ K.T + numerical_correction * I) @ alpha_diff <= 4 * beta ** 2]
22+
objective = cp.Minimize(alpha @ cp.psd_wrap(K + numerical_correction * I) @ alpha)
2023
prob = cp.Problem(objective, constraints)
2124

2225
# The optimal objective value is returned by `prob.solve()`.
@@ -29,13 +32,17 @@ def alpha_minimize_distance(kernel_matrix: Float[Array, 'n_obs n_obs'],
2932
sigma: Scalar,
3033
alpha_mu: Float[Array, 'n_obs'],
3134
norm_bound: Scalar = 3) -> Tuple[Float[Array, 'n_obs'], Problem]:
35+
numerical_correction = 0.0
3236
n_obs = kernel_matrix.shape[0]
3337
alpha = cp.Variable(n_obs)
3438
alpha_diff = alpha - alpha_mu
3539
K = kernel_matrix
3640

37-
constraints = [alpha @ cp.psd_wrap(K) @ alpha <= norm_bound]
38-
objective = cp.Minimize(alpha_diff @ cp.psd_wrap(K + 1 / sigma ** 2 * K @ K.T) @ alpha_diff)
41+
I = jnp.eye(n_obs)
42+
43+
constraints = [alpha @ cp.psd_wrap(K + numerical_correction * I) @ alpha <= norm_bound]
44+
objective = cp.Minimize(
45+
alpha_diff @ cp.psd_wrap(K + 1 / sigma ** 2 * K @ K.T + numerical_correction * I) @ alpha_diff)
3946

4047
prob = cp.Problem(objective, constraints)
4148

0 commit comments

Comments
 (0)