Skip to content

Commit 378c99c

Browse files
committed
add option that we don't learn kernel parameters and that the data stats stay fixed.
1 parent 6742eb7 commit 378c99c

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

bsm/statistical_model/gp_statistical_model.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from bsm.bayesian_regression.gaussian_processes.gaussian_processes import GPModelState, GaussianProcess
88
from bsm.statistical_model.abstract_statistical_model import StatisticalModel
9-
from bsm.utils.normalization import Data
9+
from bsm.utils.normalization import Data, DataStats
1010
from bsm.utils.type_aliases import StatisticalModelState
1111
from typing import Union
1212

@@ -15,17 +15,23 @@ class GPStatisticalModel(StatisticalModel[GPModelState]):
1515
def __init__(self,
1616
input_dim: int,
1717
output_dim: int,
18-
f_norm_bound: float = 1.0,
18+
f_norm_bound: float | chex.Array = 1.0,
1919
delta: float = 0.1,
2020
num_training_steps: Union[int, optax.Schedule] = 1000,
2121
beta: chex.Array | optax.Schedule | None = None,
2222
normalize: bool = True,
23+
fixed_kernel_params: bool = False,
24+
normalization_stats: DataStats | None = None,
2325
*args, **kwargs
2426
):
2527
self.normalize = normalize
2628
model = GaussianProcess(input_dim=input_dim, output_dim=output_dim, normalize=normalize, *args, **kwargs)
2729
super().__init__(input_dim, output_dim, model)
30+
self.fixed_kernel_params = fixed_kernel_params
31+
self.normalization_stats = normalization_stats
2832
self.model = model
33+
if f_norm_bound is float:
34+
f_norm_bound = jnp.ones(output_dim) * f_norm_bound
2935
self.f_norm_bound = f_norm_bound
3036
self.delta = delta
3137
self.num_training_steps = num_training_steps
@@ -40,7 +46,12 @@ def __init__(self,
4046
def update(self, stats_model_state: StatisticalModelState, data: Data) -> StatisticalModelState[GPModelState]:
4147
size = len(data.inputs)
4248
num_training_steps = int(self.num_training_steps(size))
43-
new_model_state = self.model.fit_model(data, num_training_steps, stats_model_state.model_state)
49+
if self.fixed_kernel_params:
50+
new_model_state = GPModelState(history=data, data_stats=self.normalization_stats,
51+
params=stats_model_state.model_state.params)
52+
else:
53+
new_model_state = self.model.fit_model(data, num_training_steps, stats_model_state.model_state)
54+
4455
if self._potential_beta is None:
4556
beta = self.compute_beta(new_model_state, data)
4657
return StatisticalModelState(model_state=new_model_state, beta=beta)

0 commit comments

Comments
 (0)