Skip to content

Commit f560e1e

Browse files
twieckiricardoV94
authored andcommitted
Add wrapper for running blackjax pathfinder.
1 parent 549b7fb commit f560e1e

File tree

9 files changed

+243
-0
lines changed

9 files changed

+243
-0
lines changed

.github/workflows/test.yml

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ jobs:
1919
matrix:
2020
os: [ubuntu-18.04]
2121
floatx: [float32, float64]
22+
test-subset:
23+
- pymc_experimental/tests
2224
fail-fast: false
2325
runs-on: ${{ matrix.os }}
2426
env:
@@ -80,6 +82,8 @@ jobs:
8082
matrix:
8183
os: [windows-latest]
8284
floatx: [float32, float64]
85+
test-subset:
86+
- pymc_experimental/tests
8387
fail-fast: false
8488
runs-on: ${{ matrix.os }}
8589
env:

conda-envs/environment-test-py38.yml

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ dependencies:
1111
- xhistogram
1212
- pip:
1313
- "git+https://github.com/pymc-devs/pymc.git@main"
14+
- blackjax

conda-envs/environment-test-py39.yml

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ dependencies:
1111
- xhistogram
1212
- pip:
1313
- "git+https://github.com/pymc-devs/pymc.git@main"
14+
- blackjax

pymc_experimental/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212

1313

1414
from pymc_experimental import distributions, gp, utils
15+
from pymc_experimental.inference.fit import fit
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pymc_experimental.inference.fit import fit

pymc_experimental/inference/fit.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def fit(method, **kwargs):
17+
"""
18+
Fit a model with an inference algorithm
19+
20+
Parameters
21+
----------
22+
method : str
23+
Which inference method to run.
24+
Supported: pathfinder
25+
26+
kwargs are passed on.
27+
28+
Returns
29+
-------
30+
arviz.InferenceData
31+
"""
32+
if method == "pathfinder":
33+
try:
34+
from pymc_experimental.inference.pathfinder import fit_pathfinder
35+
except ImportError as exc:
36+
raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc
37+
return fit_pathfinder(**kwargs)
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import collections
17+
import sys
18+
from typing import Optional
19+
20+
import arviz as az
21+
import blackjax
22+
import jax
23+
import jax.numpy as jnp
24+
import jax.random as random
25+
import numpy as np
26+
import pymc as pm
27+
from pymc import modelcontext
28+
from pymc.sampling import RandomSeed, _get_seeds_per_chain
29+
from pymc.sampling_jax import get_jaxified_graph
30+
from pymc.util import get_default_varnames
31+
32+
33+
def convert_flat_trace_to_idata(
34+
samples,
35+
dims=None,
36+
coords=None,
37+
include_transformed=False,
38+
postprocessing_backend="cpu",
39+
model=None,
40+
):
41+
42+
model = modelcontext(model)
43+
init_position_dict = model.initial_point()
44+
trace = collections.defaultdict(list)
45+
astart = pm.blocking.DictToArrayBijection.map(init_position_dict)
46+
for sample in samples:
47+
raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info)
48+
point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict)
49+
for p, v in point.items():
50+
trace[p].append(v.tolist())
51+
52+
trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()}
53+
54+
var_names = model.unobserved_value_vars
55+
vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
56+
print("Transforming variables...", file=sys.stdout)
57+
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
58+
result = jax.vmap(jax.vmap(jax_fn))(
59+
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
60+
)
61+
62+
trace = {v.name: r for v, r in zip(vars_to_sample, result)}
63+
idata = az.from_dict(trace, dims=dims, coords=coords)
64+
65+
return idata
66+
67+
68+
def fit_pathfinder(
69+
iterations=5_000,
70+
random_seed: Optional[RandomSeed] = None,
71+
postprocessing_backend="cpu",
72+
ftol=1e-4,
73+
model=None,
74+
):
75+
"""
76+
Fit the pathfinder algorithm as implemented in blackjax
77+
78+
Requires the JAX backend
79+
80+
Parameters
81+
----------
82+
iterations : int
83+
Number of iterations to run.
84+
random_seed : int
85+
Random seed to set.
86+
postprocessing_backend : str
87+
Where to compute transformations of the trace.
88+
"cpu" or "gpu".
89+
ftol : float
90+
Floating point tolerance
91+
92+
Returns
93+
-------
94+
arviz.InferenceData
95+
96+
Reference
97+
---------
98+
https://arxiv.org/abs/2108.03782
99+
"""
100+
101+
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
102+
103+
model = modelcontext(model)
104+
105+
rvs = [rv.name for rv in model.value_vars]
106+
init_position_dict = model.initial_point()
107+
init_position = [init_position_dict[rv] for rv in rvs]
108+
109+
new_logprob, new_input = pm.aesaraf.join_nonshared_inputs(
110+
init_position_dict, (model.logp(),), model.value_vars, ()
111+
)
112+
113+
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
114+
115+
def logprob_fn(x):
116+
return logprob_fn_list(x)[0]
117+
118+
dim = sum(v.size for v in init_position_dict.values())
119+
120+
rng_key = random.PRNGKey(random_seed)
121+
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim))
122+
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol)
123+
124+
pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol)
125+
state = pathfinder.init(w0)
126+
127+
def inference_loop(rng_key, kernel, initial_state, num_samples):
128+
@jax.jit
129+
def one_step(state, rng_key):
130+
state, info = kernel(rng_key, state)
131+
return state, (state, info)
132+
133+
keys = jax.random.split(rng_key, num_samples)
134+
return jax.lax.scan(one_step, initial_state, keys)
135+
136+
_, rng_key = random.split(rng_key)
137+
print("Running pathfinder...", file=sys.stdout)
138+
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations)
139+
140+
dims = {
141+
var_name: [dim for dim in dims if dim is not None]
142+
for var_name, dims in model.RV_dims.items()
143+
}
144+
145+
idata = convert_flat_trace_to_idata(
146+
samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims
147+
)
148+
149+
return idata
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
17+
import numpy as np
18+
import pymc as pm
19+
import pytest
20+
21+
import pymc_experimental as pmx
22+
23+
24+
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
25+
def test_pathfinder():
26+
# Data of the Eight Schools Model
27+
J = 8
28+
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
29+
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
30+
31+
with pm.Model() as model:
32+
33+
mu = pm.Normal("mu", mu=0.0, sigma=10.0)
34+
tau = pm.HalfCauchy("tau", 5.0)
35+
36+
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
37+
theta_1 = mu + tau * theta
38+
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)
39+
40+
idata = pmx.fit(method="pathfinder", iterations=100)
41+
42+
assert idata is not None
43+
assert "theta" in idata.posterior._variables.keys()
44+
assert "tau" in idata.posterior._variables.keys()
45+
assert "mu" in idata.posterior._variables.keys()
46+
assert idata.posterior["mu"].shape == (1, 100)
47+
assert idata.posterior["tau"].shape == (1, 100)
48+
assert idata.posterior["theta"].shape == (1, 100, 8)

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
dask[all]
2+
blackjax

0 commit comments

Comments
 (0)