|
| 1 | +# Copyright 2022 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +import collections |
| 17 | +import sys |
| 18 | +from typing import Optional |
| 19 | + |
| 20 | +import arviz as az |
| 21 | +import blackjax |
| 22 | +import jax |
| 23 | +import jax.numpy as jnp |
| 24 | +import jax.random as random |
| 25 | +import numpy as np |
| 26 | +import pymc as pm |
| 27 | +from pymc import modelcontext |
| 28 | +from pymc.sampling import RandomSeed, _get_seeds_per_chain |
| 29 | +from pymc.sampling_jax import get_jaxified_graph |
| 30 | +from pymc.util import get_default_varnames |
| 31 | + |
| 32 | + |
| 33 | +def convert_flat_trace_to_idata( |
| 34 | + samples, |
| 35 | + dims=None, |
| 36 | + coords=None, |
| 37 | + include_transformed=False, |
| 38 | + postprocessing_backend="cpu", |
| 39 | + model=None, |
| 40 | +): |
| 41 | + |
| 42 | + model = modelcontext(model) |
| 43 | + init_position_dict = model.initial_point() |
| 44 | + trace = collections.defaultdict(list) |
| 45 | + astart = pm.blocking.DictToArrayBijection.map(init_position_dict) |
| 46 | + for sample in samples: |
| 47 | + raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info) |
| 48 | + point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict) |
| 49 | + for p, v in point.items(): |
| 50 | + trace[p].append(v.tolist()) |
| 51 | + |
| 52 | + trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} |
| 53 | + |
| 54 | + var_names = model.unobserved_value_vars |
| 55 | + vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed)) |
| 56 | + print("Transforming variables...", file=sys.stdout) |
| 57 | + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) |
| 58 | + result = jax.vmap(jax.vmap(jax_fn))( |
| 59 | + *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) |
| 60 | + ) |
| 61 | + |
| 62 | + trace = {v.name: r for v, r in zip(vars_to_sample, result)} |
| 63 | + idata = az.from_dict(trace, dims=dims, coords=coords) |
| 64 | + |
| 65 | + return idata |
| 66 | + |
| 67 | + |
| 68 | +def fit_pathfinder( |
| 69 | + iterations=5_000, |
| 70 | + random_seed: Optional[RandomSeed] = None, |
| 71 | + postprocessing_backend="cpu", |
| 72 | + ftol=1e-4, |
| 73 | + model=None, |
| 74 | +): |
| 75 | + """ |
| 76 | + Fit the pathfinder algorithm as implemented in blackjax |
| 77 | +
|
| 78 | + Requires the JAX backend |
| 79 | +
|
| 80 | + Parameters |
| 81 | + ---------- |
| 82 | + iterations : int |
| 83 | + Number of iterations to run. |
| 84 | + random_seed : int |
| 85 | + Random seed to set. |
| 86 | + postprocessing_backend : str |
| 87 | + Where to compute transformations of the trace. |
| 88 | + "cpu" or "gpu". |
| 89 | + ftol : float |
| 90 | + Floating point tolerance |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + arviz.InferenceData |
| 95 | +
|
| 96 | + Reference |
| 97 | + --------- |
| 98 | + https://arxiv.org/abs/2108.03782 |
| 99 | + """ |
| 100 | + |
| 101 | + (random_seed,) = _get_seeds_per_chain(random_seed, 1) |
| 102 | + |
| 103 | + model = modelcontext(model) |
| 104 | + |
| 105 | + rvs = [rv.name for rv in model.value_vars] |
| 106 | + init_position_dict = model.initial_point() |
| 107 | + init_position = [init_position_dict[rv] for rv in rvs] |
| 108 | + |
| 109 | + new_logprob, new_input = pm.aesaraf.join_nonshared_inputs( |
| 110 | + init_position_dict, (model.logp(),), model.value_vars, () |
| 111 | + ) |
| 112 | + |
| 113 | + logprob_fn_list = get_jaxified_graph([new_input], new_logprob) |
| 114 | + |
| 115 | + def logprob_fn(x): |
| 116 | + return logprob_fn_list(x)[0] |
| 117 | + |
| 118 | + dim = sum(v.size for v in init_position_dict.values()) |
| 119 | + |
| 120 | + rng_key = random.PRNGKey(random_seed) |
| 121 | + w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) |
| 122 | + path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) |
| 123 | + |
| 124 | + pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol) |
| 125 | + state = pathfinder.init(w0) |
| 126 | + |
| 127 | + def inference_loop(rng_key, kernel, initial_state, num_samples): |
| 128 | + @jax.jit |
| 129 | + def one_step(state, rng_key): |
| 130 | + state, info = kernel(rng_key, state) |
| 131 | + return state, (state, info) |
| 132 | + |
| 133 | + keys = jax.random.split(rng_key, num_samples) |
| 134 | + return jax.lax.scan(one_step, initial_state, keys) |
| 135 | + |
| 136 | + _, rng_key = random.split(rng_key) |
| 137 | + print("Running pathfinder...", file=sys.stdout) |
| 138 | + _, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations) |
| 139 | + |
| 140 | + dims = { |
| 141 | + var_name: [dim for dim in dims if dim is not None] |
| 142 | + for var_name, dims in model.RV_dims.items() |
| 143 | + } |
| 144 | + |
| 145 | + idata = convert_flat_trace_to_idata( |
| 146 | + samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims |
| 147 | + ) |
| 148 | + |
| 149 | + return idata |
0 commit comments