Skip to content

Commit 3fd458d

Browse files
committed
updating versions
1 parent 850abc5 commit 3fd458d

16 files changed

+55
-38
lines changed

bsm/bayesian_regression/bayesian_neural_networks/bnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Sequence
66

77
import chex
8+
import flax
89
import jax
910
import jax.numpy as jnp
1011
import jax.random as jr
@@ -144,7 +145,7 @@ def step_jit(self,
144145
def _init(self, key):
145146
variables = self.model.init(key, jnp.ones(shape=(self.input_dim,)))
146147
if 'params' in variables:
147-
stats, params = variables.pop('params')
148+
stats, params = flax.core.pop(variables, 'params')
148149
else:
149150
stats, params = variables
150151
del variables # Delete variables to avoid wasting resources

bsm/bayesian_regression/bayesian_neural_networks/deterministic_ensembles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def apply_eval(self,
4747

4848
noise_level = 0.1
4949
d_l, d_u = 0, 10
50-
xs = jnp.linspace(d_l, d_u, 256).reshape(-1, 1)
50+
xs = jnp.linspace(d_l, d_u, 32).reshape(-1, 1)
5151
ys = jnp.concatenate([jnp.sin(xs), jnp.cos(xs)], axis=1)
5252
ys = ys + noise_level * random.normal(key=random.PRNGKey(0), shape=ys.shape)
5353
data_std = noise_level * jnp.ones(shape=(output_dim,))

bsm/bayesian_regression/bayesian_recurrent_neural_networks/rnn_ensembles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Sequence, Dict, Tuple, Optional
55

66
import chex
7+
import flax
78
import jax
89
import jax.numpy as jnp
910
import matplotlib.pyplot as plt
@@ -117,7 +118,7 @@ def eval_ll(self,
117118
def _init(self, key):
118119
variables = self.model.init(key, jnp.ones(shape=(1, self.input_dim)))
119120
if 'params' in variables:
120-
stats, params = variables.pop('params')
121+
stats, params = flax.core.pop(variables, 'params')
121122
else:
122123
stats, params = variables
123124
del variables # Delete variables to avoid wasting resources

bsm/bayesian_regression/gaussian_processes/__init__

Whitespace-only changes.

bsm/bayesian_regression/gaussian_processes/gaussian_processes.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from bsm.bayesian_regression.bayesian_regression_model import BayesianRegressionModel
1818
from bsm.bayesian_regression.gaussian_processes.kernels import Kernel, RBF
1919
from bsm.utils.normal_with_aleatoric import ExtendedNormal
20-
from bsm.utils.normalization import Normalizer, DataStats, Data, Stats
20+
from bsm.utils.normalization import Normalizer, DataStats, Data
2121

2222

2323
@chex.dataclass
@@ -35,10 +35,12 @@ def __init__(self,
3535
lr_rate: optax.Schedule | float = optax.constant_schedule(1e-2),
3636
seed: int = 0,
3737
logging_wandb: bool = True,
38+
normalize: bool = True,
3839
*args,
3940
**kwargs
4041
):
4142
super().__init__(*args, **kwargs)
43+
self.normalize = normalize
4244
if kernel is None:
4345
kernel = RBF(self.input_dim)
4446
self.kernel = kernel
@@ -56,13 +58,16 @@ def __init__(self,
5658
self.kernel_multiple_output = vmap(self.kernel.apply, in_axes=(None, None, 0), out_axes=0)
5759

5860
def init(self, key: chex.PRNGKey) -> GPModelState:
59-
keys = jr.split(key, self.output_dim)
60-
params = vmap(self.kernel.init)(keys)
6161
inputs = jnp.zeros(shape=(1, self.input_dim))
6262
outputs = jnp.zeros(shape=(1, self.output_dim))
6363
data = Data(inputs=inputs, outputs=outputs)
64-
data_stats = self.normalizer.compute_stats(data.inputs)
65-
return GPModelState(history=data, data_stats=data_stats, params=params)
64+
if self.normalize:
65+
data_stats = self.normalizer.compute_stats(data.inputs)
66+
else:
67+
data_stats = self.normalizer.init_stats(data.inputs)
68+
keys = jr.split(key, self.output_dim)
69+
params = vmap(self.kernel.init)(keys)
70+
return GPModelState(params=params, data_stats=data_stats, history=data)
6671

6772
def loss(self, vmapped_params, inputs, outputs, data_stats: DataStats):
6873
assert inputs.shape[0] == outputs.shape[0]
@@ -105,7 +110,10 @@ def fit_model(self,
105110
model_state: GPModelState) -> GPModelState:
106111
vmapped_params = model_state.params
107112
opt_state = self.tx.init(vmapped_params)
108-
data_stats = self.normalizer.compute_stats(data)
113+
if self.normalize:
114+
data_stats = self.normalizer.compute_stats(data)
115+
else:
116+
data_stats = self.normalizer.init_stats(data)
109117

110118
def f(carry, _):
111119
opt_state, vmapped_params = carry

bsm/statistical_model/bnn_statistical_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
import jax.numpy as jnp
33
import jax.random as jr
44
import optax
5-
from jax import vmap
65

76
from bsm.bayesian_regression.bayesian_neural_networks.bnn import BNNState
87
from bsm.bayesian_regression.bayesian_neural_networks.bnn import BayesianNeuralNet
98
from bsm.bayesian_regression.bayesian_neural_networks.deterministic_ensembles import DeterministicEnsemble
9+
from bsm.bayesian_regression.bayesian_neural_networks.fsvgd_ensemble import DeterministicFSVGDEnsemble, \
10+
ProbabilisticFSVGDEnsemble
1011
from bsm.bayesian_regression.bayesian_neural_networks.probabilistic_ensembles import ProbabilisticEnsemble
11-
from bsm.bayesian_regression.bayesian_neural_networks.fsvgd_ensemble import DeterministicFSVGDEnsemble, ProbabilisticFSVGDEnsemble
1212
from bsm.statistical_model.abstract_statistical_model import StatisticalModel
1313
from bsm.utils.normalization import Data
14-
from bsm.utils.type_aliases import StatisticalModelState, StatisticalModelOutput
14+
from bsm.utils.type_aliases import StatisticalModelState
1515

1616

1717
class BNNStatisticalModel(StatisticalModel[BNNState]):
@@ -91,6 +91,7 @@ def update(self, stats_model_state: StatisticalModelState[BNNState], data: Data)
9191
plt.plot(test_xs.reshape(-1), test_ys[:, j], label='True', color='green')
9292
by_label = dict(zip(labels, handles))
9393
plt.legend(by_label.values(), by_label.keys())
94+
plt.savefig(f'bnn_{j}.pdf')
9495
plt.show()
9596

9697
num_test_points = 1000
@@ -102,4 +103,4 @@ def update(self, stats_model_state: StatisticalModelState[BNNState], data: Data)
102103
plt.plot(in_domain_test_xs, in_domain_preds.mean[:, j], label='Mean', color='blue')
103104
plt.plot(in_domain_test_xs, in_domain_test_ys[:, j], label='Fun', color='Green')
104105
plt.legend()
105-
plt.show()
106+
plt.show()

bsm/statistical_model/gp_statistical_model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import optax
55
from jax import vmap
66

7-
from bsm.statistical_model.abstract_statistical_model import StatisticalModel
87
from bsm.bayesian_regression.gaussian_processes.gaussian_processes import GPModelState, GaussianProcess
8+
from bsm.statistical_model.abstract_statistical_model import StatisticalModel
99
from bsm.utils.normalization import Data
10-
from bsm.utils.type_aliases import StatisticalModelState, StatisticalModelOutput
10+
from bsm.utils.type_aliases import StatisticalModelState
1111

1212

1313
class GPStatisticalModel(StatisticalModel[GPModelState]):
@@ -18,9 +18,11 @@ def __init__(self,
1818
delta: float = 0.1,
1919
num_training_steps: int = 1000,
2020
beta: chex.Array | optax.Schedule | None = None,
21+
normalize: bool = True,
2122
*args, **kwargs
2223
):
23-
model = GaussianProcess(input_dim=input_dim, output_dim=output_dim, *args, **kwargs)
24+
self.normalize = normalize
25+
model = GaussianProcess(input_dim=input_dim, output_dim=output_dim, normalize=normalize, *args, **kwargs)
2426
super().__init__(input_dim, output_dim, model)
2527
self.model = model
2628
self.f_norm_bound = f_norm_bound
@@ -41,8 +43,11 @@ def update(self, stats_model_state: StatisticalModelState, data: Data) -> Statis
4143
return StatisticalModelState(model_state=new_model_state, beta=beta)
4244

4345
def compute_beta(self, model_state: GPModelState, data: Data):
44-
inputs_norm = vmap(self.model.normalizer.normalize, in_axes=(0, None))(data.inputs,
45-
model_state.data_stats.inputs)
46+
if self.normalize:
47+
inputs_norm = vmap(self.model.normalizer.normalize, in_axes=(0, None))(data.inputs,
48+
model_state.data_stats.inputs)
49+
else:
50+
inputs_norm = data.inputs
4651
covariance_matrix = self.model.m_kernel_multiple_output(inputs_norm, inputs_norm, model_state.params)
4752
covariance_matrix = covariance_matrix / (self.model.output_stds ** 2)[:, None, None]
4853
covariance_matrix = covariance_matrix + jnp.eye(covariance_matrix.shape[-1])[None, :, :]
@@ -92,4 +97,5 @@ def compute_beta(self, model_state: GPModelState, data: Data):
9297
plt.plot(test_xs.reshape(-1), test_ys[:, j], label='True', color='green')
9398
by_label = dict(zip(labels, handles))
9499
plt.legend(by_label.values(), by_label.keys())
100+
plt.savefig(f'gp_{j}.pdf')
95101
plt.show()

bsm/utils/normalization.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,25 @@ def __init__(self, num_correction=1e-6):
3030
def compute_stats(self, data: PyTree) -> PyTree[Stats]:
3131
return jtu.tree_map(self.get_stats, data)
3232

33+
def init_stats(self, data: PyTree) -> PyTree[Stats]:
34+
return jtu.tree_map(self._init_stats, data)
35+
3336
@partial(jax.jit, static_argnums=0)
3437
def get_stats(self, data: chex.Array) -> Stats:
3538
assert data.ndim == 2
3639
mean = jnp.mean(data, axis=0)
37-
std = jnp.std(data, axis=0) + self.num_correction
40+
if data.shape[0] > 1:
41+
std = jnp.std(data, axis=0) + self.num_correction
42+
else:
43+
std = jnp.ones_like(mean)
3844
return Stats(mean, std)
3945

46+
@partial(jax.jit, static_argnums=0)
47+
def _init_stats(self, data: chex.Array) -> Stats:
48+
assert data.ndim == 2
49+
mean = jnp.mean(data, axis=0)
50+
return Stats(jnp.zeros_like(mean), jnp.ones_like(mean))
51+
4052
@partial(jax.jit, static_argnums=0)
4153
def normalize(self, datum: chex.Array, stats: Stats) -> chex.Array:
4254
assert datum.ndim == 1

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33
from setuptools import setup, find_packages
44

55
required = [
6-
'flax>=0.4.1,<=0.7.0',
7-
'jax==0.4.13',
6+
'flax>=0.7.0',
7+
'jax>=0.4.13',
88
'jaxtyping>=0.2.20',
9-
'jaxlib==0.4.13',
10-
'pytest==7.4.0',
9+
'jaxlib>==0.4.13',
10+
'pytest>=7.4.0',
1111
'matplotlib>=3.5.1',
1212
'numpy>=1.22.2',
1313
'optax>=0.1.1',
1414
'scipy>=1.8.0',
1515
'wandb>=0.12.11',
16-
'distrax~=0.1.2',
16+
'distrax @ git+https://github.com/deepmind/distrax.git',
1717
'argparse-dataclass>=0.2.1',
1818
'jaxutils',
1919
'chex',

tests/test_deterministic_ensemble_bnn_sm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import vmap
43

54
from bsm.bayesian_regression import DeterministicEnsemble
65
from bsm.statistical_model import BNNStatisticalModel
76
from bsm.utils.normalization import Data
8-
from bsm.utils.type_aliases import StatisticalModelOutput
97

108
key = jr.PRNGKey(0)
119
input_dim = 1

tests/test_deterministic_fsvg_ensemble_bnn_sm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import vmap
43

54
from bsm.bayesian_regression import DeterministicFSVGDEnsemble
65
from bsm.statistical_model import BNNStatisticalModel
76
from bsm.utils.normalization import Data
8-
from bsm.utils.type_aliases import StatisticalModelOutput
97

108
key = jr.PRNGKey(0)
119
input_dim = 1

tests/test_deterministic_gru_ensemble_bnn_sm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from jax import vmap
44

55
from bsm.bayesian_regression import DeterministicGRUEnsemble
6-
from bsm.utils.general_utils import create_windowed_array
76
from bsm.statistical_model.brnn_statistical_model import BRNNStatisticalModel
7+
from bsm.utils.general_utils import create_windowed_array
88
from bsm.utils.normalization import Data
9-
from bsm.utils.type_aliases import StatisticalModelOutput
109

1110
key = jr.PRNGKey(0)
1211
input_dim = 1
@@ -61,7 +60,7 @@ def test_statistical_model_state_of_prediction():
6160
in_domain_test_ys = in_domain_test_ys.transpose()
6261

6362
in_domain_preds = model.predict_batch(in_domain_test_xs,
64-
statistical_model_state)
63+
statistical_model_state)
6564

6665

6766
def test_good_probabilistic_ensemble_fit():

tests/test_gaussian_processes_sm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import vmap
43

54
from bsm.statistical_model import GPStatisticalModel
65
from bsm.utils.normalization import Data
7-
from bsm.utils.type_aliases import StatisticalModelOutput
86

97
key = jr.PRNGKey(0)
108
input_dim = 1

tests/test_probabilistic_ensemble_sm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import vmap
43

54
from bsm.bayesian_regression import ProbabilisticEnsemble
65
from bsm.statistical_model import BNNStatisticalModel
76
from bsm.utils.normalization import Data
8-
from bsm.utils.type_aliases import StatisticalModelOutput
97

108
key = jr.PRNGKey(0)
119
input_dim = 1

tests/test_probabilistic_fsvg_ensemble_bnn_sm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import vmap
43

54
from bsm.bayesian_regression import ProbabilisticFSVGDEnsemble
65
from bsm.statistical_model import BNNStatisticalModel
76
from bsm.utils.normalization import Data
8-
from bsm.utils.type_aliases import StatisticalModelOutput
97

108
key = jr.PRNGKey(0)
119
input_dim = 1

tests/test_probabilistic_gru_ensemble_bnn_sm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from bsm.statistical_model import BRNNStatisticalModel
77
from bsm.utils.general_utils import create_windowed_array
88
from bsm.utils.normalization import Data
9-
from bsm.utils.type_aliases import StatisticalModelOutput
109

1110
key = jr.PRNGKey(0)
1211
input_dim = 1

0 commit comments

Comments
 (0)