9
9
import optax
10
10
from jax import vmap , jit
11
11
from jax .scipy .stats import multivariate_normal
12
- from jaxtyping import PyTree
12
+ from jaxtyping import PyTree , Float , Array , Scalar
13
13
14
14
import wandb
15
15
from bsm .bayesian_regression .bayesian_regression_model import BayesianRegressionModel
16
16
from bsm .bayesian_regression .gaussian_processes .kernels import Kernel , RBF
17
17
from bsm .utils .normal_with_aleatoric import ExtendedNormal
18
18
from bsm .utils .normalization import Normalizer , DataStats , Data
19
+ from bsm .bayesian_regression .gaussian_processes .rkhs_optimization import alpha_minimize_distance , alpha_minimize_norm
19
20
20
21
21
22
@chex .dataclass
22
23
class GPModelState :
23
24
history : Data
24
25
data_stats : DataStats
25
26
params : PyTree
27
+ alphas : Float [Array , 'output_dim num_data' ] | None = None
26
28
27
29
28
30
class GaussianProcess (BayesianRegressionModel [GPModelState ]):
@@ -34,6 +36,11 @@ def __init__(self,
34
36
seed : int = 0 ,
35
37
logging_wandb : bool = True ,
36
38
normalize : bool = True ,
39
+ predict_regularized_mean : bool = False ,
40
+ regularized_mean_strategy : str = 'minimize_norm' ,
41
+ # Should be in ['minimize_norm', 'minimize_distance']
42
+ f_norm_bound_RKHS_optimization : Float [Array , 'output_dim' ] | Scalar = jnp .array (100.0 ),
43
+ beta_RKHS_optimization : Float [Array , 'output_dim' ] | Scalar = jnp .array (3.0 ),
37
44
* args ,
38
45
** kwargs
39
46
):
@@ -48,6 +55,17 @@ def __init__(self,
48
55
self .tx = optax .adamw (learning_rate = lr_rate , weight_decay = weight_decay )
49
56
self .key = jr .PRNGKey (seed )
50
57
self .logging_wandb = logging_wandb
58
+ self .predict_regularized_mean = predict_regularized_mean
59
+ self .regularized_mean_strategy = regularized_mean_strategy
60
+
61
+ if f_norm_bound_RKHS_optimization .shape == ():
62
+ f_norm_bound_RKHS_optimization = f_norm_bound_RKHS_optimization * jnp .ones (shape = (self .output_dim ,))
63
+ self .f_norm_bound_RKHS_optimization = f_norm_bound_RKHS_optimization
64
+
65
+ if beta_RKHS_optimization .shape == ():
66
+ beta_RKHS_optimization = beta_RKHS_optimization * jnp .ones (
67
+ shape = (self .output_dim ,))
68
+ self .beta_RKHS_optimization = beta_RKHS_optimization
51
69
52
70
self .v_kernel = vmap (self .kernel .apply , in_axes = (0 , None , None ), out_axes = 0 )
53
71
self .m_kernel = vmap (self .v_kernel , in_axes = (None , 0 , None ), out_axes = 1 )
@@ -65,7 +83,7 @@ def init(self, key: chex.PRNGKey) -> GPModelState:
65
83
data_stats = self .normalizer .init_stats (data )
66
84
keys = jr .split (key , self .output_dim )
67
85
params = vmap (self .kernel .init )(keys )
68
- return GPModelState (params = params , data_stats = data_stats , history = data )
86
+ return GPModelState (params = params , data_stats = data_stats , history = data , alphas = jnp . ones_like ( outputs ) )
69
87
70
88
def loss (self , vmapped_params , inputs , outputs , data_stats : DataStats ):
71
89
assert inputs .shape [0 ] == outputs .shape [0 ]
@@ -116,6 +134,44 @@ def _train_model(self, num_training_steps: int, model_state: GPModelState, data_
116
134
new_model_state = GPModelState (history = data , data_stats = data_stats , params = vmapped_params )
117
135
return new_model_state
118
136
137
+ def compute_alphas_for_regularized_mean (self , gp_model : GPModelState ) -> Float [Array , 'output_dim num_data' ]:
138
+ # Compute covariance matrix
139
+ num_data = gp_model .history .inputs .shape [0 ]
140
+ history_inputs_norm = vmap (self .normalizer .normalize , in_axes = (0 , None ))(gp_model .history .inputs ,
141
+ gp_model .data_stats .inputs )
142
+ history_outputs_norm = vmap (self .normalizer .normalize , in_axes = (0 , None ))(gp_model .history .outputs ,
143
+ gp_model .data_stats .outputs )
144
+ covariance_matrix = self .m_kernel_multiple_output (history_inputs_norm , history_inputs_norm , gp_model .params )
145
+ assert covariance_matrix .shape == (self .output_dim , num_data , num_data )
146
+ # Add noise term
147
+ extended_eye = jnp .repeat (jnp .eye (covariance_matrix .shape [- 1 ])[None , ...], repeats = self .output_dim , axis = 0 )
148
+ outputs_stds_norm = self .normalizer .normalize_std (self .output_stds , gp_model .data_stats .outputs )
149
+ noise_term = extended_eye * outputs_stds_norm [:, None , None ] ** 2
150
+ noisy_covariance_matrix = covariance_matrix + noise_term
151
+ cholesky_tuples = vmap (jax .scipy .linalg .cho_factor )(noisy_covariance_matrix )
152
+
153
+ # Compute posterior mean
154
+ denoised_mean = vmap (jax .scipy .linalg .cho_solve , in_axes = ((0 , None ), 1 ))((cholesky_tuples [0 ], False ),
155
+ history_outputs_norm )
156
+
157
+ # We have
158
+ new_alphas = []
159
+ for i in range (self .output_dim ):
160
+ # alpha_minimize_norm, alpha_minimize_distance
161
+ if self .regularized_mean_strategy == 'minimize_distance' :
162
+ alpha_value , prob = alpha_minimize_distance (kernel_matrix = covariance_matrix [i ],
163
+ sigma = outputs_stds_norm [i ],
164
+ alpha_mu = denoised_mean [i ],
165
+ norm_bound = self .f_norm_bound_RKHS_optimization [i ])
166
+ elif self .regularized_mean_strategy == 'minimize_norm' :
167
+ alpha_value , prob = alpha_minimize_norm (kernel_matrix = covariance_matrix [i ],
168
+ sigma = outputs_stds_norm [i ],
169
+ alpha_mu = denoised_mean [i ],
170
+ beta = self .beta_RKHS_optimization [i ])
171
+ new_alphas .append (alpha_value )
172
+ new_alphas = jnp .stack (new_alphas , axis = 0 )
173
+ return new_alphas
174
+
119
175
def fit_model (self ,
120
176
data : Data ,
121
177
num_training_steps : int ,
@@ -127,6 +183,9 @@ def fit_model(self,
127
183
data_stats = self .normalizer .init_stats (data )
128
184
129
185
new_model_state = self ._train_model (num_training_steps , model_state , data_stats , data )
186
+ if self .predict_regularized_mean :
187
+ alphas = self .compute_alphas_for_regularized_mean (new_model_state )
188
+ new_model_state = new_model_state .replace (alphas = alphas )
130
189
return new_model_state
131
190
132
191
@partial (jit , static_argnums = 0 )
@@ -165,7 +224,10 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
165
224
# Compute posterior mean
166
225
denoised_mean = vmap (jax .scipy .linalg .cho_solve , in_axes = ((0 , None ), 1 ))((cholesky_tuples [0 ], False ),
167
226
history_outputs_norm )
168
- mean = vmap (jnp .dot )(k_x_X , denoised_mean )
227
+ if self .predict_regularized_mean :
228
+ mean = vmap (jnp .dot )(k_x_X , model_state .alphas )
229
+ else :
230
+ mean = vmap (jnp .dot )(k_x_X , denoised_mean )
169
231
170
232
# Denormalize
171
233
mean = self .normalizer .denormalize (mean , gp_model .data_stats .outputs )
@@ -184,7 +246,9 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
184
246
if __name__ == '__main__' :
185
247
import time
186
248
import matplotlib .pyplot as plt
187
- jax .config .update ('jax_log_compiles' , True )
249
+
250
+ # jax.config.update('jax_log_compiles', True)
251
+ jax .config .update ("jax_enable_x64" , True )
188
252
189
253
key = jr .PRNGKey (0 )
190
254
input_dim = 1
@@ -200,7 +264,8 @@ def posterior(self, input, gp_model: GPModelState) -> Tuple[ExtendedNormal, Exte
200
264
201
265
logging = False
202
266
num_particles = 10
203
- model = GaussianProcess (input_dim = input_dim , output_dim = output_dim , output_stds = data_std , logging_wandb = False )
267
+ model = GaussianProcess (input_dim = input_dim , output_dim = output_dim , output_stds = data_std , logging_wandb = False ,
268
+ predict_regularized_mean = True , regularized_mean_strategy = 'minimize_distance' )
204
269
model_state = model .init (model .key )
205
270
start_time = time .time ()
206
271
print ('Starting with training' )
0 commit comments