From 8eed424c181f0c5c0b5cf5e50780115d14704fc6 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 12:42:00 -0500 Subject: [PATCH 1/7] add static adjusted mclmc --- blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 4 +- blackjax/mcmc/adjusted_mclmc.py | 57 ++---- blackjax/mcmc/adjusted_mclmc_dynamic.py | 257 ++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 110 +++++++++- 5 files changed, 384 insertions(+), 46 deletions(-) create mode 100644 blackjax/mcmc/adjusted_mclmc_dynamic.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..35c9e3b58 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,6 +13,7 @@ from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -112,6 +113,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 1e1317684..fad5dcb97 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,5 +1,5 @@ from . import ( - adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -25,5 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "adjusted_mclmc", + "adjusted_mclmc_dynamic", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 81fbc2835..8288772a3 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" from typing import Callable, Union import jax @@ -19,28 +23,26 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.hmc import HMCInfo, HMCState from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] -def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + return HMCState(position, logdensity, logdensity_grad) def build_kernel( - integration_steps_fn, + num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], sqrt_diag_cov=1.0, ): - """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -63,15 +65,13 @@ def build_kernel( def kernel( rng_key: PRNGKey, - state: DynamicHMCState, + state: HMCState, logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) - key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -90,11 +90,10 @@ def kernel( ) return ( - DynamicHMCState( + HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, - next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -110,10 +109,9 @@ def as_top_level_api( *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + num_integration_steps, ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- @@ -140,15 +138,15 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, + num_integration_steps, integrator=integrator, - next_random_arg_fn=next_random_arg_fn, sqrt_diag_cov=sqrt_diag_cov, divergence_threshold=divergence_threshold, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - return init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) def update_fn(rng_key: PRNGKey, state): return kernel( @@ -240,18 +238,3 @@ def generate( return sampled_state, info, other_proposal_info return generate - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..81fbc2835 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,257 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 474f67293..45d60f84a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,7 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -146,7 +146,7 @@ def run_mclmc( return samples - def run_adjusted_mclmc( + def run_adjusted_mclmc_dynamic( self, logdensity_fn, num_steps, @@ -158,13 +158,13 @@ def run_adjusted_mclmc( init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) @@ -177,7 +177,7 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - target_acc_rate = 0.65 + target_acc_rate = 0.9 ( blackjax_state_after_tuning, @@ -197,7 +197,7 @@ def run_adjusted_mclmc( step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( @@ -218,6 +218,73 @@ def run_adjusted_mclmc( return out + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + num_integration_steps=avg_num_integration_steps, + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -334,7 +401,35 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - def test_adjusted_mclmc(self): + @parameterized.parameters([True, False]) + def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_dynamic( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + @parameterized.parameters([True, False]) + def test_adjusted_mclmc(self, diagonal_preconditioning): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -351,6 +446,7 @@ def test_adjusted_mclmc(self): logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, ) coefs_samples = states["coefs"][3000:] From 9dd6bdba039003538eaf635fde7678defdbc7350 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:27:08 -0500 Subject: [PATCH 2/7] add static adjusted mclmc --- .../adaptation/adjusted_mclmc_adaptation.py | 10 ++--- blackjax/adaptation/mclmc_adaptation.py | 22 +++++----- blackjax/mcmc/adjusted_mclmc.py | 10 +++-- blackjax/mcmc/adjusted_mclmc_dynamic.py | 10 +++-- blackjax/mcmc/integrators.py | 12 +++--- blackjax/mcmc/mclmc.py | 8 ++-- tests/mcmc/test_integrators.py | 4 +- tests/mcmc/test_sampling.py | 40 +++++++++++-------- 8 files changed, 64 insertions(+), 52 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f5d54e5c9..eabb642a3 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -74,7 +74,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -152,7 +152,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) # step updating @@ -283,9 +283,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace( - sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) - ) + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -323,7 +321,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) return next_state, next_state.position diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 8452b6171..aa192b964 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -123,7 +123,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 8288772a3..9b868562c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -40,7 +40,7 @@ def build_kernel( num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -105,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -140,7 +142,7 @@ def as_top_level_api( kernel = build_kernel( num_integration_steps, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 81fbc2835..1a69e1a28 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -38,7 +38,7 @@ def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -106,7 +108,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -143,7 +145,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 593683ca4..733e7e960 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..ff9638a1f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """Build a HMC kernel. Parameters @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -107,7 +107,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..c37c0ede6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 45d60f84a..a4ea66a9b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -112,10 +112,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -133,7 +133,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -144,6 +144,8 @@ def run_mclmc( transform=lambda state, info: state.position, ) + print(samples["coefs"][0].item()) + return samples def run_adjusted_mclmc_dynamic( @@ -164,12 +166,12 @@ def run_adjusted_mclmc_dynamic( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -204,7 +206,7 @@ def run_adjusted_mclmc_dynamic( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -216,6 +218,8 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) + print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) + return out def run_adjusted_mclmc( @@ -235,10 +239,10 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, num_integration_steps=avg_num_integration_steps, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -271,7 +275,7 @@ def run_adjusted_mclmc( step_size=step_size, num_integration_steps=L / step_size, integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -402,7 +406,10 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters([True, False]) - def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + def test_adjusted_mclmc_dynamic( + self, + diagonal_preconditioning, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -495,7 +502,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -506,10 +513,10 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -523,13 +530,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From a49bb35f37293f0033ea4c9c5b8daf7ff62c1461 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:54:36 -0500 Subject: [PATCH 3/7] add static adjusted mclmc --- tests/mcmc/test_sampling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a4ea66a9b..d788696f8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -144,8 +144,6 @@ def run_mclmc( transform=lambda state, info: state.position, ) - print(samples["coefs"][0].item()) - return samples def run_adjusted_mclmc_dynamic( @@ -218,8 +216,6 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) - print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) - return out def run_adjusted_mclmc( From 35d71ffb9ece4940b113ed2caedf4f302b20ddfd Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:02:35 -0500 Subject: [PATCH 4/7] draft --- blackjax/adaptation/ensemble_mclmc.py | 224 +++++++++++++++++++++ blackjax/adaptation/ensemble_umclmc.py | 265 +++++++++++++++++++++++++ blackjax/adaptation/step_size.py | 43 ++++ blackjax/mcmc/mclmc.py | 2 +- blackjax/util.py | 123 +++++++++++- 5 files changed, 655 insertions(+), 2 deletions(-) create mode 100644 blackjax/adaptation/ensemble_mclmc.py create mode 100644 blackjax/adaptation/ensemble_umclmc.py diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py new file mode 100644 index 000000000..dabf5a3be --- /dev/null +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -0,0 +1,224 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +from typing import Callable, NamedTuple, Any + +import jax +import jax.numpy as jnp + +from blackjax.util import run_eca +from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss + +from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn + + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + epsadap_state: Any + sample_count: int + + + +def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): + """MCLMC kernel""" + + kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + + def sequential_kernel(key, state, adap): + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + + return sequential_kernel + + + +class Adaptation: + + def __init__(self, adap_state, num_adaptation_samples, + steps_per_sample, acc_prob_target= 0.8, + observables = lambda x: 0., + observables_for_bias = lambda x: 0., contract= lambda x: 0.): + + self.num_adaptation_samples= num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + ### Determine the initial hyperparameters ### + + ## stepsize ## + #if we switched to the more accurate integrator we can use longer step size + #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size #* integrator_factor * adjustment_factor + + #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + ### Initialize the dual averaging adaptation ### + #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + #epsadap_state = da_init_fn(step_size) + + ### Initialize the bisection for finding the step size + epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + + self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) + + + def summary_statistics_fn(self, state, info, rng_key): + + return {'acceptance_probability': info.acceptance_rate, + #'inv_acceptance_probability': 1./info.acceptance_rate, + 'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'observables': self.observables(state.position), + 'observables_for_bias': self.observables_for_bias(state.position) + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + acc_prob = Etheta['acceptance_probability'] + #acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + true_bias = self.contract(Etheta['observables_for_bias']) + + + info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, + 'acc_prob': acc_prob, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'observables': Etheta['observables'] + } + + # hyperparameter adaptation + + # Dual Averaging + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + + # def update(_): + # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) + # step_size = jnp.exp(da_state.log_step_size) + # return da_state, step_size + + # def dont_update(_): + # da_state = adaptation_state.epsadap_state + # return da_state, jnp.exp(da_state.log_step_size_avg) + + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) + + # Bisection + epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) + + return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + +def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, + alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters + diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, + observables = lambda x: None, + ensemble_observables= None + ): + + observables_for_bias, contract = bias(model) + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) + + ### burn-in with the unadjusted method ### + kernel = umclmc.build_kernel(model.logdensity_fn) + save_num= (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) + + if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + + num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) + #print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) + + ### refine the results with the adjusted method ### + _acc_prob = acc_prob + if integrator_coefficients == None: + high_dims = model.ndims > 200 + _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients + if acc_prob == None: + _acc_prob = 0.9 if high_dims else 0.7 + + else: + _integrator_coefficients = integrator_coefficients + if acc_prob == None: + _acc_prob = 0.9 + + + integrator = generate_isokinetic_integrator(_integrator_coefficients) + gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + + if diagonal_preconditioning: + sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + + else: + sqrt_diag_cov= 1. + + kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) + num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. + + adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + + final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) + + return info1, info2, gradient_calls_per_step, _acc_prob + + \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py new file mode 100644 index 000000000..73b9ae4f7 --- /dev/null +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -0,0 +1,265 @@ + +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from typing import Callable, NamedTuple, Any + +from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet +from blackjax.types import Array, ArrayLike +from blackjax.util import pytree_size +from blackjax.mcmc import mclmc +from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.util import ensemble_execute_fn + + + +def no_nans(a): + flat_a, unravel_fn = ravel_pytree(a) + return jnp.all(jnp.isfinite(flat_a)) + + +def nan_reject(nonans, old, new): + """Equivalent to + return new if nonans else old""" + + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + + + +def build_kernel(logdensity_fn): + """MCLMC kernel (with nan rejection)""" + + kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) + + + def sequential_kernel(key, state, adap): + + new_state, info = kernel(key, state, adap.L, adap.step_size) + + # reject the new state if there were nans + nonans = no_nans(new_state) + new_state = nan_reject(nonans, state, new_state) + + return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} + + + return sequential_kernel + + + +def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): + """initialize the chains based on the equipartition of the initial condition. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + """ + + def sequential_init(key, x, args): + """initialize the position using sample_init and the velocity along the gradient""" + position = sample_init(key) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + flat_g, unravel_fn = ravel_pytree(logdensity_grad) + velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + return IntegratorState(position, velocity, logdensity, logdensity_grad), None + + def summary_statistics_fn(state): + """compute the diagonal elements of the equipartition matrix""" + return -state.position * state.logdensity_grad + + def ensemble_init(key, state, signs): + """flip the velocity, depending on the equipartition condition""" + velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) + return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None + + key1, key2= jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) + signs = -2. * (equipartition < 1.) + 1. + initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) + + return initial_state + + +def update_history(new_vals, history): + return jnp.concatenate((new_vals[None, :], history[:-1])) + +def update_history_scalar(new_val, history): + return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + +def contract_history(theta, weights): + + square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) + average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) + + r = (average_square - square_average) / square_average + + return jnp.array([jnp.max(r), + jnp.average(r) + ]) + + +class History(NamedTuple): + observables: Array + stopping: Array + weights: Array + + +class AdaptationState(NamedTuple): + + L: float + sqrt_diag_cov: Any + step_size: float + + step_count: int + EEVPD: float + EEVPD_wanted: float + history: Any + + +def equipartition_diagonal(state): + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + equipartition_loss = average over parameters (Ei)""" + return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) + + + +def equipartition_fullrank(state, rng_key): + """loss = Tr[(1 - E)^T (1 - E)] / d^2 + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" + + x, unravel_fn = ravel_pytree(state.position) + g, unravel_fn = ravel_pytree(state.logdensity_grad) + d = len(x) + + def func(z): + """z here has the same shape as position""" + return (z + jnp.dot(z, g) * x) + + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + return jax.vmap(func)(z) + + +def equipartition_diagonal_loss(Eii): + Eii_flat, unravel_fn = ravel_pytree(Eii) + return jnp.average(jnp.square(1.- Eii_flat)) + + +def equipartition_fullrank_loss(delta_z): + d = delta_z.shape[-1] + return jnp.average(jnp.square(delta_z)) / d + + +class Adaptation: + + def __init__(self, num_dims, + alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, + bias_type= 0, save_num = 10, + observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. + ): + + self.num_dims = num_dims + self.alpha = alpha + self.C = C + self.power = power + self.r_end = r_end + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + self.bias_type = bias_type + self.save_num = save_num + #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + + r_save_num = save_num + + history = History(observables= jnp.zeros((r_save_num, num_dims)), + stopping= jnp.full((save_num,), jnp.nan), + weights= jnp.zeros(r_save_num)) + + self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step + sqrt_diag_cov= jnp.ones(num_dims), + step_size= 0.01 * jnp.sqrt(num_dims), + step_count= 0, + EEVPD=1e-3, EEVPD_wanted=1e-3, + history=history) + + + def summary_statistics_fn(self, state, info, rng_key): + + position_flat, unravel_fn = ravel_pytree(state.position) + + return {'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'x': position_flat, 'xsq': jnp.square(position_flat), + 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), + 'rejection_rate_nans': info['nans'], + 'observables_for_bias': self.observables_for_bias(state.position), + 'observables': self.observables(state.position), + 'entropy': - info['logdensity'] + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + + history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) + history_weights = update_history_scalar(1., adaptation_state.history.weights) + fluctuations = contract_history(history_observables, history_weights) + history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), + adaptation_state.history.stopping) + history = History(history_observables, history_stopping, history_weights) + + L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) + sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims + true_bias = self.contract(Etheta['observables_for_bias']) + nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + # hyperparameter adaptation + # estimate bias + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + EEVPD_wanted = self.C * jnp.power(bias, self.power) + + + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) + eps_factor = jnp.clip(eps_factor, 0.3, 3.) + + eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + + # determine if we want to finish this stage (i.e. if loss is no longer decreassing) + #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + #while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) + + info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, + 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'r_max': fluctuations[0], 'r_avg': fluctuations[1], + 'while_cond': while_cond, 'entropy': Etheta['entropy'], + 'observables': Etheta['observables']} + + adaptation_state_new = AdaptationState(L, + sqrt_diag_cov, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history) + + return adaptation_state_new, info_to_be_stored + diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2b06172c0..7a076b962 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -257,3 +257,46 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" + + def update(state, exp_x, acc_rate_new): + + bounds, terminated = state + + # update the bounds + acc_high = acc_rate_new > acc_prob_wanted + x = jnp.log(exp_x) + + def on_true(bounds): + bounds0 = jnp.max(jnp.array([bounds[0], x])) + return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + + def on_false(bounds): + bounds1 = jnp.min(jnp.array([bounds[1], x])) + return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + + + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds + bracketing = jnp.all(jnp.isfinite(bounds_new)) + + def reduce(bounds): + return x_new + + def bisect(bounds): + return jnp.average(bounds) + + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) + + stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) + + terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated + + return (bounds_new, terminated_new), stepsize + + + return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..e5cc46213 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 8cdcd45ee..8f6494167 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,11 +4,14 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax +from jax import jit, lax, device_put, vmap from jax.flatten_util import ravel_pytree from jax.random import normal, split from jax.tree_util import tree_leaves, tree_map +from jax.sharding import Mesh, PartitionSpec, NamedSharding +from jax.experimental.shard_map import shard_map + from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -314,3 +317,121 @@ def incremental_value_update( ) total += weight return total, average + +def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): + + def _step(state_all, xs): + """This function operates on a single device.""" + state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) + + return (state, adaptation_state), info_to_be_stored + + + if ensemble_info != None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) + + return step + + else: + return _step + + +def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): + + step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) + + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + final_state, final_adaptation_state = final_state_all + return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars + + + p, pscalar = PartitionSpec('chains'), PartitionSpec() + parallel_execute = shard_map(all_steps, + mesh= mesh, + in_specs= (p, p, pscalar), + out_specs= (p, pscalar, pscalar), + check_rep=False + ) + + # produce all random keys that will be needed + key_sampling, key_adaptation = split(rng_key) + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) + + return final_state, final_adaptation_state, info_history + + + + +def ensemble_execute_fn(func, rng_key, num_chains, mesh, + x= None, + args= None, + summary_statistics_fn= lambda y: 0., + ): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec('chains'), PartitionSpec() + + if x == None: + X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X= x + + adaptation_update= lambda _, Etheta: (Etheta, None) + + _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map(F, + mesh= mesh, + in_specs= (p, p), + out_specs= (p, pscalar), + check_rep=False + ) + + keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) \ No newline at end of file From 0f3df53011f00ecec9d879f23c01f08588608241 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:13:11 -0500 Subject: [PATCH 5/7] draft --- blackjax/adaptation/ensemble_mclmc.py | 14 +++++++------- blackjax/adaptation/ensemble_umclmc.py | 8 ++++---- blackjax/mcmc/adjusted_mclmc.py | 4 ++-- blackjax/util.py | 3 +++ 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index dabf5a3be..b303f8ae7 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -37,13 +37,13 @@ class AdaptationState(NamedTuple): -def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): +def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): """MCLMC kernel""" - kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) return sequential_kernel @@ -200,16 +200,16 @@ def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) # scale the stepsize so that it reflects averag scale change of the preconditioning - average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) else: - sqrt_diag_cov= 1. + inverse_mass_matrix= 1. - kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 73b9ae4f7..458361103 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -119,7 +119,7 @@ class History(NamedTuple): class AdaptationState(NamedTuple): L: float - sqrt_diag_cov: Any + inverse_mass_matrix: Any step_size: float step_count: int @@ -189,7 +189,7 @@ def __init__(self, num_dims, weights= jnp.zeros(r_save_num)) self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - sqrt_diag_cov= jnp.ones(num_dims), + inverse_mass_matrix= jnp.ones(num_dims), step_size= 0.01 * jnp.sqrt(num_dims), step_count= 0, EEVPD=1e-3, EEVPD_wanted=1e-3, @@ -225,7 +225,7 @@ def update(self, adaptation_state, Etheta): history = History(history_observables, history_stopping, history_weights) L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims true_bias = self.contract(Etheta['observables_for_bias']) nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) @@ -254,7 +254,7 @@ def update(self, adaptation_state, Etheta): 'observables': Etheta['observables']} adaptation_state_new = AdaptationState(L, - sqrt_diag_cov, + inverse_mass_matrix, adaptation_state.step_size * eps_factor, adaptation_state.step_count + 1, EEVPD, diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..8a5a37e55 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" diff --git a/blackjax/util.py b/blackjax/util.py index 8f6494167..92eb77f40 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -377,7 +377,10 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): ) # produce all random keys that will be needed + # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) From 04522f57c94cf6d5aa2bbb25f9cd14450be645da Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:18:55 -0500 Subject: [PATCH 6/7] change order of parameters --- blackjax/mcmc/adjusted_mclmc.py | 16 ++++++++-------- tests/mcmc/test_sampling.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..f390402f2 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -140,7 +140,7 @@ def as_top_level_api( """ kernel = build_kernel( - num_integration_steps, + logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, @@ -152,11 +152,11 @@ def init_fn(position: ArrayLikeTree, rng_key=None): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d788696f8..e9068326e 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -237,13 +237,13 @@ def run_adjusted_mclmc( kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, - num_integration_steps=avg_num_integration_steps, inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, )( rng_key=rng_key, state=state, step_size=step_size, - logdensity_fn=logdensity_fn, + num_integration_steps=avg_num_integration_steps, ) target_acc_rate = 0.9 From a7c99b92ee59b82db855d4b919a6254aa979a97d Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Jan 2025 21:00:11 +0000 Subject: [PATCH 7/7] draft --- blackjax/adaptation/ensemble_mclmc.py | 352 +++++++++++++++---------- blackjax/adaptation/ensemble_umclmc.py | 319 ++++++++++++---------- blackjax/adaptation/step_size.py | 43 ++- blackjax/mcmc/mclmc.py | 4 +- blackjax/util.py | 200 ++++++++------ tests/mcmc/test_sampling.py | 102 +++++++ 6 files changed, 655 insertions(+), 365 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index b303f8ae7..c01791daa 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -11,142 +11,158 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" -from typing import Callable, NamedTuple, Any +from typing import Any, NamedTuple import jax import jax.numpy as jnp -from blackjax.util import run_eca -from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients -from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt import blackjax.adaptation.ensemble_umclmc as umclmc -from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss - -from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, + equipartition_fullrank, + equipartition_fullrank_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +from blackjax.util import run_eca - class AdaptationState(NamedTuple): steps_per_sample: float step_size: float epsadap_state: Any sample_count: int - - -def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): - """MCLMC kernel""" - - kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) - - def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) - - return sequential_kernel +# put the arguments of build_kernel in a suitable order +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, +)(rng_key=key, state=state, step_size=adap.step_size, num_integration_steps=adap.steps_per_sample, L_proposal_factor=1.25) class Adaptation: - - def __init__(self, adap_state, num_adaptation_samples, - steps_per_sample, acc_prob_target= 0.8, - observables = lambda x: 0., - observables_for_bias = lambda x: 0., contract= lambda x: 0.): - - self.num_adaptation_samples= num_adaptation_samples + def __init__( + self, + adap_state, + num_adaptation_samples, + steps_per_sample, + acc_prob_target=0.8, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): + self.num_adaptation_samples = num_adaptation_samples self.observables = observables self.observables_for_bias = observables_for_bias self.contract = contract - - ### Determine the initial hyperparameters ### - - ## stepsize ## - #if we switched to the more accurate integrator we can use longer step size - #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 - # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adap_state.step_size #* integrator_factor * adjustment_factor - - #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - ### Initialize the dual averaging adaptation ### - #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - #epsadap_state = da_init_fn(step_size) - - ### Initialize the bisection for finding the step size + # adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size # * integrator_factor * adjustment_factor + + # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + # Initialize the dual averaging adaptation # + # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + # epsadap_state = da_init_fn(step_size) + + # Initialize the bisection for finding the step size epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - - self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) - - + + self.initial_state = AdaptationState( + steps_per_sample, step_size, epsadap_state, 0 + ) + def summary_statistics_fn(self, state, info, rng_key): - - return {'acceptance_probability': info.acceptance_rate, - #'inv_acceptance_probability': 1./info.acceptance_rate, - 'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'observables': self.observables(state.position), - 'observables_for_bias': self.observables_for_bias(state.position) - } - + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - acc_prob = Etheta['acceptance_probability'] - #acc_prob = 1./Etheta['inv_acceptance_probability'] - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - true_bias = self.contract(Etheta['observables_for_bias']) - - - info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, - 'acc_prob': acc_prob, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'observables': Etheta['observables'] - } + acc_prob = Etheta["acceptance_probability"] + # acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + true_bias = self.contract(Etheta["observables_for_bias"]) + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "observables": Etheta["observables"], + } # hyperparameter adaptation - + # Dual Averaging - # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples - + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + # def update(_): # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) # step_size = jnp.exp(da_state.log_step_size) # return da_state, step_size - + # def dont_update(_): # da_state = adaptation_state.epsadap_state # return da_state, jnp.exp(da_state.log_step_size_avg) - + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) - - # Bisection - epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) - - return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + # Bisection + epsadap_state, step_size = self.epsadap_update( + adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + epsadap_state, + adaptation_state.sample_count + 1, + ), + info_to_be_stored, + ) def bias(model): """should be transfered to benchmarks/""" - + def observables(position): return jnp.square(model.transform(position)) - + def contract(sampler_E_x2): bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 return jnp.array([jnp.max(bsq), jnp.average(bsq)]) - - return observables, contract + return observables, contract def while_steps_num(cond): @@ -156,69 +172,141 @@ def while_steps_num(cond): return jnp.argmin(cond) + 1 -def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, - alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters - diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, - observables = lambda x: None, - ensemble_observables= None - ): - +def emaus( + model, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + bias_type=0, + save_frac=0.2, + C=0.1, + power=3.0 / 8.0, + early_stop=True, + r_end=5e-3, # stage1 parameters + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=10, + acc_prob=None, + observables=lambda x: None, + ensemble_observables=None, +): observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - + # initialize the chains - initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) - - ### burn-in with the unadjusted method ### + initial_state = umclmc.initialize( + key_init, model.logdensity_fn, model.sample_init, num_chains, mesh + ) + + # burn-in with the unadjusted method # kernel = umclmc.build_kernel(model.logdensity_fn) - save_num= (int)(jnp.rint(save_frac * num_steps1)) - adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) - - if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. - - num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) - #print(num_steps_while, save_num) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) - - ### refine the results with the adjusted method ### + save_num = (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation( + model.ndims, + alpha=alpha, + bias_type=bias_type, + save_num=save_num, + C=C, + power=power, + r_end=r_end, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + ensemble_observables, + ) + + if ( + early_stop + ): # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + num_steps_while = while_steps_num( + (info1[0] if ensemble_observables is not None else info1)["while_cond"] + ) + # print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps_while, + num_chains, + mesh, + ensemble_observables, + ) + + # refine the results with the adjusted method # _acc_prob = acc_prob - if integrator_coefficients == None: + if integrator_coefficients is None: high_dims = model.ndims > 200 - _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients - if acc_prob == None: + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: _acc_prob = 0.9 if high_dims else 0.7 - + else: _integrator_coefficients = integrator_coefficients - if acc_prob == None: + if acc_prob is None: _acc_prob = 0.9 - - + integrator = generate_isokinetic_integrator(_integrator_coefficients) - gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 + ) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) - + inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + # scale the stepsize so that it reflects averag scale change of the preconditioning average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) - final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) else: - inverse_mass_matrix= 1. - - kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) - initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + inverse_mass_matrix = 1.0 + + kernel = build_kernel( + model.logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + ) + initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + ) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) - num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. - - adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - - final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) - + num_adaptation_samples = ( + num_samples // 2 + ) # number of samples after which the stepsize is fixed. + + adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + + final_state, final_adaptation_state, info2 = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + ensemble_observables, + ) + return info1, info2, gradient_calls_per_step, _acc_prob - - \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 458361103..7e2390def 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -1,4 +1,3 @@ - # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,22 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" + +from typing import Any, NamedTuple import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from typing import Callable, NamedTuple, Any -from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet -from blackjax.types import Array, ArrayLike -from blackjax.util import pytree_size from blackjax.mcmc import mclmc -from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.mcmc.integrators import ( + IntegratorState, + _normalized_flatten_array, + isokinetic_velocity_verlet, +) +from blackjax.types import Array from blackjax.util import ensemble_execute_fn - def no_nans(a): flat_a, unravel_fn = ravel_pytree(a) return jnp.all(jnp.isfinite(flat_a)) @@ -35,79 +36,96 @@ def no_nans(a): def nan_reject(nonans, old, new): """Equivalent to - return new if nonans else old""" - - return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + return new if nonans else old""" + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) def build_kernel(logdensity_fn): """MCLMC kernel (with nan rejection)""" - - kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) - - + + kernel = mclmc.build_kernel( + logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet + ) + def sequential_kernel(key, state, adap): - new_state, info = kernel(key, state, adap.L, adap.step_size) - + # reject the new state if there were nans nonans = no_nans(new_state) new_state = nan_reject(nonans, state, new_state) - - return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} - + return new_state, { + "nans": 1 - nonans, + "energy_change": info.energy_change * nonans, + "logdensity": info.logdensity * nonans, + } + return sequential_kernel - def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): """initialize the chains based on the equipartition of the initial condition. - We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. """ - + def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) - velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + velocity = unravel_fn( + _normalized_flatten_array(flat_g)[0] + ) # = grad logp/ |grad logp| return IntegratorState(position, velocity, logdensity, logdensity_grad), None - + def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" return -state.position * state.logdensity_grad - + def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) - return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None - - key1, key2= jax.random.split(rng_key) - initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) - signs = -2. * (equipartition < 1.) + 1. - initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) - + velocity = jax.tree_util.tree_map( + lambda sign, u: sign * u, signs, state.momentum + ) + return ( + IntegratorState( + state.position, velocity, state.logdensity, state.logdensity_grad + ), + None, + ) + + key1, key2 = jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn( + sequential_init, + key1, + num_chains, + mesh, + summary_statistics_fn=summary_statistics_fn, + ) + signs = -2.0 * (equipartition < 1.0) + 1.0 + initial_state, _ = ensemble_execute_fn( + ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs + ) + return initial_state - - + + def update_history(new_vals, history): return jnp.concatenate((new_vals[None, :], history[:-1])) + def update_history_scalar(new_val, history): return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + def contract_history(theta, weights): - - square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) - average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) - + square_average = jnp.square(jnp.average(theta, weights=weights, axis=0)) + average_square = jnp.average(jnp.square(theta), weights=weights, axis=0) + r = (average_square - square_average) / square_average - - return jnp.array([jnp.max(r), - jnp.average(r) - ]) + + return jnp.array([jnp.max(r), jnp.average(r)]) class History(NamedTuple): @@ -117,44 +135,44 @@ class History(NamedTuple): class AdaptationState(NamedTuple): - L: float inverse_mass_matrix: Any step_size: float - + step_count: int EEVPD: float EEVPD_wanted: float - history: Any - + history: Any + def equipartition_diagonal(state): - """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. equipartition_loss = average over parameters (Ei)""" - return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) - + return jax.tree_util.tree_map( + lambda x, g: -x * g, state.position, state.logdensity_grad + ) def equipartition_fullrank(state, rng_key): """loss = Tr[(1 - E)^T (1 - E)] / d^2 - where Eij = is the equipartition patrix. - Loss is computed with the Hutchinson's trick.""" + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" x, unravel_fn = ravel_pytree(state.position) g, unravel_fn = ravel_pytree(state.logdensity_grad) d = len(x) - + def func(z): """z here has the same shape as position""" - return (z + jnp.dot(z, g) * x) + return z + jnp.dot(z, g) * x - z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij return jax.vmap(func)(z) def equipartition_diagonal_loss(Eii): Eii_flat, unravel_fn = ravel_pytree(Eii) - return jnp.average(jnp.square(1.- Eii_flat)) + return jnp.average(jnp.square(1.0 - Eii_flat)) def equipartition_fullrank_loss(delta_z): @@ -163,103 +181,138 @@ def equipartition_fullrank_loss(delta_z): class Adaptation: - - def __init__(self, num_dims, - alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, - bias_type= 0, save_num = 10, - observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. - ): - + def __init__( + self, + num_dims, + alpha=1.0, + C=0.1, + power=3.0 / 8.0, + r_end=0.01, + bias_type=0, + save_num=10, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): self.num_dims = num_dims self.alpha = alpha self.C = C self.power = power self.r_end = r_end self.observables = observables - self.observables_for_bias = observables_for_bias + self.observables_for_bias = observables_for_bias self.contract = contract self.bias_type = bias_type self.save_num = save_num - #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - + # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + r_save_num = save_num - - history = History(observables= jnp.zeros((r_save_num, num_dims)), - stopping= jnp.full((save_num,), jnp.nan), - weights= jnp.zeros(r_save_num)) - - self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - inverse_mass_matrix= jnp.ones(num_dims), - step_size= 0.01 * jnp.sqrt(num_dims), - step_count= 0, - EEVPD=1e-3, EEVPD_wanted=1e-3, - history=history) - - + + history = History( + observables=jnp.zeros((r_save_num, num_dims)), + stopping=jnp.full((save_num,), jnp.nan), + weights=jnp.zeros(r_save_num), + ) + + self.initial_state = AdaptationState( + L=jnp.inf, # do not add noise for the first step + inverse_mass_matrix=jnp.ones(num_dims), + step_size=0.01 * jnp.sqrt(num_dims), + step_count=0, + EEVPD=1e-3, + EEVPD_wanted=1e-3, + history=history, + ) + def summary_statistics_fn(self, state, info, rng_key): - position_flat, unravel_fn = ravel_pytree(state.position) - - return {'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'x': position_flat, 'xsq': jnp.square(position_flat), - 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), - 'rejection_rate_nans': info['nans'], - 'observables_for_bias': self.observables_for_bias(state.position), - 'observables': self.observables(state.position), - 'entropy': - info['logdensity'] - } - - + + return { + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "x": position_flat, + "xsq": jnp.square(position_flat), + "E": info["energy_change"], + "Esq": jnp.square(info["energy_change"]), + "rejection_rate_nans": info["nans"], + "observables_for_bias": self.observables_for_bias(state.position), + "observables": self.observables(state.position), + "entropy": -info["logdensity"], + } + def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - - history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) - history_weights = update_history_scalar(1., adaptation_state.history.weights) + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + + history_observables = update_history( + Etheta["observables_for_bias"], adaptation_state.history.observables + ) + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) - history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), - adaptation_state.history.stopping) + history_stopping = update_history_scalar( + jax.lax.cond( + adaptation_state.step_count > len(history_weights), + lambda _: fluctuations[0], + lambda _: jnp.nan, + operand=None, + ), + adaptation_state.history.stopping, + ) history = History(history_observables, history_stopping, history_weights) - - L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) - EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims - true_bias = self.contract(Etheta['observables_for_bias']) - nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + L = self.alpha * jnp.sqrt( + jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) + ) # average over the ensemble, sum over parameters (to get sqrt(d)) + inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.num_dims + true_bias = self.contract(Etheta["observables_for_bias"]) + nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) # hyperparameter adaptation # estimate bias - bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[ + self.bias_type + ] # r_max, r_avg, equi_full, equi_diag EEVPD_wanted = self.C * jnp.power(bias, self.power) - - eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) - eps_factor = jnp.clip(eps_factor, 0.3, 3.) - - eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) + eps_factor = jnp.clip(eps_factor, 0.3, 3.0) + + eps_factor = nan_reject( + 1 - nans, 0.5, eps_factor + ) # reduce the stepsize if there were nans # determine if we want to finish this stage (i.e. if loss is no longer decreassing) - #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] - #while_cond = ~increasing - while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) - - info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, - 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'r_max': fluctuations[0], 'r_avg': fluctuations[1], - 'while_cond': while_cond, 'entropy': Etheta['entropy'], - 'observables': Etheta['observables']} - - adaptation_state_new = AdaptationState(L, - inverse_mass_matrix, - adaptation_state.step_size * eps_factor, - adaptation_state.step_count + 1, - EEVPD, - EEVPD_wanted, - history) - + # increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + # while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | ( + adaptation_state.step_count < self.save_num + ) + + info_to_be_stored = { + "L": adaptation_state.L, + "step_size": adaptation_state.step_size, + "EEVPD_wanted": EEVPD_wanted, + "EEVPD": EEVPD, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "r_max": fluctuations[0], + "r_avg": fluctuations[1], + "while_cond": while_cond, + "entropy": Etheta["entropy"], + "observables": Etheta["observables"], + } + + adaptation_state_new = AdaptationState( + L, + inverse_mass_matrix, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history, + ) + return adaptation_state_new, info_to_be_stored - diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 7a076b962..94c634ce3 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -258,45 +258,44 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: return rss_state.step_size -def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift=jnp.log(2.0), tolerance=0.03): """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" - + def update(state, exp_x, acc_rate_new): - bounds, terminated = state - + # update the bounds acc_high = acc_rate_new > acc_prob_wanted x = jnp.log(exp_x) - + def on_true(bounds): bounds0 = jnp.max(jnp.array([bounds[0], x])) return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift - + def on_false(bounds): bounds1 = jnp.min(jnp.array([bounds[1], x])) return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift - - - bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) - - + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds bracketing = jnp.all(jnp.isfinite(bounds_new)) - - def reduce(bounds): + + def reduce(bounds): return x_new def bisect(bounds): return jnp.average(bounds) - + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) - - stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) - - terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated - + + stepsize = terminated * exp_x + (1 - terminated) * jnp.exp(x_new) + + terminated_new = ( + jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance + ) | terminated + return (bounds_new, terminated_new), stepsize - - - return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file + + return (jnp.array([-jnp.inf, jnp.inf]), False), update diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e5cc46213..2299dc68e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,9 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): +def build_kernel( + logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan +): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 92eb77f40..e8c42f11f 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,14 +4,13 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax, device_put, vmap +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map from jax.flatten_util import ravel_pytree from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map -from jax.sharding import Mesh, PartitionSpec, NamedSharding -from jax.experimental.shard_map import shard_map - from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -318,123 +317,170 @@ def incremental_value_update( total += weight return total, average -def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): def _step(state_all, xs): """This function operates on a single device.""" - state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) - + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - + # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) - + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + return (state, adaptation_state), info_to_be_stored - - - if ensemble_info != None: - + + if ensemble_info is not None: + def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) - + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + return step else: return _step -def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): - - step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) - +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, +): + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) def all_steps(initial_state, keys_sampling, keys_adaptation): """This function operates on a single device. key is a random key for this device.""" - + initial_state_all = (initial_state, adaptation.initial_state) - + # run sampling - xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + final_state_all, info_history = lax.scan(step, initial_state_all, xs) final_state, final_adaptation_state = final_state_all - return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars - + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) - p, pscalar = PartitionSpec('chains'), PartitionSpec() - parallel_execute = shard_map(all_steps, - mesh= mesh, - in_specs= (p, p, pscalar), - out_specs= (p, pscalar, pscalar), - check_rep=False - ) - # produce all random keys that will be needed # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] key_sampling, key_adaptation = split(rng_key) num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) - - return final_state, final_adaptation_state, info_history + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + return final_state, final_adaptation_state, info_history +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. -def ensemble_execute_fn(func, rng_key, num_chains, mesh, - x= None, - args= None, - summary_statistics_fn= lambda y: 0., - ): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics """ - p, pscalar = PartitionSpec('chains'), PartitionSpec() - - if x == None: - X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X= x - - adaptation_update= lambda _, Etheta: (Etheta, None) - - _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) def F(x, keys): """This function operates on a single device. key is a random key for this device.""" y, summary_statistics = _F((x, args), (None, keys, None))[0] return y, summary_statistics - parallel_execute = shard_map(F, - mesh= mesh, - in_specs= (p, p), - out_specs= (p, pscalar), - check_rep=False - ) - - keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices # apply F in parallel - return parallel_execute(X, keys) \ No newline at end of file + return parallel_execute(X, keys) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e9068326e..5d3dece82 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -2,6 +2,7 @@ import functools import itertools +from blackjax.adaptation.ensemble_mclmc import emaus import chex import jax import jax.numpy as jnp @@ -284,6 +285,31 @@ def run_adjusted_mclmc( ) return out + + def run_emaus( + self, + initial_position, + logdensity_fn, + key, + num_steps, + diagonal_preconditioning, + ): + + mesh = jax.sharding.Mesh(jax.devices(), 'chains') + + from blackjax.mcmc.integrators import velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients + + + integrator_coefficients = mclachlan_coefficients + + info1, info2, grads_per_step, _acc_prob = emaus(logdensity_fn, num_steps1=1000, num_steps2=3000, num_chains=4000, mesh=mesh, rng_key=key, + alpha = 1.9, bias_type= 3, C= 0.1, power= 3./8., + early_stop=1, r_end= 1e-2, diagonal_preconditioning= diagonal_preconditioning, integrator_coefficients= integrator_coefficients, + steps_per_sample= 15, acc_prob= None, ensemble_observables= lambda x: x + #ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return info2[1].reshape(info2[1].shape[0]*info2[1].shape[1], info2[1].shape[2]) @parameterized.parameters( itertools.product( @@ -457,6 +483,41 @@ def test_adjusted_mclmc(self, diagonal_preconditioning): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + # TODO: add preconditioning + def test_emaus(self,): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + model = Banana() + + states = self.run_emaus( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=model, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=True, + ) + + # coefs_samples = states["coefs"][3000:] + # scale_samples = np.exp(states["log_scale"][3000:]) + + # samples = states[3000:] + + print((states**2).mean(axis=0), Banana().E_x2) + + np.testing.assert_allclose((states**2).mean(axis=0), Banana().E_x2, atol=1e-2) + + # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) def test_mclmc_preconditioning(self): class IllConditionedGaussian: @@ -1223,5 +1284,46 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) + +#TODO: remove +class Banana(): + """Banana target fromm the Inference Gym""" + + def __init__(self, initialization= 'wide'): + self.name = 'Banana' + self.ndims = 2 + self.curvature = 0.03 + + self.transform = lambda x: x + self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.Var_x2 = jnp.array([20000.0, 4600.898]) + + if initialization == 'map': + self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) + elif initialization == 'posterior': + self.sample_init = lambda key: self.posterior_draw(key) + elif initialization == 'wide': + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + else: + raise ValueError('initialization = '+initialization +' is not a valid option.') + + def logdensity_fn(self, x): + mu2 = self.curvature * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + + def posterior_draw(self, key): + z = jax.random.normal(key, shape = (2, )) + x0 = 10.0 * z[0] + x1 = self.curvature * (x0 ** 2 - 100) + z[1] + return jnp.array([x0, x1]) + + def ground_truth(self): + x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + print(jnp.average(x, axis=0)) + print(jnp.average(jnp.square(x), axis=0)) + print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) + + if __name__ == "__main__": absltest.main() +