6
6
7
7
from bsm .bayesian_regression .gaussian_processes .gaussian_processes import GPModelState , GaussianProcess
8
8
from bsm .statistical_model .abstract_statistical_model import StatisticalModel
9
- from bsm .utils .normalization import Data
9
+ from bsm .utils .normalization import Data , DataStats
10
10
from bsm .utils .type_aliases import StatisticalModelState
11
11
from typing import Union
12
12
@@ -15,17 +15,23 @@ class GPStatisticalModel(StatisticalModel[GPModelState]):
15
15
def __init__ (self ,
16
16
input_dim : int ,
17
17
output_dim : int ,
18
- f_norm_bound : float = 1.0 ,
18
+ f_norm_bound : float | chex . Array = 1.0 ,
19
19
delta : float = 0.1 ,
20
20
num_training_steps : Union [int , optax .Schedule ] = 1000 ,
21
21
beta : chex .Array | optax .Schedule | None = None ,
22
22
normalize : bool = True ,
23
+ fixed_kernel_params : bool = False ,
24
+ normalization_stats : DataStats | None = None ,
23
25
* args , ** kwargs
24
26
):
25
27
self .normalize = normalize
26
28
model = GaussianProcess (input_dim = input_dim , output_dim = output_dim , normalize = normalize , * args , ** kwargs )
27
29
super ().__init__ (input_dim , output_dim , model )
30
+ self .fixed_kernel_params = fixed_kernel_params
31
+ self .normalization_stats = normalization_stats
28
32
self .model = model
33
+ if f_norm_bound is float :
34
+ f_norm_bound = jnp .ones (output_dim ) * f_norm_bound
29
35
self .f_norm_bound = f_norm_bound
30
36
self .delta = delta
31
37
self .num_training_steps = num_training_steps
@@ -40,7 +46,12 @@ def __init__(self,
40
46
def update (self , stats_model_state : StatisticalModelState , data : Data ) -> StatisticalModelState [GPModelState ]:
41
47
size = len (data .inputs )
42
48
num_training_steps = int (self .num_training_steps (size ))
43
- new_model_state = self .model .fit_model (data , num_training_steps , stats_model_state .model_state )
49
+ if self .fixed_kernel_params :
50
+ new_model_state = GPModelState (history = data , data_stats = self .normalization_stats ,
51
+ params = stats_model_state .model_state .params )
52
+ else :
53
+ new_model_state = self .model .fit_model (data , num_training_steps , stats_model_state .model_state )
54
+
44
55
if self ._potential_beta is None :
45
56
beta = self .compute_beta (new_model_state , data )
46
57
return StatisticalModelState (model_state = new_model_state , beta = beta )
0 commit comments