From 8eed424c181f0c5c0b5cf5e50780115d14704fc6 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 12:42:00 -0500 Subject: [PATCH 01/34] add static adjusted mclmc --- blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 4 +- blackjax/mcmc/adjusted_mclmc.py | 57 ++---- blackjax/mcmc/adjusted_mclmc_dynamic.py | 257 ++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 110 +++++++++- 5 files changed, 384 insertions(+), 46 deletions(-) create mode 100644 blackjax/mcmc/adjusted_mclmc_dynamic.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..35c9e3b58 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,6 +13,7 @@ from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -112,6 +113,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 1e1317684..fad5dcb97 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,5 +1,5 @@ from . import ( - adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -25,5 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "adjusted_mclmc", + "adjusted_mclmc_dynamic", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 81fbc2835..8288772a3 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" from typing import Callable, Union import jax @@ -19,28 +23,26 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.hmc import HMCInfo, HMCState from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] -def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + return HMCState(position, logdensity, logdensity_grad) def build_kernel( - integration_steps_fn, + num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], sqrt_diag_cov=1.0, ): - """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -63,15 +65,13 @@ def build_kernel( def kernel( rng_key: PRNGKey, - state: DynamicHMCState, + state: HMCState, logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) - key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -90,11 +90,10 @@ def kernel( ) return ( - DynamicHMCState( + HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, - next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -110,10 +109,9 @@ def as_top_level_api( *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + num_integration_steps, ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- @@ -140,15 +138,15 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, + num_integration_steps, integrator=integrator, - next_random_arg_fn=next_random_arg_fn, sqrt_diag_cov=sqrt_diag_cov, divergence_threshold=divergence_threshold, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - return init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) def update_fn(rng_key: PRNGKey, state): return kernel( @@ -240,18 +238,3 @@ def generate( return sampled_state, info, other_proposal_info return generate - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..81fbc2835 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,257 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 474f67293..45d60f84a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,7 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -146,7 +146,7 @@ def run_mclmc( return samples - def run_adjusted_mclmc( + def run_adjusted_mclmc_dynamic( self, logdensity_fn, num_steps, @@ -158,13 +158,13 @@ def run_adjusted_mclmc( init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) @@ -177,7 +177,7 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - target_acc_rate = 0.65 + target_acc_rate = 0.9 ( blackjax_state_after_tuning, @@ -197,7 +197,7 @@ def run_adjusted_mclmc( step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( @@ -218,6 +218,73 @@ def run_adjusted_mclmc( return out + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + num_integration_steps=avg_num_integration_steps, + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -334,7 +401,35 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - def test_adjusted_mclmc(self): + @parameterized.parameters([True, False]) + def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_dynamic( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + @parameterized.parameters([True, False]) + def test_adjusted_mclmc(self, diagonal_preconditioning): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -351,6 +446,7 @@ def test_adjusted_mclmc(self): logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, ) coefs_samples = states["coefs"][3000:] From 9dd6bdba039003538eaf635fde7678defdbc7350 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:27:08 -0500 Subject: [PATCH 02/34] add static adjusted mclmc --- .../adaptation/adjusted_mclmc_adaptation.py | 10 ++--- blackjax/adaptation/mclmc_adaptation.py | 22 +++++----- blackjax/mcmc/adjusted_mclmc.py | 10 +++-- blackjax/mcmc/adjusted_mclmc_dynamic.py | 10 +++-- blackjax/mcmc/integrators.py | 12 +++--- blackjax/mcmc/mclmc.py | 8 ++-- tests/mcmc/test_integrators.py | 4 +- tests/mcmc/test_sampling.py | 40 +++++++++++-------- 8 files changed, 64 insertions(+), 52 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f5d54e5c9..eabb642a3 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -74,7 +74,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -152,7 +152,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) # step updating @@ -283,9 +283,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace( - sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) - ) + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -323,7 +321,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) return next_state, next_state.position diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 8452b6171..aa192b964 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -123,7 +123,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 8288772a3..9b868562c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -40,7 +40,7 @@ def build_kernel( num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -105,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -140,7 +142,7 @@ def as_top_level_api( kernel = build_kernel( num_integration_steps, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 81fbc2835..1a69e1a28 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -38,7 +38,7 @@ def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -106,7 +108,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -143,7 +145,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 593683ca4..733e7e960 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..ff9638a1f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """Build a HMC kernel. Parameters @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -107,7 +107,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..c37c0ede6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 45d60f84a..a4ea66a9b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -112,10 +112,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -133,7 +133,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -144,6 +144,8 @@ def run_mclmc( transform=lambda state, info: state.position, ) + print(samples["coefs"][0].item()) + return samples def run_adjusted_mclmc_dynamic( @@ -164,12 +166,12 @@ def run_adjusted_mclmc_dynamic( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -204,7 +206,7 @@ def run_adjusted_mclmc_dynamic( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -216,6 +218,8 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) + print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) + return out def run_adjusted_mclmc( @@ -235,10 +239,10 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, num_integration_steps=avg_num_integration_steps, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -271,7 +275,7 @@ def run_adjusted_mclmc( step_size=step_size, num_integration_steps=L / step_size, integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -402,7 +406,10 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters([True, False]) - def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + def test_adjusted_mclmc_dynamic( + self, + diagonal_preconditioning, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -495,7 +502,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -506,10 +513,10 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -523,13 +530,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From a49bb35f37293f0033ea4c9c5b8daf7ff62c1461 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:54:36 -0500 Subject: [PATCH 03/34] add static adjusted mclmc --- tests/mcmc/test_sampling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a4ea66a9b..d788696f8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -144,8 +144,6 @@ def run_mclmc( transform=lambda state, info: state.position, ) - print(samples["coefs"][0].item()) - return samples def run_adjusted_mclmc_dynamic( @@ -218,8 +216,6 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) - print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) - return out def run_adjusted_mclmc( From 35d71ffb9ece4940b113ed2caedf4f302b20ddfd Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:02:35 -0500 Subject: [PATCH 04/34] draft --- blackjax/adaptation/ensemble_mclmc.py | 224 +++++++++++++++++++++ blackjax/adaptation/ensemble_umclmc.py | 265 +++++++++++++++++++++++++ blackjax/adaptation/step_size.py | 43 ++++ blackjax/mcmc/mclmc.py | 2 +- blackjax/util.py | 123 +++++++++++- 5 files changed, 655 insertions(+), 2 deletions(-) create mode 100644 blackjax/adaptation/ensemble_mclmc.py create mode 100644 blackjax/adaptation/ensemble_umclmc.py diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py new file mode 100644 index 000000000..dabf5a3be --- /dev/null +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -0,0 +1,224 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +from typing import Callable, NamedTuple, Any + +import jax +import jax.numpy as jnp + +from blackjax.util import run_eca +from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss + +from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn + + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + epsadap_state: Any + sample_count: int + + + +def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): + """MCLMC kernel""" + + kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + + def sequential_kernel(key, state, adap): + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + + return sequential_kernel + + + +class Adaptation: + + def __init__(self, adap_state, num_adaptation_samples, + steps_per_sample, acc_prob_target= 0.8, + observables = lambda x: 0., + observables_for_bias = lambda x: 0., contract= lambda x: 0.): + + self.num_adaptation_samples= num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + ### Determine the initial hyperparameters ### + + ## stepsize ## + #if we switched to the more accurate integrator we can use longer step size + #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size #* integrator_factor * adjustment_factor + + #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + ### Initialize the dual averaging adaptation ### + #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + #epsadap_state = da_init_fn(step_size) + + ### Initialize the bisection for finding the step size + epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + + self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) + + + def summary_statistics_fn(self, state, info, rng_key): + + return {'acceptance_probability': info.acceptance_rate, + #'inv_acceptance_probability': 1./info.acceptance_rate, + 'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'observables': self.observables(state.position), + 'observables_for_bias': self.observables_for_bias(state.position) + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + acc_prob = Etheta['acceptance_probability'] + #acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + true_bias = self.contract(Etheta['observables_for_bias']) + + + info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, + 'acc_prob': acc_prob, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'observables': Etheta['observables'] + } + + # hyperparameter adaptation + + # Dual Averaging + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + + # def update(_): + # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) + # step_size = jnp.exp(da_state.log_step_size) + # return da_state, step_size + + # def dont_update(_): + # da_state = adaptation_state.epsadap_state + # return da_state, jnp.exp(da_state.log_step_size_avg) + + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) + + # Bisection + epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) + + return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + +def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, + alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters + diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, + observables = lambda x: None, + ensemble_observables= None + ): + + observables_for_bias, contract = bias(model) + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) + + ### burn-in with the unadjusted method ### + kernel = umclmc.build_kernel(model.logdensity_fn) + save_num= (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) + + if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + + num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) + #print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) + + ### refine the results with the adjusted method ### + _acc_prob = acc_prob + if integrator_coefficients == None: + high_dims = model.ndims > 200 + _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients + if acc_prob == None: + _acc_prob = 0.9 if high_dims else 0.7 + + else: + _integrator_coefficients = integrator_coefficients + if acc_prob == None: + _acc_prob = 0.9 + + + integrator = generate_isokinetic_integrator(_integrator_coefficients) + gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + + if diagonal_preconditioning: + sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + + else: + sqrt_diag_cov= 1. + + kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) + num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. + + adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + + final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) + + return info1, info2, gradient_calls_per_step, _acc_prob + + \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py new file mode 100644 index 000000000..73b9ae4f7 --- /dev/null +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -0,0 +1,265 @@ + +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from typing import Callable, NamedTuple, Any + +from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet +from blackjax.types import Array, ArrayLike +from blackjax.util import pytree_size +from blackjax.mcmc import mclmc +from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.util import ensemble_execute_fn + + + +def no_nans(a): + flat_a, unravel_fn = ravel_pytree(a) + return jnp.all(jnp.isfinite(flat_a)) + + +def nan_reject(nonans, old, new): + """Equivalent to + return new if nonans else old""" + + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + + + +def build_kernel(logdensity_fn): + """MCLMC kernel (with nan rejection)""" + + kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) + + + def sequential_kernel(key, state, adap): + + new_state, info = kernel(key, state, adap.L, adap.step_size) + + # reject the new state if there were nans + nonans = no_nans(new_state) + new_state = nan_reject(nonans, state, new_state) + + return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} + + + return sequential_kernel + + + +def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): + """initialize the chains based on the equipartition of the initial condition. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + """ + + def sequential_init(key, x, args): + """initialize the position using sample_init and the velocity along the gradient""" + position = sample_init(key) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + flat_g, unravel_fn = ravel_pytree(logdensity_grad) + velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + return IntegratorState(position, velocity, logdensity, logdensity_grad), None + + def summary_statistics_fn(state): + """compute the diagonal elements of the equipartition matrix""" + return -state.position * state.logdensity_grad + + def ensemble_init(key, state, signs): + """flip the velocity, depending on the equipartition condition""" + velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) + return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None + + key1, key2= jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) + signs = -2. * (equipartition < 1.) + 1. + initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) + + return initial_state + + +def update_history(new_vals, history): + return jnp.concatenate((new_vals[None, :], history[:-1])) + +def update_history_scalar(new_val, history): + return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + +def contract_history(theta, weights): + + square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) + average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) + + r = (average_square - square_average) / square_average + + return jnp.array([jnp.max(r), + jnp.average(r) + ]) + + +class History(NamedTuple): + observables: Array + stopping: Array + weights: Array + + +class AdaptationState(NamedTuple): + + L: float + sqrt_diag_cov: Any + step_size: float + + step_count: int + EEVPD: float + EEVPD_wanted: float + history: Any + + +def equipartition_diagonal(state): + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + equipartition_loss = average over parameters (Ei)""" + return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) + + + +def equipartition_fullrank(state, rng_key): + """loss = Tr[(1 - E)^T (1 - E)] / d^2 + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" + + x, unravel_fn = ravel_pytree(state.position) + g, unravel_fn = ravel_pytree(state.logdensity_grad) + d = len(x) + + def func(z): + """z here has the same shape as position""" + return (z + jnp.dot(z, g) * x) + + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + return jax.vmap(func)(z) + + +def equipartition_diagonal_loss(Eii): + Eii_flat, unravel_fn = ravel_pytree(Eii) + return jnp.average(jnp.square(1.- Eii_flat)) + + +def equipartition_fullrank_loss(delta_z): + d = delta_z.shape[-1] + return jnp.average(jnp.square(delta_z)) / d + + +class Adaptation: + + def __init__(self, num_dims, + alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, + bias_type= 0, save_num = 10, + observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. + ): + + self.num_dims = num_dims + self.alpha = alpha + self.C = C + self.power = power + self.r_end = r_end + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + self.bias_type = bias_type + self.save_num = save_num + #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + + r_save_num = save_num + + history = History(observables= jnp.zeros((r_save_num, num_dims)), + stopping= jnp.full((save_num,), jnp.nan), + weights= jnp.zeros(r_save_num)) + + self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step + sqrt_diag_cov= jnp.ones(num_dims), + step_size= 0.01 * jnp.sqrt(num_dims), + step_count= 0, + EEVPD=1e-3, EEVPD_wanted=1e-3, + history=history) + + + def summary_statistics_fn(self, state, info, rng_key): + + position_flat, unravel_fn = ravel_pytree(state.position) + + return {'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'x': position_flat, 'xsq': jnp.square(position_flat), + 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), + 'rejection_rate_nans': info['nans'], + 'observables_for_bias': self.observables_for_bias(state.position), + 'observables': self.observables(state.position), + 'entropy': - info['logdensity'] + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + + history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) + history_weights = update_history_scalar(1., adaptation_state.history.weights) + fluctuations = contract_history(history_observables, history_weights) + history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), + adaptation_state.history.stopping) + history = History(history_observables, history_stopping, history_weights) + + L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) + sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims + true_bias = self.contract(Etheta['observables_for_bias']) + nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + # hyperparameter adaptation + # estimate bias + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + EEVPD_wanted = self.C * jnp.power(bias, self.power) + + + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) + eps_factor = jnp.clip(eps_factor, 0.3, 3.) + + eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + + # determine if we want to finish this stage (i.e. if loss is no longer decreassing) + #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + #while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) + + info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, + 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'r_max': fluctuations[0], 'r_avg': fluctuations[1], + 'while_cond': while_cond, 'entropy': Etheta['entropy'], + 'observables': Etheta['observables']} + + adaptation_state_new = AdaptationState(L, + sqrt_diag_cov, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history) + + return adaptation_state_new, info_to_be_stored + diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2b06172c0..7a076b962 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -257,3 +257,46 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" + + def update(state, exp_x, acc_rate_new): + + bounds, terminated = state + + # update the bounds + acc_high = acc_rate_new > acc_prob_wanted + x = jnp.log(exp_x) + + def on_true(bounds): + bounds0 = jnp.max(jnp.array([bounds[0], x])) + return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + + def on_false(bounds): + bounds1 = jnp.min(jnp.array([bounds[1], x])) + return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + + + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds + bracketing = jnp.all(jnp.isfinite(bounds_new)) + + def reduce(bounds): + return x_new + + def bisect(bounds): + return jnp.average(bounds) + + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) + + stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) + + terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated + + return (bounds_new, terminated_new), stepsize + + + return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..e5cc46213 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 8cdcd45ee..8f6494167 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,11 +4,14 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax +from jax import jit, lax, device_put, vmap from jax.flatten_util import ravel_pytree from jax.random import normal, split from jax.tree_util import tree_leaves, tree_map +from jax.sharding import Mesh, PartitionSpec, NamedSharding +from jax.experimental.shard_map import shard_map + from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -314,3 +317,121 @@ def incremental_value_update( ) total += weight return total, average + +def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): + + def _step(state_all, xs): + """This function operates on a single device.""" + state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) + + return (state, adaptation_state), info_to_be_stored + + + if ensemble_info != None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) + + return step + + else: + return _step + + +def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): + + step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) + + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + final_state, final_adaptation_state = final_state_all + return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars + + + p, pscalar = PartitionSpec('chains'), PartitionSpec() + parallel_execute = shard_map(all_steps, + mesh= mesh, + in_specs= (p, p, pscalar), + out_specs= (p, pscalar, pscalar), + check_rep=False + ) + + # produce all random keys that will be needed + key_sampling, key_adaptation = split(rng_key) + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) + + return final_state, final_adaptation_state, info_history + + + + +def ensemble_execute_fn(func, rng_key, num_chains, mesh, + x= None, + args= None, + summary_statistics_fn= lambda y: 0., + ): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec('chains'), PartitionSpec() + + if x == None: + X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X= x + + adaptation_update= lambda _, Etheta: (Etheta, None) + + _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map(F, + mesh= mesh, + in_specs= (p, p), + out_specs= (p, pscalar), + check_rep=False + ) + + keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) \ No newline at end of file From 0f3df53011f00ecec9d879f23c01f08588608241 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:13:11 -0500 Subject: [PATCH 05/34] draft --- blackjax/adaptation/ensemble_mclmc.py | 14 +++++++------- blackjax/adaptation/ensemble_umclmc.py | 8 ++++---- blackjax/mcmc/adjusted_mclmc.py | 4 ++-- blackjax/util.py | 3 +++ 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index dabf5a3be..b303f8ae7 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -37,13 +37,13 @@ class AdaptationState(NamedTuple): -def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): +def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): """MCLMC kernel""" - kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) return sequential_kernel @@ -200,16 +200,16 @@ def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) # scale the stepsize so that it reflects averag scale change of the preconditioning - average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) else: - sqrt_diag_cov= 1. + inverse_mass_matrix= 1. - kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 73b9ae4f7..458361103 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -119,7 +119,7 @@ class History(NamedTuple): class AdaptationState(NamedTuple): L: float - sqrt_diag_cov: Any + inverse_mass_matrix: Any step_size: float step_count: int @@ -189,7 +189,7 @@ def __init__(self, num_dims, weights= jnp.zeros(r_save_num)) self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - sqrt_diag_cov= jnp.ones(num_dims), + inverse_mass_matrix= jnp.ones(num_dims), step_size= 0.01 * jnp.sqrt(num_dims), step_count= 0, EEVPD=1e-3, EEVPD_wanted=1e-3, @@ -225,7 +225,7 @@ def update(self, adaptation_state, Etheta): history = History(history_observables, history_stopping, history_weights) L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims true_bias = self.contract(Etheta['observables_for_bias']) nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) @@ -254,7 +254,7 @@ def update(self, adaptation_state, Etheta): 'observables': Etheta['observables']} adaptation_state_new = AdaptationState(L, - sqrt_diag_cov, + inverse_mass_matrix, adaptation_state.step_size * eps_factor, adaptation_state.step_count + 1, EEVPD, diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..8a5a37e55 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" diff --git a/blackjax/util.py b/blackjax/util.py index 8f6494167..92eb77f40 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -377,7 +377,10 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): ) # produce all random keys that will be needed + # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) From 04522f57c94cf6d5aa2bbb25f9cd14450be645da Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:18:55 -0500 Subject: [PATCH 06/34] change order of parameters --- blackjax/mcmc/adjusted_mclmc.py | 16 ++++++++-------- tests/mcmc/test_sampling.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..f390402f2 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -140,7 +140,7 @@ def as_top_level_api( """ kernel = build_kernel( - num_integration_steps, + logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, @@ -152,11 +152,11 @@ def init_fn(position: ArrayLikeTree, rng_key=None): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d788696f8..e9068326e 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -237,13 +237,13 @@ def run_adjusted_mclmc( kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, - num_integration_steps=avg_num_integration_steps, inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, )( rng_key=rng_key, state=state, step_size=step_size, - logdensity_fn=logdensity_fn, + num_integration_steps=avg_num_integration_steps, ) target_acc_rate = 0.9 From a7c99b92ee59b82db855d4b919a6254aa979a97d Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Jan 2025 21:00:11 +0000 Subject: [PATCH 07/34] draft --- blackjax/adaptation/ensemble_mclmc.py | 352 +++++++++++++++---------- blackjax/adaptation/ensemble_umclmc.py | 319 ++++++++++++---------- blackjax/adaptation/step_size.py | 43 ++- blackjax/mcmc/mclmc.py | 4 +- blackjax/util.py | 200 ++++++++------ tests/mcmc/test_sampling.py | 102 +++++++ 6 files changed, 655 insertions(+), 365 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index b303f8ae7..c01791daa 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -11,142 +11,158 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" -from typing import Callable, NamedTuple, Any +from typing import Any, NamedTuple import jax import jax.numpy as jnp -from blackjax.util import run_eca -from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients -from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt import blackjax.adaptation.ensemble_umclmc as umclmc -from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss - -from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, + equipartition_fullrank, + equipartition_fullrank_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +from blackjax.util import run_eca - class AdaptationState(NamedTuple): steps_per_sample: float step_size: float epsadap_state: Any sample_count: int - - -def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): - """MCLMC kernel""" - - kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) - - def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) - - return sequential_kernel +# put the arguments of build_kernel in a suitable order +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, +)(rng_key=key, state=state, step_size=adap.step_size, num_integration_steps=adap.steps_per_sample, L_proposal_factor=1.25) class Adaptation: - - def __init__(self, adap_state, num_adaptation_samples, - steps_per_sample, acc_prob_target= 0.8, - observables = lambda x: 0., - observables_for_bias = lambda x: 0., contract= lambda x: 0.): - - self.num_adaptation_samples= num_adaptation_samples + def __init__( + self, + adap_state, + num_adaptation_samples, + steps_per_sample, + acc_prob_target=0.8, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): + self.num_adaptation_samples = num_adaptation_samples self.observables = observables self.observables_for_bias = observables_for_bias self.contract = contract - - ### Determine the initial hyperparameters ### - - ## stepsize ## - #if we switched to the more accurate integrator we can use longer step size - #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 - # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adap_state.step_size #* integrator_factor * adjustment_factor - - #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - ### Initialize the dual averaging adaptation ### - #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - #epsadap_state = da_init_fn(step_size) - - ### Initialize the bisection for finding the step size + # adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size # * integrator_factor * adjustment_factor + + # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + # Initialize the dual averaging adaptation # + # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + # epsadap_state = da_init_fn(step_size) + + # Initialize the bisection for finding the step size epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - - self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) - - + + self.initial_state = AdaptationState( + steps_per_sample, step_size, epsadap_state, 0 + ) + def summary_statistics_fn(self, state, info, rng_key): - - return {'acceptance_probability': info.acceptance_rate, - #'inv_acceptance_probability': 1./info.acceptance_rate, - 'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'observables': self.observables(state.position), - 'observables_for_bias': self.observables_for_bias(state.position) - } - + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - acc_prob = Etheta['acceptance_probability'] - #acc_prob = 1./Etheta['inv_acceptance_probability'] - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - true_bias = self.contract(Etheta['observables_for_bias']) - - - info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, - 'acc_prob': acc_prob, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'observables': Etheta['observables'] - } + acc_prob = Etheta["acceptance_probability"] + # acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + true_bias = self.contract(Etheta["observables_for_bias"]) + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "observables": Etheta["observables"], + } # hyperparameter adaptation - + # Dual Averaging - # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples - + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + # def update(_): # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) # step_size = jnp.exp(da_state.log_step_size) # return da_state, step_size - + # def dont_update(_): # da_state = adaptation_state.epsadap_state # return da_state, jnp.exp(da_state.log_step_size_avg) - + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) - - # Bisection - epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) - - return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + # Bisection + epsadap_state, step_size = self.epsadap_update( + adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + epsadap_state, + adaptation_state.sample_count + 1, + ), + info_to_be_stored, + ) def bias(model): """should be transfered to benchmarks/""" - + def observables(position): return jnp.square(model.transform(position)) - + def contract(sampler_E_x2): bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 return jnp.array([jnp.max(bsq), jnp.average(bsq)]) - - return observables, contract + return observables, contract def while_steps_num(cond): @@ -156,69 +172,141 @@ def while_steps_num(cond): return jnp.argmin(cond) + 1 -def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, - alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters - diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, - observables = lambda x: None, - ensemble_observables= None - ): - +def emaus( + model, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + bias_type=0, + save_frac=0.2, + C=0.1, + power=3.0 / 8.0, + early_stop=True, + r_end=5e-3, # stage1 parameters + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=10, + acc_prob=None, + observables=lambda x: None, + ensemble_observables=None, +): observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - + # initialize the chains - initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) - - ### burn-in with the unadjusted method ### + initial_state = umclmc.initialize( + key_init, model.logdensity_fn, model.sample_init, num_chains, mesh + ) + + # burn-in with the unadjusted method # kernel = umclmc.build_kernel(model.logdensity_fn) - save_num= (int)(jnp.rint(save_frac * num_steps1)) - adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) - - if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. - - num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) - #print(num_steps_while, save_num) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) - - ### refine the results with the adjusted method ### + save_num = (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation( + model.ndims, + alpha=alpha, + bias_type=bias_type, + save_num=save_num, + C=C, + power=power, + r_end=r_end, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + ensemble_observables, + ) + + if ( + early_stop + ): # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + num_steps_while = while_steps_num( + (info1[0] if ensemble_observables is not None else info1)["while_cond"] + ) + # print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps_while, + num_chains, + mesh, + ensemble_observables, + ) + + # refine the results with the adjusted method # _acc_prob = acc_prob - if integrator_coefficients == None: + if integrator_coefficients is None: high_dims = model.ndims > 200 - _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients - if acc_prob == None: + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: _acc_prob = 0.9 if high_dims else 0.7 - + else: _integrator_coefficients = integrator_coefficients - if acc_prob == None: + if acc_prob is None: _acc_prob = 0.9 - - + integrator = generate_isokinetic_integrator(_integrator_coefficients) - gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 + ) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) - + inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + # scale the stepsize so that it reflects averag scale change of the preconditioning average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) - final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) else: - inverse_mass_matrix= 1. - - kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) - initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + inverse_mass_matrix = 1.0 + + kernel = build_kernel( + model.logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + ) + initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + ) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) - num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. - - adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - - final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) - + num_adaptation_samples = ( + num_samples // 2 + ) # number of samples after which the stepsize is fixed. + + adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + + final_state, final_adaptation_state, info2 = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + ensemble_observables, + ) + return info1, info2, gradient_calls_per_step, _acc_prob - - \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 458361103..7e2390def 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -1,4 +1,3 @@ - # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,22 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" + +from typing import Any, NamedTuple import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from typing import Callable, NamedTuple, Any -from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet -from blackjax.types import Array, ArrayLike -from blackjax.util import pytree_size from blackjax.mcmc import mclmc -from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.mcmc.integrators import ( + IntegratorState, + _normalized_flatten_array, + isokinetic_velocity_verlet, +) +from blackjax.types import Array from blackjax.util import ensemble_execute_fn - def no_nans(a): flat_a, unravel_fn = ravel_pytree(a) return jnp.all(jnp.isfinite(flat_a)) @@ -35,79 +36,96 @@ def no_nans(a): def nan_reject(nonans, old, new): """Equivalent to - return new if nonans else old""" - - return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + return new if nonans else old""" + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) def build_kernel(logdensity_fn): """MCLMC kernel (with nan rejection)""" - - kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) - - + + kernel = mclmc.build_kernel( + logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet + ) + def sequential_kernel(key, state, adap): - new_state, info = kernel(key, state, adap.L, adap.step_size) - + # reject the new state if there were nans nonans = no_nans(new_state) new_state = nan_reject(nonans, state, new_state) - - return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} - + return new_state, { + "nans": 1 - nonans, + "energy_change": info.energy_change * nonans, + "logdensity": info.logdensity * nonans, + } + return sequential_kernel - def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): """initialize the chains based on the equipartition of the initial condition. - We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. """ - + def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) - velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + velocity = unravel_fn( + _normalized_flatten_array(flat_g)[0] + ) # = grad logp/ |grad logp| return IntegratorState(position, velocity, logdensity, logdensity_grad), None - + def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" return -state.position * state.logdensity_grad - + def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) - return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None - - key1, key2= jax.random.split(rng_key) - initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) - signs = -2. * (equipartition < 1.) + 1. - initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) - + velocity = jax.tree_util.tree_map( + lambda sign, u: sign * u, signs, state.momentum + ) + return ( + IntegratorState( + state.position, velocity, state.logdensity, state.logdensity_grad + ), + None, + ) + + key1, key2 = jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn( + sequential_init, + key1, + num_chains, + mesh, + summary_statistics_fn=summary_statistics_fn, + ) + signs = -2.0 * (equipartition < 1.0) + 1.0 + initial_state, _ = ensemble_execute_fn( + ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs + ) + return initial_state - - + + def update_history(new_vals, history): return jnp.concatenate((new_vals[None, :], history[:-1])) + def update_history_scalar(new_val, history): return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + def contract_history(theta, weights): - - square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) - average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) - + square_average = jnp.square(jnp.average(theta, weights=weights, axis=0)) + average_square = jnp.average(jnp.square(theta), weights=weights, axis=0) + r = (average_square - square_average) / square_average - - return jnp.array([jnp.max(r), - jnp.average(r) - ]) + + return jnp.array([jnp.max(r), jnp.average(r)]) class History(NamedTuple): @@ -117,44 +135,44 @@ class History(NamedTuple): class AdaptationState(NamedTuple): - L: float inverse_mass_matrix: Any step_size: float - + step_count: int EEVPD: float EEVPD_wanted: float - history: Any - + history: Any + def equipartition_diagonal(state): - """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. equipartition_loss = average over parameters (Ei)""" - return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) - + return jax.tree_util.tree_map( + lambda x, g: -x * g, state.position, state.logdensity_grad + ) def equipartition_fullrank(state, rng_key): """loss = Tr[(1 - E)^T (1 - E)] / d^2 - where Eij = is the equipartition patrix. - Loss is computed with the Hutchinson's trick.""" + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" x, unravel_fn = ravel_pytree(state.position) g, unravel_fn = ravel_pytree(state.logdensity_grad) d = len(x) - + def func(z): """z here has the same shape as position""" - return (z + jnp.dot(z, g) * x) + return z + jnp.dot(z, g) * x - z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij return jax.vmap(func)(z) def equipartition_diagonal_loss(Eii): Eii_flat, unravel_fn = ravel_pytree(Eii) - return jnp.average(jnp.square(1.- Eii_flat)) + return jnp.average(jnp.square(1.0 - Eii_flat)) def equipartition_fullrank_loss(delta_z): @@ -163,103 +181,138 @@ def equipartition_fullrank_loss(delta_z): class Adaptation: - - def __init__(self, num_dims, - alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, - bias_type= 0, save_num = 10, - observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. - ): - + def __init__( + self, + num_dims, + alpha=1.0, + C=0.1, + power=3.0 / 8.0, + r_end=0.01, + bias_type=0, + save_num=10, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): self.num_dims = num_dims self.alpha = alpha self.C = C self.power = power self.r_end = r_end self.observables = observables - self.observables_for_bias = observables_for_bias + self.observables_for_bias = observables_for_bias self.contract = contract self.bias_type = bias_type self.save_num = save_num - #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - + # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + r_save_num = save_num - - history = History(observables= jnp.zeros((r_save_num, num_dims)), - stopping= jnp.full((save_num,), jnp.nan), - weights= jnp.zeros(r_save_num)) - - self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - inverse_mass_matrix= jnp.ones(num_dims), - step_size= 0.01 * jnp.sqrt(num_dims), - step_count= 0, - EEVPD=1e-3, EEVPD_wanted=1e-3, - history=history) - - + + history = History( + observables=jnp.zeros((r_save_num, num_dims)), + stopping=jnp.full((save_num,), jnp.nan), + weights=jnp.zeros(r_save_num), + ) + + self.initial_state = AdaptationState( + L=jnp.inf, # do not add noise for the first step + inverse_mass_matrix=jnp.ones(num_dims), + step_size=0.01 * jnp.sqrt(num_dims), + step_count=0, + EEVPD=1e-3, + EEVPD_wanted=1e-3, + history=history, + ) + def summary_statistics_fn(self, state, info, rng_key): - position_flat, unravel_fn = ravel_pytree(state.position) - - return {'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'x': position_flat, 'xsq': jnp.square(position_flat), - 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), - 'rejection_rate_nans': info['nans'], - 'observables_for_bias': self.observables_for_bias(state.position), - 'observables': self.observables(state.position), - 'entropy': - info['logdensity'] - } - - + + return { + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "x": position_flat, + "xsq": jnp.square(position_flat), + "E": info["energy_change"], + "Esq": jnp.square(info["energy_change"]), + "rejection_rate_nans": info["nans"], + "observables_for_bias": self.observables_for_bias(state.position), + "observables": self.observables(state.position), + "entropy": -info["logdensity"], + } + def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - - history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) - history_weights = update_history_scalar(1., adaptation_state.history.weights) + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + + history_observables = update_history( + Etheta["observables_for_bias"], adaptation_state.history.observables + ) + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) - history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), - adaptation_state.history.stopping) + history_stopping = update_history_scalar( + jax.lax.cond( + adaptation_state.step_count > len(history_weights), + lambda _: fluctuations[0], + lambda _: jnp.nan, + operand=None, + ), + adaptation_state.history.stopping, + ) history = History(history_observables, history_stopping, history_weights) - - L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) - EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims - true_bias = self.contract(Etheta['observables_for_bias']) - nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + L = self.alpha * jnp.sqrt( + jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) + ) # average over the ensemble, sum over parameters (to get sqrt(d)) + inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.num_dims + true_bias = self.contract(Etheta["observables_for_bias"]) + nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) # hyperparameter adaptation # estimate bias - bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[ + self.bias_type + ] # r_max, r_avg, equi_full, equi_diag EEVPD_wanted = self.C * jnp.power(bias, self.power) - - eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) - eps_factor = jnp.clip(eps_factor, 0.3, 3.) - - eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) + eps_factor = jnp.clip(eps_factor, 0.3, 3.0) + + eps_factor = nan_reject( + 1 - nans, 0.5, eps_factor + ) # reduce the stepsize if there were nans # determine if we want to finish this stage (i.e. if loss is no longer decreassing) - #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] - #while_cond = ~increasing - while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) - - info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, - 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'r_max': fluctuations[0], 'r_avg': fluctuations[1], - 'while_cond': while_cond, 'entropy': Etheta['entropy'], - 'observables': Etheta['observables']} - - adaptation_state_new = AdaptationState(L, - inverse_mass_matrix, - adaptation_state.step_size * eps_factor, - adaptation_state.step_count + 1, - EEVPD, - EEVPD_wanted, - history) - + # increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + # while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | ( + adaptation_state.step_count < self.save_num + ) + + info_to_be_stored = { + "L": adaptation_state.L, + "step_size": adaptation_state.step_size, + "EEVPD_wanted": EEVPD_wanted, + "EEVPD": EEVPD, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "r_max": fluctuations[0], + "r_avg": fluctuations[1], + "while_cond": while_cond, + "entropy": Etheta["entropy"], + "observables": Etheta["observables"], + } + + adaptation_state_new = AdaptationState( + L, + inverse_mass_matrix, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history, + ) + return adaptation_state_new, info_to_be_stored - diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 7a076b962..94c634ce3 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -258,45 +258,44 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: return rss_state.step_size -def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift=jnp.log(2.0), tolerance=0.03): """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" - + def update(state, exp_x, acc_rate_new): - bounds, terminated = state - + # update the bounds acc_high = acc_rate_new > acc_prob_wanted x = jnp.log(exp_x) - + def on_true(bounds): bounds0 = jnp.max(jnp.array([bounds[0], x])) return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift - + def on_false(bounds): bounds1 = jnp.min(jnp.array([bounds[1], x])) return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift - - - bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) - - + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds bracketing = jnp.all(jnp.isfinite(bounds_new)) - - def reduce(bounds): + + def reduce(bounds): return x_new def bisect(bounds): return jnp.average(bounds) - + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) - - stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) - - terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated - + + stepsize = terminated * exp_x + (1 - terminated) * jnp.exp(x_new) + + terminated_new = ( + jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance + ) | terminated + return (bounds_new, terminated_new), stepsize - - - return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file + + return (jnp.array([-jnp.inf, jnp.inf]), False), update diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e5cc46213..2299dc68e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,9 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): +def build_kernel( + logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan +): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 92eb77f40..e8c42f11f 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,14 +4,13 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax, device_put, vmap +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map from jax.flatten_util import ravel_pytree from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map -from jax.sharding import Mesh, PartitionSpec, NamedSharding -from jax.experimental.shard_map import shard_map - from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -318,123 +317,170 @@ def incremental_value_update( total += weight return total, average -def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): def _step(state_all, xs): """This function operates on a single device.""" - state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) - + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - + # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) - + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + return (state, adaptation_state), info_to_be_stored - - - if ensemble_info != None: - + + if ensemble_info is not None: + def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) - + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + return step else: return _step -def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): - - step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) - +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, +): + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) def all_steps(initial_state, keys_sampling, keys_adaptation): """This function operates on a single device. key is a random key for this device.""" - + initial_state_all = (initial_state, adaptation.initial_state) - + # run sampling - xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + final_state_all, info_history = lax.scan(step, initial_state_all, xs) final_state, final_adaptation_state = final_state_all - return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars - + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) - p, pscalar = PartitionSpec('chains'), PartitionSpec() - parallel_execute = shard_map(all_steps, - mesh= mesh, - in_specs= (p, p, pscalar), - out_specs= (p, pscalar, pscalar), - check_rep=False - ) - # produce all random keys that will be needed # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] key_sampling, key_adaptation = split(rng_key) num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) - - return final_state, final_adaptation_state, info_history + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + return final_state, final_adaptation_state, info_history +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. -def ensemble_execute_fn(func, rng_key, num_chains, mesh, - x= None, - args= None, - summary_statistics_fn= lambda y: 0., - ): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics """ - p, pscalar = PartitionSpec('chains'), PartitionSpec() - - if x == None: - X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X= x - - adaptation_update= lambda _, Etheta: (Etheta, None) - - _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) def F(x, keys): """This function operates on a single device. key is a random key for this device.""" y, summary_statistics = _F((x, args), (None, keys, None))[0] return y, summary_statistics - parallel_execute = shard_map(F, - mesh= mesh, - in_specs= (p, p), - out_specs= (p, pscalar), - check_rep=False - ) - - keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices # apply F in parallel - return parallel_execute(X, keys) \ No newline at end of file + return parallel_execute(X, keys) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e9068326e..5d3dece82 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -2,6 +2,7 @@ import functools import itertools +from blackjax.adaptation.ensemble_mclmc import emaus import chex import jax import jax.numpy as jnp @@ -284,6 +285,31 @@ def run_adjusted_mclmc( ) return out + + def run_emaus( + self, + initial_position, + logdensity_fn, + key, + num_steps, + diagonal_preconditioning, + ): + + mesh = jax.sharding.Mesh(jax.devices(), 'chains') + + from blackjax.mcmc.integrators import velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients + + + integrator_coefficients = mclachlan_coefficients + + info1, info2, grads_per_step, _acc_prob = emaus(logdensity_fn, num_steps1=1000, num_steps2=3000, num_chains=4000, mesh=mesh, rng_key=key, + alpha = 1.9, bias_type= 3, C= 0.1, power= 3./8., + early_stop=1, r_end= 1e-2, diagonal_preconditioning= diagonal_preconditioning, integrator_coefficients= integrator_coefficients, + steps_per_sample= 15, acc_prob= None, ensemble_observables= lambda x: x + #ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return info2[1].reshape(info2[1].shape[0]*info2[1].shape[1], info2[1].shape[2]) @parameterized.parameters( itertools.product( @@ -457,6 +483,41 @@ def test_adjusted_mclmc(self, diagonal_preconditioning): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + # TODO: add preconditioning + def test_emaus(self,): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + model = Banana() + + states = self.run_emaus( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=model, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=True, + ) + + # coefs_samples = states["coefs"][3000:] + # scale_samples = np.exp(states["log_scale"][3000:]) + + # samples = states[3000:] + + print((states**2).mean(axis=0), Banana().E_x2) + + np.testing.assert_allclose((states**2).mean(axis=0), Banana().E_x2, atol=1e-2) + + # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) def test_mclmc_preconditioning(self): class IllConditionedGaussian: @@ -1223,5 +1284,46 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) + +#TODO: remove +class Banana(): + """Banana target fromm the Inference Gym""" + + def __init__(self, initialization= 'wide'): + self.name = 'Banana' + self.ndims = 2 + self.curvature = 0.03 + + self.transform = lambda x: x + self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.Var_x2 = jnp.array([20000.0, 4600.898]) + + if initialization == 'map': + self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) + elif initialization == 'posterior': + self.sample_init = lambda key: self.posterior_draw(key) + elif initialization == 'wide': + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + else: + raise ValueError('initialization = '+initialization +' is not a valid option.') + + def logdensity_fn(self, x): + mu2 = self.curvature * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + + def posterior_draw(self, key): + z = jax.random.normal(key, shape = (2, )) + x0 = 10.0 * z[0] + x1 = self.curvature * (x0 ** 2 - 100) + z[1] + return jnp.array([x0, x1]) + + def ground_truth(self): + x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + print(jnp.average(x, axis=0)) + print(jnp.average(jnp.square(x), axis=0)) + print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) + + if __name__ == "__main__": absltest.main() + From 6972f235b5dbb553ff38276ac596fdc645a762d5 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 3 Feb 2025 13:44:45 -0500 Subject: [PATCH 08/34] mid cleanup --- blackjax/adaptation/ensemble_mclmc.py | 107 ++++++++++++------------ blackjax/adaptation/ensemble_umclmc.py | 2 + blackjax/mcmc/metrics.py | 9 +- blackjax/optimizers/lbfgs.py | 4 +- blackjax/sgmcmc/csgld.py | 1 + blackjax/sgmcmc/sgnht.py | 1 + blackjax/smc/tuning/from_kernel_info.py | 1 + blackjax/smc/tuning/from_particles.py | 1 + tests/adaptation/test_mass_matrix.py | 1 + tests/mcmc/test_sampling.py | 107 +++++++++++++++--------- tests/mcmc/test_trajectory.py | 1 + tests/mcmc/test_uturn.py | 1 + tests/optimizers/test_optimizers.py | 1 + tests/optimizers/test_pathfinder.py | 1 + tests/smc/test_resampling.py | 1 + tests/smc/test_smc.py | 1 + tests/smc/test_smc_ess.py | 1 + tests/smc/test_solver.py | 1 + tests/smc/test_tempered_smc.py | 9 +- tests/test_benchmarks.py | 1 + tests/test_compilation.py | 1 + tests/test_diagnostics.py | 1 + 22 files changed, 149 insertions(+), 105 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index c01791daa..786d55195 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -22,8 +22,6 @@ from blackjax.adaptation.ensemble_umclmc import ( equipartition_diagonal, equipartition_diagonal_loss, - equipartition_fullrank, - equipartition_fullrank_loss, ) from blackjax.adaptation.step_size import bisection_monotonic_fn from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt @@ -37,30 +35,38 @@ class AdaptationState(NamedTuple): + steps_per_sample: float step_size: float - epsadap_state: Any - sample_count: int + stepsize_adaptation_state: ( + Any # the state of the bisection algorithm to find a stepsize + ) + iteration: int + -# put the arguments of build_kernel in a suitable order build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, -)(rng_key=key, state=state, step_size=adap.step_size, num_integration_steps=adap.steps_per_sample, L_proposal_factor=1.25) - +)( + rng_key=key, + state=state, + step_size=adap.step_size, + num_integration_steps=adap.steps_per_sample, + L_proposal_factor=1.25, +) class Adaptation: def __init__( self, - adap_state, - num_adaptation_samples, - steps_per_sample, + adaptation_state, + num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params + steps_per_sample, # L/eps (same for each chain: currently fixed to 15) acc_prob_target=0.8, - observables=lambda x: 0.0, - observables_for_bias=lambda x: 0.0, - contract=lambda x: 0.0, + observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep + observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains + contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions ): self.num_adaptation_samples = num_adaptation_samples self.observables = observables @@ -76,27 +82,30 @@ def __init__( # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adap_state.step_size # * integrator_factor * adjustment_factor + # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adaptation_state.step_size # * integrator_factor * adjustment_factor # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) # Initialize the dual averaging adaptation # # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - # epsadap_state = da_init_fn(step_size) + # stepsize_adaptation_state = da_init_fn(step_size) # Initialize the bisection for finding the step size - epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( + acc_prob_target + ) self.initial_state = AdaptationState( - steps_per_sample, step_size, epsadap_state, 0 + steps_per_sample, step_size, stepsize_adaptation_state, 0 ) def summary_statistics_fn(self, state, info, rng_key): return { "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal(state), - "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "equipartition_diagonal": equipartition_diagonal( + state + ), # metric for bias: equipartition theorem gives todo... "observables": self.observables(state.position), "observables_for_bias": self.observables_for_bias(state.position), } @@ -106,8 +115,7 @@ def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] # acc_prob = 1./Etheta['inv_acceptance_probability'] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) - true_bias = self.contract(Etheta["observables_for_bias"]) + true_bias = self.contract(Etheta["observables_for_bias"]) # remove info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -115,38 +123,23 @@ def update(self, adaptation_state, Etheta): "step_size": adaptation_state.step_size, "acc_prob": acc_prob, "equi_diag": equi_diag, - "equi_full": equi_full, "bias": true_bias, "observables": Etheta["observables"], } - # hyperparameter adaptation - - # Dual Averaging - # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples - - # def update(_): - # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) - # step_size = jnp.exp(da_state.log_step_size) - # return da_state, step_size - - # def dont_update(_): - # da_state = adaptation_state.epsadap_state - # return da_state, jnp.exp(da_state.log_step_size_avg) - - # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) - - # Bisection - epsadap_state, step_size = self.epsadap_update( - adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob + # Bisection to find step size + stepsize_adaptation_state, step_size = self.epsadap_update( + adaptation_state.stepsize_adaptation_state, + adaptation_state.step_size, + acc_prob, ) return ( AdaptationState( adaptation_state.steps_per_sample, step_size, - epsadap_state, - adaptation_state.sample_count + 1, + stepsize_adaptation_state, + adaptation_state.iteration + 1, ), info_to_be_stored, ) @@ -174,24 +167,25 @@ def while_steps_num(cond): def emaus( model, - num_steps1, - num_steps2, + num_steps1, # max number in phase 1 + num_steps2, # fixed number in phase 2 num_chains, mesh, rng_key, - alpha=1.9, - bias_type=0, - save_frac=0.2, - C=0.1, - power=3.0 / 8.0, - early_stop=True, + alpha=1.9, # L = \sqrt{d}*\alpha*vars + bias_type=0, # eliminate (fix to diagonal rank) + save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 + C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) + power=3.0 / 8.0, # eliminate + early_stop=True, # for stage 1 r_end=5e-3, # stage1 parameters diagonal_preconditioning=True, - integrator_coefficients=None, + integrator_coefficients=None, # (for stage 2) steps_per_sample=10, acc_prob=None, observables=lambda x: None, ensemble_observables=None, + diagnostics=True ): observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) @@ -309,4 +303,9 @@ def emaus( ensemble_observables, ) - return info1, info2, gradient_calls_per_step, _acc_prob + if diagnostics: + info = {"phase_1" : info1, "phase_2" : info2} + else: + info = None + + return info, gradient_calls_per_step, _acc_prob, final_state diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 7e2390def..d430b767e 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -128,6 +128,7 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) +# used for the early stopping class History(NamedTuple): observables: Array stopping: Array @@ -224,6 +225,7 @@ def __init__( history=history, ) + # info 1 def summary_statistics_fn(self, state, info, rng_key): position_flat, unravel_fn = ravel_pytree(state.position) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index f0720acf4..70e33d3a4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -43,8 +43,7 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> Numeric: - ... + ) -> Numeric: ... class CheckTurning(Protocol): @@ -55,8 +54,7 @@ def __call__( momentum_sum: ArrayLikeTree, position_left: Optional[ArrayLikeTree] = None, position_right: Optional[ArrayLikeTree] = None, - ) -> bool: - ... + ) -> bool: ... class Scale(Protocol): @@ -67,8 +65,7 @@ def __call__( *, inv: bool, trans: bool, - ) -> ArrayLikeTree: - ... + ) -> ArrayLikeTree: ... class Metric(NamedTuple): diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index 0dd59f003..aef55200f 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -269,9 +269,7 @@ def compute_next_alpha(s_l, z_l, alpha_lm1): b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l inv_alpha_l = ( - a / (b * alpha_lm1) - + z_l**2 / b - - (a * s_l**2) / (b * c * alpha_lm1**2) + a / (b * alpha_lm1) + z_l**2 / b - (a * s_l**2) / (b * c * alpha_lm1**2) ) return 1.0 / inv_alpha_l diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 506740c50..02766ca3d 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -41,6 +41,7 @@ class ContourSGLDState(NamedTuple): Index `i` such that the current position belongs to :math:`S_i`. """ + position: ArrayTree energy_pdf: Array energy_idx: int diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index ad9547406..7bcb2ccef 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -35,6 +35,7 @@ class SGNHTState(NamedTuple): Scalar thermostat controlling kinetic energy. """ + position: ArrayTree momentum: ArrayTree xi: float diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index a039e66c1..5725cc363 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -2,6 +2,7 @@ strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ + import jax import jax.numpy as jnp diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..2d0b737fa 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -2,6 +2,7 @@ strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ + import jax import jax.numpy as jnp from jax._src.flatten_util import ravel_pytree diff --git a/tests/adaptation/test_mass_matrix.py b/tests/adaptation/test_mass_matrix.py index 622b2111c..97d6ea882 100644 --- a/tests/adaptation/test_mass_matrix.py +++ b/tests/adaptation/test_mass_matrix.py @@ -1,4 +1,5 @@ """Test the welford adaptation algorithm.""" + import itertools import chex diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5d3dece82..7d600e14c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,4 +1,5 @@ """Test the accuracy of the MCMC kernels.""" + import functools import itertools @@ -285,31 +286,50 @@ def run_adjusted_mclmc( ) return out - + def run_emaus( - self, - initial_position, - logdensity_fn, - key, - num_steps, - diagonal_preconditioning, - ): + self, + initial_position, + logdensity_fn, + key, + num_steps, + diagonal_preconditioning, + ): - mesh = jax.sharding.Mesh(jax.devices(), 'chains') + mesh = jax.sharding.Mesh(jax.devices(), "chains") - from blackjax.mcmc.integrators import velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients + from blackjax.mcmc.integrators import ( + velocity_verlet_coefficients, + mclachlan_coefficients, + omelyan_coefficients, + ) + integrator_coefficients = mclachlan_coefficients - integrator_coefficients = mclachlan_coefficients + info1, info2, grads_per_step, _acc_prob = emaus( + logdensity_fn, + num_steps1=1000, + num_steps2=3000, + num_chains=4000, + mesh=mesh, + rng_key=key, + alpha=1.9, + bias_type=3, + C=0.1, + power=3.0 / 8.0, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=diagonal_preconditioning, + integrator_coefficients=integrator_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm - info1, info2, grads_per_step, _acc_prob = emaus(logdensity_fn, num_steps1=1000, num_steps2=3000, num_chains=4000, mesh=mesh, rng_key=key, - alpha = 1.9, bias_type= 3, C= 0.1, power= 3./8., - early_stop=1, r_end= 1e-2, diagonal_preconditioning= diagonal_preconditioning, integrator_coefficients= integrator_coefficients, - steps_per_sample= 15, acc_prob= None, ensemble_observables= lambda x: x - #ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - return info2[1].reshape(info2[1].shape[0]*info2[1].shape[1], info2[1].shape[2]) + return info2[1].reshape( + info2[1].shape[0] * info2[1].shape[1], info2[1].shape[2] + ) @parameterized.parameters( itertools.product( @@ -483,9 +503,11 @@ def test_adjusted_mclmc(self, diagonal_preconditioning): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - + # TODO: add preconditioning - def test_emaus(self,): + def test_emaus( + self, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -593,8 +615,7 @@ def get_inverse_mass_matrix(): assert ( jnp.abs( jnp.dot( - (inverse_mass_matrix**2) - / jnp.linalg.norm(inverse_mass_matrix**2), + (inverse_mass_matrix**2) / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 @@ -1284,41 +1305,50 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) - -#TODO: remove -class Banana(): +# TODO: remove +class Banana: """Banana target fromm the Inference Gym""" - def __init__(self, initialization= 'wide'): - self.name = 'Banana' + def __init__(self, initialization="wide"): + self.name = "Banana" self.ndims = 2 self.curvature = 0.03 - + self.transform = lambda x: x - self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.E_x2 = jnp.array( + [100.0, 19.0] + ) # the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. self.Var_x2 = jnp.array([20000.0, 4600.898]) - if initialization == 'map': + if initialization == "map": self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) - elif initialization == 'posterior': + elif initialization == "posterior": self.sample_init = lambda key: self.posterior_draw(key) - elif initialization == 'wide': - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + elif initialization == "wide": + self.sample_init = ( + lambda key: jax.random.normal(key, shape=(self.ndims,)) + * jnp.array([10.0, 5.0]) + * 2 + ) else: - raise ValueError('initialization = '+initialization +' is not a valid option.') + raise ValueError( + "initialization = " + initialization + " is not a valid option." + ) def logdensity_fn(self, x): mu2 = self.curvature * (x[0] ** 2 - 100) return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) def posterior_draw(self, key): - z = jax.random.normal(key, shape = (2, )) + z = jax.random.normal(key, shape=(2,)) x0 = 10.0 * z[0] - x1 = self.curvature * (x0 ** 2 - 100) + z[1] + x1 = self.curvature * (x0**2 - 100) + z[1] return jnp.array([x0, x1]) def ground_truth(self): - x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + x = jax.vmap(self.posterior_draw)( + jax.random.split(jax.random.PRNGKey(0), 100000000) + ) print(jnp.average(x, axis=0)) print(jnp.average(jnp.square(x), axis=0)) print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) @@ -1326,4 +1356,3 @@ def ground_truth(self): if __name__ == "__main__": absltest.main() - diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index e93280400..bc8490b19 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -1,4 +1,5 @@ """Test the trajectory integration""" + import chex import jax import jax.numpy as jnp diff --git a/tests/mcmc/test_uturn.py b/tests/mcmc/test_uturn.py index 7f9f597d6..ff1f261f6 100644 --- a/tests/mcmc/test_uturn.py +++ b/tests/mcmc/test_uturn.py @@ -1,4 +1,5 @@ """Test the iterative u-turn criterion.""" + import chex import jax.numpy as jnp from absl.testing import absltest, parameterized diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a7549842f..47f437af2 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -1,4 +1,5 @@ """Test optimizers.""" + import functools import chex diff --git a/tests/optimizers/test_pathfinder.py b/tests/optimizers/test_pathfinder.py index b9b9c69be..f40e79410 100644 --- a/tests/optimizers/test_pathfinder.py +++ b/tests/optimizers/test_pathfinder.py @@ -1,4 +1,5 @@ """Test the pathfinder algorithm.""" + import functools import chex diff --git a/tests/smc/test_resampling.py b/tests/smc/test_resampling.py index 20cb0d813..e6570f8f6 100644 --- a/tests/smc/test_resampling.py +++ b/tests/smc/test_resampling.py @@ -1,4 +1,5 @@ """Test the resampling functions for SMC.""" + import itertools import chex diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index b0e86e0b0..5c5e73259 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,5 @@ """Test the generic SMC sampler""" + import functools import chex diff --git a/tests/smc/test_smc_ess.py b/tests/smc/test_smc_ess.py index 570d392d9..1f02b8d61 100644 --- a/tests/smc/test_smc_ess.py +++ b/tests/smc/test_smc_ess.py @@ -1,4 +1,5 @@ """Test the ess function""" + import functools import chex diff --git a/tests/smc/test_solver.py b/tests/smc/test_solver.py index 49db84129..8bcdd6a07 100644 --- a/tests/smc/test_solver.py +++ b/tests/smc/test_solver.py @@ -1,4 +1,5 @@ """Test the solving functions""" + import itertools import chex diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 527457d62..ef8b8cb08 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -1,4 +1,5 @@ """Test the tempered SMC steps and routine""" + import functools import chex @@ -79,9 +80,11 @@ def logprior_fn(x): base_params, jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), jax.tree_util.tree_map_with_path( - lambda path, x: jnp.repeat(x, num_particles, axis=0) - if path[0].key == "step_size" - else x, + lambda path, x: ( + jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x + ), base_params, ), ] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 2d108a48d..ea9c6aa66 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -5,6 +5,7 @@ obviously more models. It should also be run in CI. """ + import functools import jax diff --git a/tests/test_compilation.py b/tests/test_compilation.py index 7179b71ba..2d3b03a44 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -5,6 +5,7 @@ internal changes do not trigger more compilations than is necessary. """ + import chex import jax import jax.numpy as jnp diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index b583c8645..1d7c74846 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -1,4 +1,5 @@ """Test MCMC diagnostics.""" + import functools import itertools From c96a8e82c1a2d170f47c01ab30d669ab65fe35d2 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 4 Feb 2025 13:16:03 -0500 Subject: [PATCH 09/34] fix while loop --- .../adaptation/adjusted_mclmc_adaptation.py | 2 +- blackjax/adaptation/ensemble_mclmc.py | 50 ++++++++++--------- blackjax/util.py | 29 ++++++++++- 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index eabb642a3..8c9fafc60 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -100,7 +100,7 @@ def adjusted_mclmc_find_L_and_step_size( state, params = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, - Lfactor=0.5, + Lfactor=0.3, max=max, eigenvector=eigenvector, )(state, params, num_steps, part2_key1) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 786d55195..95f2a9456 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -111,9 +111,7 @@ def summary_statistics_fn(self, state, info, rng_key): } def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars acc_prob = Etheta["acceptance_probability"] - # acc_prob = 1./Etheta['inv_acceptance_probability'] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) true_bias = self.contract(Etheta["observables_for_bias"]) # remove @@ -173,10 +171,8 @@ def emaus( mesh, rng_key, alpha=1.9, # L = \sqrt{d}*\alpha*vars - bias_type=0, # eliminate (fix to diagonal rank) save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) - power=3.0 / 8.0, # eliminate early_stop=True, # for stage 1 r_end=5e-3, # stage1 parameters diagonal_preconditioning=True, @@ -187,6 +183,28 @@ def emaus( ensemble_observables=None, diagnostics=True ): + + """ + model: the target density object + num_steps1: number of steps in the first phase + num_steps2: number of steps in the second phase + num_chains: number of chains + mesh: the mesh object, used for distributing the computation across cpus and nodes + rng_key: the random key + alpha: L = \sqrt{d}*\alpha*variances + save_frac: the fraction of samples used to estimate the fluctuation in the first phase + C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) + early_stop: whether to stop the first phase early + r_end + diagonal_preconditioning: whether to use diagonal preconditioning + integrator_coefficients: the coefficients of the integrator + steps_per_sample: the number of steps per sample + acc_prob: the acceptance probability + observables: the observables (for diagnostic use) + ensemble_observables: observable calculated over the ensemble (for diagnostic use) + diagnostics: whether to return diagnostics + """ + observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) @@ -201,15 +219,16 @@ def emaus( adap = umclmc.Adaptation( model.ndims, alpha=alpha, - bias_type=bias_type, + bias_type=3, save_num=save_num, C=C, - power=power, + power=3.0 / 8.0, r_end=r_end, observables=observables, observables_for_bias=observables_for_bias, contract=contract, ) + final_state, final_adaptation_state, info1 = run_eca( key_umclmc, initial_state, @@ -219,26 +238,9 @@ def emaus( num_chains, mesh, ensemble_observables, + early_stop=early_stop, ) - if ( - early_stop - ): # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. - num_steps_while = while_steps_num( - (info1[0] if ensemble_observables is not None else info1)["while_cond"] - ) - # print(num_steps_while, save_num) - final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - num_steps_while, - num_chains, - mesh, - ensemble_observables, - ) - # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: diff --git a/blackjax/util.py b/blackjax/util.py index e8c42f11f..53543f662 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,6 +11,8 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map + +import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -375,6 +377,7 @@ def run_eca( num_chains, mesh, ensemble_info=None, + early_stop=False, ): step = eca_step( kernel, @@ -396,7 +399,31 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): keys_adaptation, ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - final_state_all, info_history = lax.scan(step, initial_state_all, xs) + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) + output, info = step(x,auxilliary_input) + + + # jax.debug.print("info {x}", x=info[0].get("while_cond")) + # jax.debug.print("info {x}", x=i) + + return (output, i + 1, info[0].get("while_cond")) + + # jax.debug.print("initial {x}", x=0) + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2] ), step_while, (initial_state_all, 0, True) + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + final_state, final_adaptation_state = final_state_all return ( final_state, From 805113a2716a41e6e318a03580be5a04a33431fa Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 17:55:41 -0500 Subject: [PATCH 10/34] test passes --- blackjax/adaptation/ensemble_mclmc.py | 46 +++++++++++------ blackjax/adaptation/ensemble_umclmc.py | 36 +++++++++++-- blackjax/mcmc/integrators.py | 12 +++++ blackjax/mcmc/mclmc.py | 5 ++ blackjax/util.py | 14 +++++ tests/mcmc/test_sampling.py | 71 +++++++++++++++++--------- 6 files changed, 139 insertions(+), 45 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 95f2a9456..5b95aacbd 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -164,13 +164,16 @@ def while_steps_num(cond): def emaus( - model, + logdensity_fn, + sample_init, + transform, + ndims, num_steps1, # max number in phase 1 num_steps2, # fixed number in phase 2 num_chains, mesh, rng_key, - alpha=1.9, # L = \sqrt{d}*\alpha*vars + alpha=1.9, # L = sqrt{d}*alpha*vars save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) early_stop=True, # for stage 1 @@ -183,7 +186,6 @@ def emaus( ensemble_observables=None, diagnostics=True ): - """ model: the target density object num_steps1: number of steps in the first phase @@ -191,7 +193,7 @@ def emaus( num_chains: number of chains mesh: the mesh object, used for distributing the computation across cpus and nodes rng_key: the random key - alpha: L = \sqrt{d}*\alpha*variances + alpha: L = sqrt{d}*alpha*variances save_frac: the fraction of samples used to estimate the fluctuation in the first phase C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) early_stop: whether to stop the first phase early @@ -205,29 +207,35 @@ def emaus( diagnostics: whether to return diagnostics """ - observables_for_bias, contract = bias(model) + # observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) # initialize the chains initial_state = umclmc.initialize( - key_init, model.logdensity_fn, model.sample_init, num_chains, mesh + key_init, logdensity_fn, sample_init, num_chains, mesh ) + + # jax.debug.print("{x} foo", x=jax.flatten_util.ravel_pytree(initial_state.position)[0].shape[-1]) + ndims = 2 + # burn-in with the unadjusted method # - kernel = umclmc.build_kernel(model.logdensity_fn) + kernel = umclmc.build_kernel(logdensity_fn) save_num = (int)(jnp.rint(save_frac * num_steps1)) adap = umclmc.Adaptation( - model.ndims, + ndims, alpha=alpha, bias_type=3, save_num=save_num, C=C, power=3.0 / 8.0, r_end=r_end, - observables=observables, - observables_for_bias=observables_for_bias, - contract=contract, + # observables=observables, + observables_for_bias=lambda position: jnp.square(transform(jax.flatten_util.ravel_pytree(position)[0])), + # contract=contract, ) + + # jax.debug.print("initial_state.momentum: {x}", x=initial_state.momentum) final_state, final_adaptation_state, info1 = run_eca( key_umclmc, @@ -241,10 +249,13 @@ def emaus( early_stop=early_stop, ) + # print(final_state.position['coefs'].shape, "\n\nfoo\n\n") + # jax.debug.print("final_state.position: {x}", x=jnp.mean(final_state.position['coefs'])) + # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: - high_dims = model.ndims > 200 + high_dims = ndims > 200 _integrator_coefficients = ( omelyan_coefficients if high_dims else mclachlan_coefficients ) @@ -274,8 +285,11 @@ def emaus( inverse_mass_matrix = 1.0 kernel = build_kernel( - model.logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix ) + + + initial_state = HMCState( final_state.position, final_state.logdensity, final_state.logdensity_grad ) @@ -289,9 +303,9 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, - observables=observables, - observables_for_bias=observables_for_bias, - contract=contract, + # observables=observables, + # observables_for_bias=observables_for_bias, + # contract=contract, ) final_state, final_adaptation_state, info2 = run_eca( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index d430b767e..4e124421b 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -77,17 +77,33 @@ def sequential_init(key, x, args): velocity = unravel_fn( _normalized_flatten_array(flat_g)[0] ) # = grad logp/ |grad logp| + + jax.debug.print("logdensity {x}", x=logdensity_fn(position)) + # jax.debug.print("velocity {x}", x=velocity) + jax.debug.print("position {x}", x=position) + # jax.debug.print("logdensity_grad {x}", x=logdensity_grad) + # jax.debug.print("logdensity {x}", x=logdensity) + # jax.debug.print("flat_g {x}", x=flat_g) return IntegratorState(position, velocity, logdensity, logdensity_grad), None def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return -state.position * state.logdensity_grad + return 0 # -state.position * state.logdensity_grad + # TODO: restore! def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - velocity = jax.tree_util.tree_map( - lambda sign, u: sign * u, signs, state.momentum + # velocity = jax.tree_util.tree_map( + # lambda sign, u: sign * u, signs, state.momentum + # ) + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) + + velocity_flat = jax.tree_util.tree_map( + lambda sign, u: sign*u, signs, momentum ) + + velocity = unflatten(velocity_flat) + return ( IntegratorState( state.position, velocity, state.logdensity, state.logdensity_grad @@ -103,6 +119,9 @@ def ensemble_init(key, state, signs): mesh, summary_statistics_fn=summary_statistics_fn, ) + + # jax.debug.print("initial_state {x}", x=initial_state.momentum) + signs = -2.0 * (equipartition < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs @@ -112,7 +131,14 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): + # new_vals = jax.flatten_util.ravel_pytree(new_vals)[0] + # history = jax.flatten_util.ravel_pytree(history)[0] + # print(new_vals, "FOOO\n\n") + + new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) + # print(history, "FOOO\n\n") return jnp.concatenate((new_vals[None, :], history[:-1])) + # return history # TODO CHANGE BACK!!!! def update_history_scalar(new_val, history): @@ -192,7 +218,7 @@ def __init__( bias_type=0, save_num=10, observables=lambda x: 0.0, - observables_for_bias=lambda x: 0.0, + observables_for_bias=lambda x: x, contract=lambda x: 0.0, ): self.num_dims = num_dims @@ -250,6 +276,8 @@ def update(self, adaptation_state, Etheta): history_observables = update_history( Etheta["observables_for_bias"], adaptation_state.history.observables ) + # history_observables = adaptation_state.history.observables + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) history_stopping = update_history_scalar( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 733e7e960..1a5711add 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -169,6 +169,7 @@ def update( position, kinetic_grad, ) + # jax.debug.print("new_position {x}", x=new_position) logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) return new_position, logdensity, logdensity_grad, None @@ -330,6 +331,7 @@ def update( """ del is_last_call + # jax.debug.print("old momentum {x}", x=momentum) logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix @@ -338,6 +340,7 @@ def update( normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) + # jax.debug.print("delta {x}", x=delta) zeta = jnp.exp(-delta) new_momentum_raw = ( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) @@ -353,6 +356,8 @@ def update( ) * (dims - 1) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change + + # jax.debug.print("new_momentum {x}", x=next_momentum) return next_momentum, gr, kinetic_energy_change return update @@ -417,11 +422,15 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): momentum with random change in angle """ + # jax.debug.print("momentum unflat {x}", x=momentum) m, unravel_fn = ravel_pytree(momentum) + # jax.debug.print("momentum {x}", x=m) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + # jax.debug.print("z {x}", x=z) new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) + # jax.debug.print("new_momentum {x}", x=new_momentum) # return new_momentum return jax.lax.cond( jnp.isinf(L), @@ -435,6 +444,7 @@ def with_isokinetic_maruyama(integrator): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment + # jax.debug.print("state before noise {x}", x=init_state.momentum) state = init_state._replace( momentum=partially_refresh_momentum( momentum=init_state.momentum, @@ -443,8 +453,10 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): step_size=step_size * 0.5, ) ) + # jax.debug.print("state after noise {x}", x=state.momentum) # one step of the deterministic dynamics state, info = integrator(state, step_size) + # jax.debug.print("state after integ {x}", x=state.position) # partial refreshment state = state._replace( diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 2299dc68e..824fd9215 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -89,10 +89,15 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: + # jax.debug.print("state momentum 1 {x}", x=state.momentum) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) + # jax.debug.print("state position 2 {x}", x=position) + # jax.debug.print("state position {x}", x=state.position.mean(axis=0)) + # jax.debug.print("state position 2 {x}", x=position.mean(axis=0)) + return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( diff --git a/blackjax/util.py b/blackjax/util.py index 53543f662..9b0a96046 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -337,6 +337,7 @@ def _step(state_all, xs): # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) @@ -404,16 +405,29 @@ def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + # jax.debug.print("momentum init {x}", x=x[0].momentum) # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) + # print(x, "\n\n") output, info = step(x,auxilliary_input) + # print(output, "\n\n\nFOOO\n\n\n") + + # jax.debug.print("\nbar\n {x}", x=output[0].position['coefs'].mean(axis=0)) + # jax.debug.print("\nbar\n {x}", x=output[0].position.mean(axis=0)) + + check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) + # jax.debug.print("\nbaz\n {x}", x=check_state.position['coefs'].mean(axis=0)) + # jax.debug.print("\nbaz\n {x}", x=check_state.position.mean(axis=0)) # jax.debug.print("info {x}", x=info[0].get("while_cond")) # jax.debug.print("info {x}", x=i) return (output, i + 1, info[0].get("while_cond")) + # flatten with ravel: use ravel, not tree_map + # initial_state_all = ravel_pytree(initial_state_all)[0] + # jax.debug.print("initial {x}", x=0) if early_stop: final_state_all, i, _ = lax.while_loop( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7d600e14c..5e2d47e89 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -289,10 +289,11 @@ def run_adjusted_mclmc( def run_emaus( self, - initial_position, + sample_init, logdensity_fn, + ndims, + transform, key, - num_steps, diagonal_preconditioning, ): @@ -306,17 +307,18 @@ def run_emaus( integrator_coefficients = mclachlan_coefficients - info1, info2, grads_per_step, _acc_prob = emaus( - logdensity_fn, - num_steps1=1000, - num_steps2=3000, - num_chains=4000, + info, grads_per_step, _acc_prob, final_state = emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=100, mesh=mesh, rng_key=key, alpha=1.9, - bias_type=3, C=0.1, - power=3.0 / 8.0, early_stop=1, r_end=1e-2, diagonal_preconditioning=diagonal_preconditioning, @@ -327,9 +329,7 @@ def run_emaus( # ensemble_observables = lambda x: vec @ x ) # run the algorithm - return info2[1].reshape( - info2[1].shape[0] * info2[1].shape[1], info2[1].shape[2] - ) + return final_state.position @parameterized.parameters( itertools.product( @@ -511,35 +511,56 @@ def test_emaus( """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + + # model = Banana() + # logdensity_fn = model.logdensity_fn + # sample_init = model.sample_init + + x_data = jax.random.normal(init_key0, shape=(1000, 1)) y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) logposterior_fn_ = functools.partial( self.regression_logprob, x=x_data, preds=y_data ) - logdensity_fn = lambda x: logposterior_fn_(**x) + # logdensity_fn = lambda x: logposterior_fn_(coefs=x[0], log_scale=x[1]) + logdensity_fn = lambda x: logposterior_fn_(coefs=x['coefs'][0], log_scale=x['log_scale'][0]) + # logdensity_fn = lambda x: logposterior_fn_(**x) - model = Banana() + # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn(jnp.array([[1.5606847], [1.719502]]))) + # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn({"coefs": jnp.array(1.5606847), "log_scale": jnp.array(1.719502)})) - states = self.run_emaus( - initial_position={"coefs": 1.0, "log_scale": 1.0}, - logdensity_fn=model, + + def sample_init(key): + key1, key2 = jax.random.split(key) + coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + return {"coefs": coefs, "log_scale": log_scale} + # return jnp.concatenate([coefs, log_scale]) + + + samples = self.run_emaus( + sample_init=sample_init, + logdensity_fn=logdensity_fn, + transform=lambda x: x, + ndims=2, key=inference_key, - num_steps=10000, diagonal_preconditioning=True, ) - # coefs_samples = states["coefs"][3000:] - # scale_samples = np.exp(states["log_scale"][3000:]) - # samples = states[3000:] - print((states**2).mean(axis=0), Banana().E_x2) + # # jax.debug.print("pos mean, {x}", x=jnp.mean(samples["coefs"][-1])) + + + coefs_samples = samples["coefs"] + scale_samples = np.exp(samples["log_scale"]) - np.testing.assert_allclose((states**2).mean(axis=0), Banana().E_x2, atol=1e-2) + jax.debug.print("coefs_samples mean {x}", x=jnp.mean(coefs_samples)) + jax.debug.print("scale_samples mean {x}", x=jnp.mean(scale_samples)) - # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) def test_mclmc_preconditioning(self): class IllConditionedGaussian: From 16841c6bb96d6b3380b59f5693298ea4db80360a Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 18:00:26 -0500 Subject: [PATCH 11/34] precommit --- blackjax/adaptation/ensemble_mclmc.py | 22 ++++++--------- blackjax/adaptation/ensemble_umclmc.py | 15 +++------- blackjax/mcmc/integrators.py | 13 +-------- blackjax/mcmc/mclmc.py | 5 ---- blackjax/mcmc/metrics.py | 9 ++++-- blackjax/optimizers/lbfgs.py | 4 ++- blackjax/util.py | 26 +++-------------- tests/mcmc/test_sampling.py | 39 ++++++-------------------- 8 files changed, 34 insertions(+), 99 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 5b95aacbd..7294f4936 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -35,7 +35,6 @@ class AdaptationState(NamedTuple): - steps_per_sample: float step_size: float stepsize_adaptation_state: ( @@ -83,7 +82,9 @@ def __init__( # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adaptation_state.step_size # * integrator_factor * adjustment_factor + step_size = ( + adaptation_state.step_size + ) # * integrator_factor * adjustment_factor # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) @@ -184,7 +185,7 @@ def emaus( acc_prob=None, observables=lambda x: None, ensemble_observables=None, - diagnostics=True + diagnostics=True, ): """ model: the target density object @@ -215,8 +216,6 @@ def emaus( key_init, logdensity_fn, sample_init, num_chains, mesh ) - - # jax.debug.print("{x} foo", x=jax.flatten_util.ravel_pytree(initial_state.position)[0].shape[-1]) ndims = 2 # burn-in with the unadjusted method # @@ -231,12 +230,12 @@ def emaus( power=3.0 / 8.0, r_end=r_end, # observables=observables, - observables_for_bias=lambda position: jnp.square(transform(jax.flatten_util.ravel_pytree(position)[0])), + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), # contract=contract, ) - # jax.debug.print("initial_state.momentum: {x}", x=initial_state.momentum) - final_state, final_adaptation_state, info1 = run_eca( key_umclmc, initial_state, @@ -249,9 +248,6 @@ def emaus( early_stop=early_stop, ) - # print(final_state.position['coefs'].shape, "\n\nfoo\n\n") - # jax.debug.print("final_state.position: {x}", x=jnp.mean(final_state.position['coefs'])) - # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: @@ -288,8 +284,6 @@ def emaus( logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix ) - - initial_state = HMCState( final_state.position, final_state.logdensity, final_state.logdensity_grad ) @@ -320,7 +314,7 @@ def emaus( ) if diagnostics: - info = {"phase_1" : info1, "phase_2" : info2} + info = {"phase_1": info1, "phase_2": info2} else: info = None diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 4e124421b..f919a82e8 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -78,17 +78,12 @@ def sequential_init(key, x, args): _normalized_flatten_array(flat_g)[0] ) # = grad logp/ |grad logp| - jax.debug.print("logdensity {x}", x=logdensity_fn(position)) - # jax.debug.print("velocity {x}", x=velocity) - jax.debug.print("position {x}", x=position) - # jax.debug.print("logdensity_grad {x}", x=logdensity_grad) - # jax.debug.print("logdensity {x}", x=logdensity) - # jax.debug.print("flat_g {x}", x=flat_g) return IntegratorState(position, velocity, logdensity, logdensity_grad), None def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return 0 # -state.position * state.logdensity_grad + return 0 # -state.position * state.logdensity_grad + # TODO: restore! def ensemble_init(key, state, signs): @@ -99,7 +94,7 @@ def ensemble_init(key, state, signs): momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) velocity_flat = jax.tree_util.tree_map( - lambda sign, u: sign*u, signs, momentum + lambda sign, u: sign * u, signs, momentum ) velocity = unflatten(velocity_flat) @@ -120,8 +115,6 @@ def ensemble_init(key, state, signs): summary_statistics_fn=summary_statistics_fn, ) - # jax.debug.print("initial_state {x}", x=initial_state.momentum) - signs = -2.0 * (equipartition < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs @@ -277,7 +270,7 @@ def update(self, adaptation_state, Etheta): Etheta["observables_for_bias"], adaptation_state.history.observables ) # history_observables = adaptation_state.history.observables - + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) history_stopping = update_history_scalar( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1a5711add..49700697d 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -169,7 +169,6 @@ def update( position, kinetic_grad, ) - # jax.debug.print("new_position {x}", x=new_position) logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) return new_position, logdensity, logdensity_grad, None @@ -331,7 +330,6 @@ def update( """ del is_last_call - # jax.debug.print("old momentum {x}", x=momentum) logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix @@ -340,7 +338,6 @@ def update( normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) - # jax.debug.print("delta {x}", x=delta) zeta = jnp.exp(-delta) new_momentum_raw = ( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) @@ -357,7 +354,6 @@ def update( if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change - # jax.debug.print("new_momentum {x}", x=next_momentum) return next_momentum, gr, kinetic_energy_change return update @@ -422,16 +418,12 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): momentum with random change in angle """ - # jax.debug.print("momentum unflat {x}", x=momentum) m, unravel_fn = ravel_pytree(momentum) - # jax.debug.print("momentum {x}", x=m) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - # jax.debug.print("z {x}", x=z) new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) - # jax.debug.print("new_momentum {x}", x=new_momentum) - # return new_momentum + return jax.lax.cond( jnp.isinf(L), lambda _: momentum, @@ -444,7 +436,6 @@ def with_isokinetic_maruyama(integrator): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment - # jax.debug.print("state before noise {x}", x=init_state.momentum) state = init_state._replace( momentum=partially_refresh_momentum( momentum=init_state.momentum, @@ -453,10 +444,8 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): step_size=step_size * 0.5, ) ) - # jax.debug.print("state after noise {x}", x=state.momentum) # one step of the deterministic dynamics state, info = integrator(state, step_size) - # jax.debug.print("state after integ {x}", x=state.position) # partial refreshment state = state._replace( diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 824fd9215..2299dc68e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -89,15 +89,10 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: - # jax.debug.print("state momentum 1 {x}", x=state.momentum) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) - # jax.debug.print("state position 2 {x}", x=position) - # jax.debug.print("state position {x}", x=state.position.mean(axis=0)) - # jax.debug.print("state position 2 {x}", x=position.mean(axis=0)) - return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 70e33d3a4..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -43,7 +43,8 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> Numeric: ... + ) -> Numeric: + ... class CheckTurning(Protocol): @@ -54,7 +55,8 @@ def __call__( momentum_sum: ArrayLikeTree, position_left: Optional[ArrayLikeTree] = None, position_right: Optional[ArrayLikeTree] = None, - ) -> bool: ... + ) -> bool: + ... class Scale(Protocol): @@ -65,7 +67,8 @@ def __call__( *, inv: bool, trans: bool, - ) -> ArrayLikeTree: ... + ) -> ArrayLikeTree: + ... class Metric(NamedTuple): diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index aef55200f..0dd59f003 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -269,7 +269,9 @@ def compute_next_alpha(s_l, z_l, alpha_lm1): b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l inv_alpha_l = ( - a / (b * alpha_lm1) + z_l**2 / b - (a * s_l**2) / (b * c * alpha_lm1**2) + a / (b * alpha_lm1) + + z_l**2 / b + - (a * s_l**2) / (b * c * alpha_lm1**2) ) return 1.0 / inv_alpha_l diff --git a/blackjax/util.py b/blackjax/util.py index 9b0a96046..5965ac93e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,8 +11,6 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map - -import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -337,7 +335,6 @@ def _step(state_all, xs): # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) @@ -405,33 +402,18 @@ def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - # jax.debug.print("momentum init {x}", x=x[0].momentum) - - # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) - # print(x, "\n\n") - output, info = step(x,auxilliary_input) - # print(output, "\n\n\nFOOO\n\n\n") - - # jax.debug.print("\nbar\n {x}", x=output[0].position['coefs'].mean(axis=0)) - # jax.debug.print("\nbar\n {x}", x=output[0].position.mean(axis=0)) + output, info = step(x, auxilliary_input) check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) - - # jax.debug.print("\nbaz\n {x}", x=check_state.position['coefs'].mean(axis=0)) - # jax.debug.print("\nbaz\n {x}", x=check_state.position.mean(axis=0)) - # jax.debug.print("info {x}", x=info[0].get("while_cond")) - # jax.debug.print("info {x}", x=i) return (output, i + 1, info[0].get("while_cond")) - # flatten with ravel: use ravel, not tree_map - # initial_state_all = ravel_pytree(initial_state_all)[0] - - # jax.debug.print("initial {x}", x=0) if early_stop: final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2] ), step_while, (initial_state_all, 0, True) + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), ) info_history = None diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5e2d47e89..493b441d0 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -3,7 +3,6 @@ import functools import itertools -from blackjax.adaptation.ensemble_mclmc import emaus import chex import jax import jax.numpy as jnp @@ -16,6 +15,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.adaptation.ensemble_mclmc import emaus from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -296,14 +296,9 @@ def run_emaus( key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") - from blackjax.mcmc.integrators import ( - velocity_verlet_coefficients, - mclachlan_coefficients, - omelyan_coefficients, - ) + from blackjax.mcmc.integrators import mclachlan_coefficients integrator_coefficients = mclachlan_coefficients @@ -511,33 +506,22 @@ def test_emaus( """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) - - # model = Banana() - # logdensity_fn = model.logdensity_fn - # sample_init = model.sample_init - x_data = jax.random.normal(init_key0, shape=(1000, 1)) y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) logposterior_fn_ = functools.partial( self.regression_logprob, x=x_data, preds=y_data ) - # logdensity_fn = lambda x: logposterior_fn_(coefs=x[0], log_scale=x[1]) - logdensity_fn = lambda x: logposterior_fn_(coefs=x['coefs'][0], log_scale=x['log_scale'][0]) - # logdensity_fn = lambda x: logposterior_fn_(**x) - - # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn(jnp.array([[1.5606847], [1.719502]]))) - # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn({"coefs": jnp.array(1.5606847), "log_scale": jnp.array(1.719502)})) - + logdensity_fn = lambda x: logposterior_fn_( + coefs=x["coefs"][0], log_scale=x["log_scale"][0] + ) def sample_init(key): key1, key2 = jax.random.split(key) coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) - log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) return {"coefs": coefs, "log_scale": log_scale} - # return jnp.concatenate([coefs, log_scale]) - samples = self.run_emaus( sample_init=sample_init, @@ -548,17 +532,9 @@ def sample_init(key): diagonal_preconditioning=True, ) - - - # # jax.debug.print("pos mean, {x}", x=jnp.mean(samples["coefs"][-1])) - - coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - jax.debug.print("coefs_samples mean {x}", x=jnp.mean(coefs_samples)) - jax.debug.print("scale_samples mean {x}", x=jnp.mean(scale_samples)) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -636,7 +612,8 @@ def get_inverse_mass_matrix(): assert ( jnp.abs( jnp.dot( - (inverse_mass_matrix**2) / jnp.linalg.norm(inverse_mass_matrix**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From 52ce7ad935f63e6174ef4224c5abb86dc2fbaa40 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 18:18:55 -0500 Subject: [PATCH 12/34] update --- tests/mcmc/test_sampling.py | 150 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7718ad2b6..ae83927d8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -15,10 +15,10 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.adaptation.ensemble_mclmc import emaus from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm -from blackjax.adaptation.ensemble_mclmc import emaus def orbit_samples(orbits, weights, rng_key): @@ -291,43 +291,43 @@ def run_adjusted_mclmc_static( return out def run_emaus( - self, - sample_init, - logdensity_fn, - ndims, - transform, - key, - diagonal_preconditioning, - ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") - - from blackjax.mcmc.integrators import mclachlan_coefficients - - integrator_coefficients = mclachlan_coefficients - - info, grads_per_step, _acc_prob, final_state = emaus( - logdensity_fn=logdensity_fn, - sample_init=sample_init, - transform=transform, - ndims=ndims, - num_steps1=100, - num_steps2=300, - num_chains=100, - mesh=mesh, - rng_key=key, - alpha=1.9, - C=0.1, - early_stop=1, - r_end=1e-2, - diagonal_preconditioning=diagonal_preconditioning, - integrator_coefficients=integrator_coefficients, - steps_per_sample=15, - acc_prob=None, - ensemble_observables=lambda x: x, - # ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - return final_state.position + self, + sample_init, + logdensity_fn, + ndims, + transform, + key, + diagonal_preconditioning, + ): + mesh = jax.sharding.Mesh(jax.devices(), "chains") + + from blackjax.mcmc.integrators import mclachlan_coefficients + + integrator_coefficients = mclachlan_coefficients + + info, grads_per_step, _acc_prob, final_state = emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=100, + mesh=mesh, + rng_key=key, + alpha=1.9, + C=0.1, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=diagonal_preconditioning, + integrator_coefficients=integrator_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return final_state.position @parameterized.parameters( itertools.product( @@ -576,42 +576,42 @@ def get_inverse_mass_matrix(): ) def test_emaus( - self, - ): - """Test the MCLMC kernel.""" - - init_key0, init_key1, inference_key = jax.random.split(self.key, 3) - - x_data = jax.random.normal(init_key0, shape=(1000, 1)) - y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - - logposterior_fn_ = functools.partial( - self.regression_logprob, x=x_data, preds=y_data - ) - logdensity_fn = lambda x: logposterior_fn_( - coefs=x["coefs"][0], log_scale=x["log_scale"][0] - ) - - def sample_init(key): - key1, key2 = jax.random.split(key) - coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) - log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) - return {"coefs": coefs, "log_scale": log_scale} - - samples = self.run_emaus( - sample_init=sample_init, - logdensity_fn=logdensity_fn, - transform=lambda x: x, - ndims=2, - key=inference_key, - diagonal_preconditioning=True, - ) - - coefs_samples = samples["coefs"] - scale_samples = np.exp(samples["log_scale"]) - - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + self, + ): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_( + coefs=x["coefs"][0], log_scale=x["log_scale"][0] + ) + + def sample_init(key): + key1, key2 = jax.random.split(key) + coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + return {"coefs": coefs, "log_scale": log_scale} + + samples = self.run_emaus( + sample_init=sample_init, + logdensity_fn=logdensity_fn, + transform=lambda x: x, + ndims=2, + key=inference_key, + diagonal_preconditioning=True, + ) + + coefs_samples = samples["coefs"] + scale_samples = np.exp(samples["log_scale"]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( @@ -1296,4 +1296,4 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From cc7bfbd00bac75a3f0f50802dc9a5f37a0909e50 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 11:59:31 -0500 Subject: [PATCH 13/34] docstrings --- blackjax/sgmcmc/csgld.py | 1 - blackjax/sgmcmc/sgnht.py | 1 - blackjax/util.py | 22 +++++++++++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 02766ca3d..506740c50 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -41,7 +41,6 @@ class ContourSGLDState(NamedTuple): Index `i` such that the current position belongs to :math:`S_i`. """ - position: ArrayTree energy_pdf: Array energy_idx: int diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 7bcb2ccef..ad9547406 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -35,7 +35,6 @@ class SGNHTState(NamedTuple): Scalar thermostat controlling kinetic energy. """ - position: ArrayTree momentum: ArrayTree xi: float diff --git a/blackjax/util.py b/blackjax/util.py index 5965ac93e..d0f97aa90 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -377,6 +377,25 @@ def run_eca( ensemble_info=None, early_stop=False, ): + """ + Run ensemble of chains in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + step = eca_step( kernel, adaptation.summary_statistics_fn, @@ -405,8 +424,6 @@ def step_while(a): output, info = step(x, auxilliary_input) - check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) - return (output, i + 1, info[0].get("while_cond")) if early_stop: @@ -437,7 +454,6 @@ def step_while(a): ) # produce all random keys that will be needed - # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] key_sampling, key_adaptation = split(rng_key) num_steps = jnp.array(num_steps).item() From 58d9920a2fa6281334e7832cb4a6c37dbc9181f9 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 12:11:55 -0500 Subject: [PATCH 14/34] remove debug statements --- blackjax/adaptation/ensemble_mclmc.py | 30 +++++++++----------------- blackjax/adaptation/ensemble_umclmc.py | 14 +----------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7294f4936..7a117665f 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -84,13 +84,8 @@ def __init__( # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) step_size = ( adaptation_state.step_size - ) # * integrator_factor * adjustment_factor + ) - # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - # Initialize the dual averaging adaptation # - # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - # stepsize_adaptation_state = da_init_fn(step_size) # Initialize the bisection for finding the step size stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( @@ -169,18 +164,18 @@ def emaus( sample_init, transform, ndims, - num_steps1, # max number in phase 1 - num_steps2, # fixed number in phase 2 + num_steps1, + num_steps2, num_chains, mesh, rng_key, - alpha=1.9, # L = sqrt{d}*alpha*vars - save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 - C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) - early_stop=True, # for stage 1 - r_end=5e-3, # stage1 parameters + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end=5e-3, diagonal_preconditioning=True, - integrator_coefficients=None, # (for stage 2) + integrator_coefficients=None, steps_per_sample=10, acc_prob=None, observables=lambda x: None, @@ -229,11 +224,9 @@ def emaus( C=C, power=3.0 / 8.0, r_end=r_end, - # observables=observables, observables_for_bias=lambda position: jnp.square( transform(jax.flatten_util.ravel_pytree(position)[0]) ), - # contract=contract, ) final_state, final_adaptation_state, info1 = run_eca( @@ -248,7 +241,7 @@ def emaus( early_stop=early_stop, ) - # refine the results with the adjusted method # + # refine the results with the adjusted method _acc_prob = acc_prob if integrator_coefficients is None: high_dims = ndims > 200 @@ -297,9 +290,6 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, - # observables=observables, - # observables_for_bias=observables_for_bias, - # contract=contract, ) final_state, final_adaptation_state, info2 = run_eca( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index f919a82e8..830099ce6 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -88,9 +88,7 @@ def summary_statistics_fn(state): def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - # velocity = jax.tree_util.tree_map( - # lambda sign, u: sign * u, signs, state.momentum - # ) + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) velocity_flat = jax.tree_util.tree_map( @@ -124,14 +122,8 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): - # new_vals = jax.flatten_util.ravel_pytree(new_vals)[0] - # history = jax.flatten_util.ravel_pytree(history)[0] - # print(new_vals, "FOOO\n\n") - new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) - # print(history, "FOOO\n\n") return jnp.concatenate((new_vals[None, :], history[:-1])) - # return history # TODO CHANGE BACK!!!! def update_history_scalar(new_val, history): @@ -146,8 +138,6 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) - -# used for the early stopping class History(NamedTuple): observables: Array stopping: Array @@ -224,8 +214,6 @@ def __init__( self.contract = contract self.bias_type = bias_type self.save_num = save_num - # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - r_save_num = save_num history = History( From 4274b07fdbace00a49b453364310a10ab2e805da Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 12:14:57 -0500 Subject: [PATCH 15/34] precommit --- blackjax/adaptation/ensemble_mclmc.py | 5 +---- blackjax/adaptation/ensemble_umclmc.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7a117665f..363b2df7f 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -82,10 +82,7 @@ def __init__( # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = ( - adaptation_state.step_size - ) - + step_size = adaptation_state.step_size # Initialize the bisection for finding the step size stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 830099ce6..e65f05c50 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -138,6 +138,7 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) + class History(NamedTuple): observables: Array stopping: Array From d951a4456e0caa3286123654e3a198d664d88768 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:12:18 -0500 Subject: [PATCH 16/34] modify test --- blackjax/util.py | 6 +++++- tests/mcmc/test_sampling.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index d0f97aa90..ee71af2b9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -321,6 +321,10 @@ def incremental_value_update( def eca_step( kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None ): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + def _step(state_all, xs): """This function operates on a single device.""" ( @@ -378,7 +382,7 @@ def run_eca( early_stop=False, ): """ - Run ensemble of chains in parallel on multiple devices. + Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- Args: rng_key: random key diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ae83927d8..5d43801d7 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=100, + num_chains=300, mesh=mesh, rng_key=key, alpha=1.9, From ba8f6ebebf4106e844bf305fa53f25c4c9186bb2 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:30:44 -0500 Subject: [PATCH 17/34] modify test --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5d43801d7..e73d3c557 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=300, + num_chains=400, mesh=mesh, rng_key=key, alpha=1.9, From 29dcd5440ca341fd656de94b690ea9dc7bb2843f Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:39:20 -0500 Subject: [PATCH 18/34] modify test --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e73d3c557..0379bdd74 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=400, + num_chains=700, mesh=mesh, rng_key=key, alpha=1.9, From 4b7d8b004cabe3b93dd2496d4a36feabe39bff77 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:52:30 -0500 Subject: [PATCH 19/34] modify test --- tests/mcmc/test_sampling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 0379bdd74..a8272cfab 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=700, + num_chains=800, mesh=mesh, rng_key=key, alpha=1.9, @@ -610,8 +610,10 @@ def sample_init(key): coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + print(np.mean(scale_samples), np.mean(coefs_samples), "foo") + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( From 67cfd715b687782103d0cc7e0269d025e0195994 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 22 Feb 2025 15:32:44 -0500 Subject: [PATCH 20/34] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 4 +--- blackjax/adaptation/ensemble_umclmc.py | 14 +++++++------- blackjax/mcmc/termination.py | 6 +++--- tests/mcmc/test_sampling.py | 2 -- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 363b2df7f..bbc1bb5f7 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -81,7 +81,7 @@ def __init__( # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) step_size = adaptation_state.step_size # Initialize the bisection for finding the step size @@ -208,8 +208,6 @@ def emaus( key_init, logdensity_fn, sample_init, num_chains, mesh ) - ndims = 2 - # burn-in with the unadjusted method # kernel = umclmc.build_kernel(logdensity_fn) save_num = (int)(jnp.rint(save_frac * num_steps1)) diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index e65f05c50..4068bee67 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -123,7 +123,7 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) - return jnp.concatenate((new_vals[None, :], history[:-1])) + return jnp.concatenate((new_vals[None, :], history[:-1, :])) def update_history_scalar(new_val, history): @@ -194,7 +194,7 @@ def equipartition_fullrank_loss(delta_z): class Adaptation: def __init__( self, - num_dims, + ndims, alpha=1.0, C=0.1, power=3.0 / 8.0, @@ -205,7 +205,7 @@ def __init__( observables_for_bias=lambda x: x, contract=lambda x: 0.0, ): - self.num_dims = num_dims + self.ndims = ndims self.alpha = alpha self.C = C self.power = power @@ -218,15 +218,15 @@ def __init__( r_save_num = save_num history = History( - observables=jnp.zeros((r_save_num, num_dims)), + observables=jnp.zeros((r_save_num, ndims)), stopping=jnp.full((save_num,), jnp.nan), weights=jnp.zeros(r_save_num), ) self.initial_state = AdaptationState( L=jnp.inf, # do not add noise for the first step - inverse_mass_matrix=jnp.ones(num_dims), - step_size=0.01 * jnp.sqrt(num_dims), + inverse_mass_matrix=jnp.ones(ndims), + step_size=0.01 * jnp.sqrt(ndims), step_count=0, EEVPD=1e-3, EEVPD_wanted=1e-3, @@ -277,7 +277,7 @@ def update(self, adaptation_state, Etheta): jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) ) # average over the ensemble, sum over parameters (to get sqrt(d)) inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) - EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.num_dims + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.ndims true_bias = self.contract(Etheta["observables_for_bias"]) nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index eb1276da3..9fa7cee6c 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -33,10 +33,10 @@ def iterative_uturn_numpyro(is_turning: CheckTurning): def new_state(chain_state, max_num_doublings) -> IterativeUTurnState: flat, _ = jax.flatten_util.ravel_pytree(chain_state.position) - num_dims = jnp.shape(flat)[0] + ndims = jnp.shape(flat)[0] return IterativeUTurnState( - jnp.zeros((max_num_doublings, num_dims)), - jnp.zeros((max_num_doublings, num_dims)), + jnp.zeros((max_num_doublings, ndims)), + jnp.zeros((max_num_doublings, ndims)), 0, 0, ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a8272cfab..886cdce0d 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -610,8 +610,6 @@ def sample_init(key): coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - print(np.mean(scale_samples), np.mean(coefs_samples), "foo") - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) From 5ee0b8a188e131643316249fdaa4c967d35bb31e Mon Sep 17 00:00:00 2001 From: Reuben Date: Thu, 27 Feb 2025 07:03:09 -0500 Subject: [PATCH 21/34] Update blackjax/adaptation/step_size.py Co-authored-by: Junpeng Lao --- blackjax/adaptation/step_size.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 94c634ce3..61ceed592 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -270,12 +270,14 @@ def update(state, exp_x, acc_rate_new): x = jnp.log(exp_x) def on_true(bounds): - bounds0 = jnp.max(jnp.array([bounds[0], x])) - return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + lower, upper = bounds + lower = jnp.max(jnp.array([lower, x])) + return jnp.array([lower, upper]), lower + reduce_shift def on_false(bounds): - bounds1 = jnp.min(jnp.array([bounds[1], x])) - return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + lower, upper = bounds + upper = jnp.min(jnp.array([upper, x])) + return jnp.array([lower, upper]), upper - reduce_shift bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) From 2b3362502f799b3e6d06b0cc7e92425e3aa8fc32 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Feb 2025 07:03:32 -0500 Subject: [PATCH 22/34] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 9 ++++----- blackjax/adaptation/step_size.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index bbc1bb5f7..15d386346 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -61,7 +61,7 @@ def __init__( self, adaptation_state, num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params - steps_per_sample, # L/eps (same for each chain: currently fixed to 15) + steps_per_sample=15, # L/eps acc_prob_target=0.8, observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains @@ -85,9 +85,8 @@ def __init__( step_size = adaptation_state.step_size # Initialize the bisection for finding the step size - stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( - acc_prob_target - ) + self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state = ((jnp.array([-jnp.inf, jnp.inf]), False),) self.initial_state = AdaptationState( steps_per_sample, step_size, stepsize_adaptation_state, 0 @@ -173,7 +172,7 @@ def emaus( r_end=5e-3, diagonal_preconditioning=True, integrator_coefficients=None, - steps_per_sample=10, + steps_per_sample=15, acc_prob=None, observables=lambda x: None, ensemble_observables=None, diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 94c634ce3..ee2298c8f 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -298,4 +298,4 @@ def bisect(bounds): return (bounds_new, terminated_new), stepsize - return (jnp.array([-jnp.inf, jnp.inf]), False), update + return update From cc5e09a57a64700069fd63b9bc549e4a164dbda9 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Feb 2025 07:17:23 -0500 Subject: [PATCH 23/34] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 15d386346..a43d0eb59 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -86,7 +86,7 @@ def __init__( # Initialize the bisection for finding the step size self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - stepsize_adaptation_state = ((jnp.array([-jnp.inf, jnp.inf]), False),) + stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) self.initial_state = AdaptationState( steps_per_sample, step_size, stepsize_adaptation_state, 0 From bd40cf9c3a88ecd6599f96f049dc19650244530b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 5 Mar 2025 14:03:45 -0500 Subject: [PATCH 24/34] bug present in minimal_repro_3.py --- blackjax/adaptation/ensemble_mclmc.py | 8 + blackjax/mcmc/alternate_emaus.py | 85 +++++ tests/mcmc/minimal_repro.py | 300 +++++++++++++++ tests/mcmc/minimal_repro_2.py | 419 +++++++++++++++++++++ tests/mcmc/minimal_repro_3.py | 514 ++++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 3 +- 6 files changed, 1328 insertions(+), 1 deletion(-) create mode 100644 blackjax/mcmc/alternate_emaus.py create mode 100644 tests/mcmc/minimal_repro.py create mode 100644 tests/mcmc/minimal_repro_2.py create mode 100644 tests/mcmc/minimal_repro_3.py diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index a43d0eb59..3edae76bd 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -13,6 +13,12 @@ # limitations under the License. # """Public API for the MCLMC Kernel""" +# import jax +# import jax.numpy as jnp +# from blackjax.util import run_eca +# import blackjax.adaptation.ensemble_umclmc as umclmc + + from typing import Any, NamedTuple import jax @@ -286,6 +292,8 @@ def emaus( _acc_prob, ) + + final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/mcmc/alternate_emaus.py b/blackjax/mcmc/alternate_emaus.py new file mode 100644 index 000000000..6010bab73 --- /dev/null +++ b/blackjax/mcmc/alternate_emaus.py @@ -0,0 +1,85 @@ +import jax +import jax.numpy as jnp +from blackjax.util import run_eca +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def emaus( + logdensity_fn, + sample_init, + transform, + ndims, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end=5e-3, + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=15, + acc_prob=None, + observables=lambda x: None, + ensemble_observables=None, + diagnostics=True, +): + """ + model: the target density object + num_steps1: number of steps in the first phase + num_steps2: number of steps in the second phase + num_chains: number of chains + mesh: the mesh object, used for distributing the computation across cpus and nodes + rng_key: the random key + alpha: L = sqrt{d}*alpha*variances + save_frac: the fraction of samples used to estimate the fluctuation in the first phase + C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) + early_stop: whether to stop the first phase early + r_end + diagonal_preconditioning: whether to use diagonal preconditioning + integrator_coefficients: the coefficients of the integrator + steps_per_sample: the number of steps per sample + acc_prob: the acceptance probability + observables: the observables (for diagnostic use) + ensemble_observables: observable calculated over the ensemble (for diagnostic use) + diagnostics: whether to return diagnostics + """ + + # observables_for_bias, contract = bias(model) + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh + ) + + # burn-in with the unadjusted method # + kernel = umclmc.build_kernel(logdensity_fn) + save_num = (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), + ) + + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + ensemble_observables, + early_stop=early_stop, + ) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro.py b/tests/mcmc/minimal_repro.py new file mode 100644 index 000000000..639b624fc --- /dev/null +++ b/tests/mcmc/minimal_repro.py @@ -0,0 +1,300 @@ +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + + def _step(state_all, xs): + """This function operates on a single device.""" + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + + return (state, adaptation_state), info_to_be_stored + + if ensemble_info is not None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + + return step + + else: + return _step + + + + +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) + +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, + early_stop=False, +): + """ + Run ensemble chain adaptation (eca) in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + output, info = step(x, auxilliary_input) + + return (output, i + 1, info[0].get("while_cond")) + + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) + + # produce all random keys that will be needed + + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + + return final_state, final_adaptation_state, info_history + + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +num_chains = 128 +ndims = 2 + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +# initialize the chains +initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh +) + +alpha = 1.9 +C = 0.1 +r_end=5e-3 +ensemble_observables=lambda x: x + +# burn-in with the unadjusted method # +kernel = umclmc.build_kernel(logdensity_fn) +save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), +) + + +final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + 100, + num_chains, + mesh, + ensemble_observables, + early_stop=True, + ) + + +# a = jnp.array([8.0, 4.0]) + +# def f(rng_key, x, args): +# return x + normal(rng_key, x.shape) + a, a + +# out = ensemble_execute_fn( +# func = f, +# rng_key = jax.random.PRNGKey(0), +# num_chains = 4, +# mesh = mesh, +# x = None, +# args = None, +# summary_statistics_fn = lambda y: a, +# ) + +# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_2.py b/tests/mcmc/minimal_repro_2.py new file mode 100644 index 000000000..c73f52681 --- /dev/null +++ b/tests/mcmc/minimal_repro_2.py @@ -0,0 +1,419 @@ +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +from blackjax.util import run_eca + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +# def eca_step( +# kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +# ): +# """ +# Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. +# """ + +# def _step(state_all, xs): +# """This function operates on a single device.""" +# ( +# state, +# adaptation_state, +# ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. +# ( +# _, +# keys_sampling, +# key_adaptation, +# ) = xs # keys_sampling.shape = (chains_per_device, ) + +# # update the state of all chains on this device +# state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + +# # combine all the chains to compute expectation values +# theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) +# Etheta = tree_map( +# lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") +# / num_chains, +# theta, +# ) + +# # use these to adapt the hyperparameters of the dynamics +# adaptation_state, info_to_be_stored = adaptation_update( +# adaptation_state, Etheta +# ) + +# return (state, adaptation_state), info_to_be_stored + +# if ensemble_info is not None: + +# def step(state_all, xs): +# (state, adaptation_state), info_to_be_stored = _step(state_all, xs) +# return (state, adaptation_state), ( +# info_to_be_stored, +# vmap(ensemble_info)(state.position), +# ) + +# return step + +# else: +# return _step + + + + +# def ensemble_execute_fn( +# func, +# rng_key, +# num_chains, +# mesh, +# x=None, +# args=None, +# summary_statistics_fn=lambda y: 0.0, +# ): +# """Given a sequential function +# func(rng_key, x, args) = y, +# evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. +# Args: +# x: array distributed over all decvices +# args: additional arguments for func, not distributed. +# summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. +# rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + +# Returns: +# y: array distributed over all decvices. Need not be of the same shape as x. +# Etheta: expected values of the summary statistics +# """ +# p, pscalar = PartitionSpec("chains"), PartitionSpec() + +# if x is None: +# X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) +# else: +# X = x + +# adaptation_update = lambda _, Etheta: (Etheta, None) + +# _F = eca_step( +# func, +# lambda y, info, key: summary_statistics_fn(y), +# adaptation_update, +# num_chains, +# ) + +# def F(x, keys): +# """This function operates on a single device. key is a random key for this device.""" +# y, summary_statistics = _F((x, args), (None, keys, None))[0] +# return y, summary_statistics + +# parallel_execute = shard_map( +# F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False +# ) + +# keys = device_put( +# split(rng_key, num_chains), NamedSharding(mesh, p) +# ) # random keys, distributed across devices +# # apply F in parallel +# return parallel_execute(X, keys) + +# def run_eca( +# rng_key, +# initial_state, +# kernel, +# adaptation, +# num_steps, +# num_chains, +# mesh, +# ensemble_info=None, +# early_stop=False, +# ): +# """ +# Run ensemble chain adaptation (eca) in parallel on multiple devices. +# ----------------------------------------------------- +# Args: +# rng_key: random key +# initial_state: initial state of the system +# kernel: kernel for the dynamics +# adaptation: adaptation object +# num_steps: number of steps to run +# num_chains: number of chains +# mesh: mesh for parallelization +# ensemble_info: function that takes the state of the system and returns some information about the ensemble +# early_stop: whether to stop early +# Returns: +# final_state: final state of the system +# final_adaptation_state: final adaptation state +# info_history: history of the information that was stored at each step (if early_stop is False, then this is None) +# """ + +# step = eca_step( +# kernel, +# adaptation.summary_statistics_fn, +# adaptation.update, +# num_chains, +# ensemble_info, +# ) + +# def all_steps(initial_state, keys_sampling, keys_adaptation): +# """This function operates on a single device. key is a random key for this device.""" + +# initial_state_all = (initial_state, adaptation.initial_state) + +# # run sampling +# xs = ( +# jnp.arange(num_steps), +# keys_sampling.T, +# keys_adaptation, +# ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + +# # ((a, Int) -> (a, Int)) +# def step_while(a): +# x, i, _ = a + +# auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + +# output, info = step(x, auxilliary_input) + +# return (output, i + 1, info[0].get("while_cond")) + +# if early_stop: +# final_state_all, i, _ = lax.while_loop( +# lambda a: ((a[1] < num_steps) & a[2]), +# step_while, +# (initial_state_all, 0, True), +# ) +# info_history = None + +# else: +# final_state_all, info_history = lax.scan(step, initial_state_all, xs) + +# final_state, final_adaptation_state = final_state_all +# return ( +# final_state, +# final_adaptation_state, +# info_history, +# ) # info history is composed of averages over all chains, so it is a couple of scalars + +# p, pscalar = PartitionSpec("chains"), PartitionSpec() +# parallel_execute = shard_map( +# all_steps, +# mesh=mesh, +# in_specs=(p, p, pscalar), +# out_specs=(p, pscalar, pscalar), +# check_rep=False, +# ) + +# # produce all random keys that will be needed + +# key_sampling, key_adaptation = split(rng_key) +# num_steps = jnp.array(num_steps).item() +# keys_adaptation = split(key_adaptation, num_steps) +# distribute_keys = lambda key, shape: device_put( +# split(key, shape), NamedSharding(mesh, p) +# ) # random keys, distributed across devices +# keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + +# # run sampling in parallel +# final_state, final_adaptation_state, info_history = parallel_execute( +# initial_state, keys_sampling, keys_adaptation +# ) + +# return final_state, final_adaptation_state, info_history + + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +# key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +num_chains = 128 +ndims = 2 + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +# # initialize the chains +# initial_state = umclmc.initialize( +# key_init, logdensity_fn, sample_init, num_chains, mesh +# ) + +# alpha = 1.9 +# C = 0.1 +# r_end=5e-3 +# ensemble_observables=lambda x: x + +# # burn-in with the unadjusted method # +# kernel = umclmc.build_kernel(logdensity_fn) +# save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +# adap = umclmc.Adaptation( +# ndims, +# alpha=alpha, +# bias_type=3, +# save_num=save_num, +# C=C, +# power=3.0 / 8.0, +# r_end=r_end, +# observables_for_bias=lambda position: jnp.square( +# transform(jax.flatten_util.ravel_pytree(position)[0]) +# ), +# ) + + +# final_state, final_adaptation_state, info1 = run_eca( +# key_umclmc, +# initial_state, +# kernel, +# adap, +# 100, +# num_chains, +# mesh, +# ensemble_observables, +# early_stop=True, +# ) + +from blackjax.mcmc.integrators import mclachlan_coefficients + +import sys +# sys.path.append(".") +# sys.path.append("../") +from blackjax.adaptation.ensemble_mclmc import emaus +# from blackjax.mcmc.alternate_emaus import emaus + + +# def emaus( +# logdensity_fn, +# sample_init, +# transform, +# ndims, +# num_steps1, +# num_steps2, +# num_chains, +# mesh, +# rng_key, +# alpha=1.9, +# save_frac=0.2, +# C=0.1, +# early_stop=True, +# r_end=5e-3, +# diagonal_preconditioning=True, +# integrator_coefficients=None, +# steps_per_sample=15, +# acc_prob=None, +# observables=lambda x: None, +# ensemble_observables=None, +# diagnostics=True, +# ): +# """ +# model: the target density object +# num_steps1: number of steps in the first phase +# num_steps2: number of steps in the second phase +# num_chains: number of chains +# mesh: the mesh object, used for distributing the computation across cpus and nodes +# rng_key: the random key +# alpha: L = sqrt{d}*alpha*variances +# save_frac: the fraction of samples used to estimate the fluctuation in the first phase +# C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) +# early_stop: whether to stop the first phase early +# r_end +# diagonal_preconditioning: whether to use diagonal preconditioning +# integrator_coefficients: the coefficients of the integrator +# steps_per_sample: the number of steps per sample +# acc_prob: the acceptance probability +# observables: the observables (for diagnostic use) +# ensemble_observables: observable calculated over the ensemble (for diagnostic use) +# diagnostics: whether to return diagnostics +# """ + +# # observables_for_bias, contract = bias(model) +# key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + +# # initialize the chains +# initial_state = umclmc.initialize( +# key_init, logdensity_fn, sample_init, num_chains, mesh +# ) + +# # burn-in with the unadjusted method # +# kernel = umclmc.build_kernel(logdensity_fn) +# save_num = (int)(jnp.rint(save_frac * num_steps1)) +# adap = umclmc.Adaptation( +# ndims, +# alpha=alpha, +# bias_type=3, +# save_num=save_num, +# C=C, +# power=3.0 / 8.0, +# r_end=r_end, +# observables_for_bias=lambda position: jnp.square( +# transform(jax.flatten_util.ravel_pytree(position)[0]) +# ), +# ) + +# final_state, final_adaptation_state, info1 = run_eca( +# key_umclmc, +# initial_state, +# kernel, +# adap, +# num_steps1, +# num_chains, +# mesh, +# ensemble_observables, +# early_stop=early_stop, +# ) + +key = jax.random.key(0) + +emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=num_chains, + mesh=mesh, + rng_key=key, + alpha=1.9, + C=0.1, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=True, + integrator_coefficients=mclachlan_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # adap=adap, + # kernel=kernel, + # initial_state=initial_state, + # key_umclmc=key_umclmc, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + +# a = jnp.array([8.0, 4.0]) + +# def f(rng_key, x, args): +# return x + normal(rng_key, x.shape) + a, a + +# out = ensemble_execute_fn( +# func = f, +# rng_key = jax.random.PRNGKey(0), +# num_chains = 4, +# mesh = mesh, +# x = None, +# args = None, +# summary_statistics_fn = lambda y: a, +# ) + +# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_3.py b/tests/mcmc/minimal_repro_3.py new file mode 100644 index 000000000..56e08caeb --- /dev/null +++ b/tests/mcmc/minimal_repro_3.py @@ -0,0 +1,514 @@ + + +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp + +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + + def _step(state_all, xs): + """This function operates on a single device.""" + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + + return (state, adaptation_state), info_to_be_stored + + if ensemble_info is not None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + + return step + + else: + return _step + + + + +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) + +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, + early_stop=False, +): + """ + Run ensemble chain adaptation (eca) in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + output, info = step(x, auxilliary_input) + + return (output, i + 1, info[0].get("while_cond")) + + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) + + # produce all random keys that will be needed + + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + + return final_state, final_adaptation_state, info_history + +# from blackjax.util import run_eca + + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + stepsize_adaptation_state: ( + Any # the state of the bisection algorithm to find a stepsize + ) + iteration: int + + +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, +)( + rng_key=key, + state=state, + step_size=adap.step_size, + num_integration_steps=adap.steps_per_sample, + L_proposal_factor=1.25, +) + + +class Adaptation: + def __init__( + self, + adaptation_state, + num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params + steps_per_sample=15, # L/eps + acc_prob_target=0.8, + observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep + observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains + contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions + ): + self.num_adaptation_samples = num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adaptation_state.step_size + + # Initialize the bisection for finding the step size + self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) + + self.initial_state = AdaptationState( + steps_per_sample, step_size, stepsize_adaptation_state, 0 + ) + + def summary_statistics_fn(self, state, info, rng_key): + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal( + state + ), # metric for bias: equipartition theorem gives todo... + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } + + def update(self, adaptation_state, Etheta): + acc_prob = Etheta["acceptance_probability"] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + true_bias = self.contract(Etheta["observables_for_bias"]) # remove + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "bias": true_bias, + "observables": Etheta["observables"], + } + + # Bisection to find step size + stepsize_adaptation_state, step_size = self.epsadap_update( + adaptation_state.stepsize_adaptation_state, + adaptation_state.step_size, + acc_prob, + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + stepsize_adaptation_state, + adaptation_state.iteration + 1, + ), + info_to_be_stored, + ) + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +num_chains = 128 + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +integrator_coefficients = mclachlan_coefficients + +acc_prob = None + +# initialize the chains +initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh +) + +diagonal_preconditioning = False +ndims = 2 + +alpha = 1.9 +C = 0.1 +r_end=5e-3 +ensemble_observables=lambda x: x + +# burn-in with the unadjusted method # +kernel = umclmc.build_kernel(logdensity_fn) +save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), +) + +final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + 100, + num_chains, + mesh, + ensemble_observables, + early_stop=True, + ) + +# refine the results with the adjusted method +_acc_prob = acc_prob +if integrator_coefficients is None: + high_dims = ndims > 200 + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: + _acc_prob = 0.9 if high_dims else 0.7 + +else: + _integrator_coefficients = integrator_coefficients + if acc_prob is None: + _acc_prob = 0.9 + +integrator = generate_isokinetic_integrator(_integrator_coefficients) +gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 +) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + +if diagonal_preconditioning: + inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) + +else: + inverse_mass_matrix = 1.0 + +kernel = build_kernel( + logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix +) +steps_per_sample = 15 +num_steps2 = 100 + + +initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + ) + +print(initial_state.position.shape, "bar\n\n") + +# pos = jax.random.normal(key_mclmc, shape=(num_chains, ndims)) + + + +# print("baz", logdensity_fn(pos)) + +# initial_state = HMCState( +# pos, logdensity_fn(pos[0]), jax.grad(logdensity_fn)(pos[0]) +# ) + + +num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) +num_adaptation_samples = ( + num_samples // 2 +) # number of samples after which the stepsize is fixed. + +adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, +) + + + +final_state, final_adaptation_state, info2 = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + ensemble_observables, +) + diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 886cdce0d..57c8aedf7 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -5,6 +5,7 @@ import chex import jax +# jax.config.update("jax_traceback_filtering", "off") import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np @@ -299,7 +300,7 @@ def run_emaus( key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") + mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") from blackjax.mcmc.integrators import mclachlan_coefficients From f7f8d86ed633e431872154b6a5e47e16bb668bed Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 6 Mar 2025 08:12:14 -0800 Subject: [PATCH 25/34] wip --- blackjax/adaptation/ensemble_mclmc.py | 18 ++++++++++-------- blackjax/adaptation/ensemble_umclmc.py | 4 ++-- blackjax/adaptation/mclmc_adaptation.py | 4 +++- blackjax/util.py | 10 +++++++++- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index bbc1bb5f7..cc5ff60c9 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -65,7 +65,7 @@ def __init__( acc_prob_target=0.8, observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains - contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions + contract=lambda x: 0.0, # just for diagnostics: observables for bias, contracted over dimensions ): self.num_adaptation_samples = num_adaptation_samples self.observables = observables @@ -106,7 +106,7 @@ def summary_statistics_fn(self, state, info, rng_key): def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) # remove + true_bias = self.contract(Etheta["observables_for_bias"]) info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -159,7 +159,6 @@ def while_steps_num(cond): def emaus( logdensity_fn, sample_init, - transform, ndims, num_steps1, num_steps2, @@ -175,9 +174,10 @@ def emaus( integrator_coefficients=None, steps_per_sample=10, acc_prob=None, - observables=lambda x: None, + observables_for_bias=lambda x: 0.0, ensemble_observables=None, diagnostics=True, + contract=lambda x: 0.0, ): """ model: the target density object @@ -210,7 +210,7 @@ def emaus( # burn-in with the unadjusted method # kernel = umclmc.build_kernel(logdensity_fn) - save_num = (int)(jnp.rint(save_frac * num_steps1)) + save_num = (jnp.rint(save_frac * num_steps1)).astype(int) adap = umclmc.Adaptation( ndims, alpha=alpha, @@ -219,9 +219,8 @@ def emaus( C=C, power=3.0 / 8.0, r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), + observables_for_bias=observables_for_bias, + contract=contract, ) final_state, final_adaptation_state, info1 = run_eca( @@ -285,8 +284,11 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, + contract=contract, + observables_for_bias=observables_for_bias, ) + final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 4068bee67..1f1bf518b 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -72,6 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) velocity = unravel_fn( @@ -82,9 +83,8 @@ def sequential_init(key, x, args): def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return 0 # -state.position * state.logdensity_grad + return -state.position * state.logdensity_grad - # TODO: restore! def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 60fd46359..16612c05c 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -50,6 +50,7 @@ def mclmc_find_L_and_step_size( desired_energy_var=5e-4, trust_in_estimate=1.5, num_effective_samples=150, + params=None, diagonal_preconditioning=True, ): """ @@ -105,7 +106,8 @@ def mclmc_find_L_and_step_size( ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState( + if params is None: + params = MCLMCAdaptationState( jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) diff --git a/blackjax/util.py b/blackjax/util.py index ee71af2b9..4668befee 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -10,7 +10,7 @@ from jax.random import normal, split from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map - +import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -352,11 +352,14 @@ def _step(state_all, xs): adaptation_state, info_to_be_stored = adaptation_update( adaptation_state, Etheta ) + return (state, adaptation_state), info_to_be_stored + if ensemble_info is not None: + def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) return (state, adaptation_state), ( @@ -381,6 +384,7 @@ def run_eca( ensemble_info=None, early_stop=False, ): + """ Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- @@ -413,6 +417,7 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): initial_state_all = (initial_state, adaptation.initial_state) + # run sampling xs = ( jnp.arange(num_steps), @@ -441,6 +446,8 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all return ( final_state, @@ -448,6 +455,7 @@ def step_while(a): info_history, ) # info history is composed of averages over all chains, so it is a couple of scalars + p, pscalar = PartitionSpec("chains"), PartitionSpec() parallel_execute = shard_map( all_steps, From a5eb4f4fc1a3bb3561c852822fbd5a5894191395 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 6 Mar 2025 08:19:06 -0800 Subject: [PATCH 26/34] wip --- blackjax/adaptation/mclmc_adaptation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 2a2d13831..5a673414f 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -52,7 +52,6 @@ def mclmc_find_L_and_step_size( num_effective_samples=150, params=None, diagonal_preconditioning=True, - params=None, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. From f40489803bf17e5da42fbba4aead158c50a24f89 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:07:43 -0400 Subject: [PATCH 27/34] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 3edae76bd..ea5f51c48 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -285,6 +285,10 @@ def emaus( num_samples // 2 ) # number of samples after which the stepsize is fixed. + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size.item() + ) + adap = Adaptation( final_adaptation_state, num_adaptation_samples, @@ -292,8 +296,6 @@ def emaus( _acc_prob, ) - - final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, From 9b00e2811e866948bc00e41f3d504bfcd4550ce3 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:28:51 -0400 Subject: [PATCH 28/34] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 10 +++------- blackjax/adaptation/ensemble_umclmc.py | 14 +++++++++----- blackjax/util.py | 10 +--------- tests/mcmc/test_sampling.py | 6 ++---- 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index e153f8a65..22de73887 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -101,9 +101,7 @@ def __init__( def summary_statistics_fn(self, state, info, rng_key): return { "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal( - state - ), # metric for bias: equipartition theorem gives todo... + "equipartition_diagonal": equipartition_diagonal(state), "observables": self.observables(state.position), "observables_for_bias": self.observables_for_bias(state.position), } @@ -111,7 +109,7 @@ def summary_statistics_fn(self, state, info, rng_key): def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) + true_bias = self.contract(Etheta["observables_for_bias"]) info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -179,7 +177,7 @@ def emaus( integrator_coefficients=None, steps_per_sample=15, acc_prob=None, - observables_for_bias=lambda x: 0.0, + observables_for_bias=lambda x: x, ensemble_observables=None, diagnostics=True, contract=lambda x: 0.0, @@ -205,7 +203,6 @@ def emaus( diagnostics: whether to return diagnostics """ - # observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) # initialize the chains @@ -297,7 +294,6 @@ def emaus( observables_for_bias=observables_for_bias, ) - final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 1f1bf518b..3cc62ae00 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -72,7 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) - + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) velocity = unravel_fn( @@ -83,8 +83,12 @@ def sequential_init(key, x, args): def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return -state.position * state.logdensity_grad + flat_pos, unflatten = jax.flatten_util.ravel_pytree(state.position) + flat_g, unravel_fn = ravel_pytree(state.logdensity_grad) + return unravel_fn(-flat_pos * flat_g) + # return 0 + # -state.position # * state.logdensity_grad def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" @@ -113,7 +117,9 @@ def ensemble_init(key, state, signs): summary_statistics_fn=summary_statistics_fn, ) - signs = -2.0 * (equipartition < 1.0) + 1.0 + flat_equi, _ = ravel_pytree(equipartition) + + signs = -2.0 * (flat_equi < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs ) @@ -122,7 +128,6 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): - new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) return jnp.concatenate((new_vals[None, :], history[:-1, :])) @@ -258,7 +263,6 @@ def update(self, adaptation_state, Etheta): history_observables = update_history( Etheta["observables_for_bias"], adaptation_state.history.observables ) - # history_observables = adaptation_state.history.observables history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) diff --git a/blackjax/util.py b/blackjax/util.py index 4668befee..ee71af2b9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -10,7 +10,7 @@ from jax.random import normal, split from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map -import jax + from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -352,14 +352,11 @@ def _step(state_all, xs): adaptation_state, info_to_be_stored = adaptation_update( adaptation_state, Etheta ) - return (state, adaptation_state), info_to_be_stored - if ensemble_info is not None: - def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) return (state, adaptation_state), ( @@ -384,7 +381,6 @@ def run_eca( ensemble_info=None, early_stop=False, ): - """ Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- @@ -417,7 +413,6 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): initial_state_all = (initial_state, adaptation.initial_state) - # run sampling xs = ( jnp.arange(num_steps), @@ -446,8 +441,6 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all return ( final_state, @@ -455,7 +448,6 @@ def step_while(a): info_history, ) # info history is composed of averages over all chains, so it is a couple of scalars - p, pscalar = PartitionSpec("chains"), PartitionSpec() parallel_execute = shard_map( all_steps, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 57c8aedf7..d2cbd1501 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -5,6 +5,7 @@ import chex import jax + # jax.config.update("jax_traceback_filtering", "off") import jax.numpy as jnp import jax.scipy.stats as stats @@ -296,11 +297,10 @@ def run_emaus( sample_init, logdensity_fn, ndims, - transform, key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + mesh = jax.sharding.Mesh(devices=jax.devices(), axis_names="chains") from blackjax.mcmc.integrators import mclachlan_coefficients @@ -309,7 +309,6 @@ def run_emaus( info, grads_per_step, _acc_prob, final_state = emaus( logdensity_fn=logdensity_fn, sample_init=sample_init, - transform=transform, ndims=ndims, num_steps1=100, num_steps2=300, @@ -602,7 +601,6 @@ def sample_init(key): samples = self.run_emaus( sample_init=sample_init, logdensity_fn=logdensity_fn, - transform=lambda x: x, ndims=2, key=inference_key, diagonal_preconditioning=True, From f35f98eaffa8dc8ffbc112a4dbf5cf010f7ab7e6 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:33:42 -0400 Subject: [PATCH 29/34] bug fix --- tests/mcmc/minimal_repro.py | 300 -------------------- tests/mcmc/minimal_repro_2.py | 419 --------------------------- tests/mcmc/minimal_repro_3.py | 514 ---------------------------------- 3 files changed, 1233 deletions(-) delete mode 100644 tests/mcmc/minimal_repro.py delete mode 100644 tests/mcmc/minimal_repro_2.py delete mode 100644 tests/mcmc/minimal_repro_3.py diff --git a/tests/mcmc/minimal_repro.py b/tests/mcmc/minimal_repro.py deleted file mode 100644 index 639b624fc..000000000 --- a/tests/mcmc/minimal_repro.py +++ /dev/null @@ -1,300 +0,0 @@ -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def eca_step( - kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -): - """ - Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. - """ - - def _step(state_all, xs): - """This function operates on a single device.""" - ( - state, - adaptation_state, - ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - ( - _, - keys_sampling, - key_adaptation, - ) = xs # keys_sampling.shape = (chains_per_device, ) - - # update the state of all chains on this device - state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - - # combine all the chains to compute expectation values - theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map( - lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") - / num_chains, - theta, - ) - - # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update( - adaptation_state, Etheta - ) - - return (state, adaptation_state), info_to_be_stored - - if ensemble_info is not None: - - def step(state_all, xs): - (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), ( - info_to_be_stored, - vmap(ensemble_info)(state.position), - ) - - return step - - else: - return _step - - - - -def ensemble_execute_fn( - func, - rng_key, - num_chains, - mesh, - x=None, - args=None, - summary_statistics_fn=lambda y: 0.0, -): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics - """ - p, pscalar = PartitionSpec("chains"), PartitionSpec() - - if x is None: - X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X = x - - adaptation_update = lambda _, Etheta: (Etheta, None) - - _F = eca_step( - func, - lambda y, info, key: summary_statistics_fn(y), - adaptation_update, - num_chains, - ) - - def F(x, keys): - """This function operates on a single device. key is a random key for this device.""" - y, summary_statistics = _F((x, args), (None, keys, None))[0] - return y, summary_statistics - - parallel_execute = shard_map( - F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False - ) - - keys = device_put( - split(rng_key, num_chains), NamedSharding(mesh, p) - ) # random keys, distributed across devices - # apply F in parallel - return parallel_execute(X, keys) - -def run_eca( - rng_key, - initial_state, - kernel, - adaptation, - num_steps, - num_chains, - mesh, - ensemble_info=None, - early_stop=False, -): - """ - Run ensemble chain adaptation (eca) in parallel on multiple devices. - ----------------------------------------------------- - Args: - rng_key: random key - initial_state: initial state of the system - kernel: kernel for the dynamics - adaptation: adaptation object - num_steps: number of steps to run - num_chains: number of chains - mesh: mesh for parallelization - ensemble_info: function that takes the state of the system and returns some information about the ensemble - early_stop: whether to stop early - Returns: - final_state: final state of the system - final_adaptation_state: final adaptation state - info_history: history of the information that was stored at each step (if early_stop is False, then this is None) - """ - - step = eca_step( - kernel, - adaptation.summary_statistics_fn, - adaptation.update, - num_chains, - ensemble_info, - ) - - def all_steps(initial_state, keys_sampling, keys_adaptation): - """This function operates on a single device. key is a random key for this device.""" - - initial_state_all = (initial_state, adaptation.initial_state) - - # run sampling - xs = ( - jnp.arange(num_steps), - keys_sampling.T, - keys_adaptation, - ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - - # ((a, Int) -> (a, Int)) - def step_while(a): - x, i, _ = a - - auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - - output, info = step(x, auxilliary_input) - - return (output, i + 1, info[0].get("while_cond")) - - if early_stop: - final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2]), - step_while, - (initial_state_all, 0, True), - ) - info_history = None - - else: - final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all - return ( - final_state, - final_adaptation_state, - info_history, - ) # info history is composed of averages over all chains, so it is a couple of scalars - - p, pscalar = PartitionSpec("chains"), PartitionSpec() - parallel_execute = shard_map( - all_steps, - mesh=mesh, - in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), - check_rep=False, - ) - - # produce all random keys that will be needed - - key_sampling, key_adaptation = split(rng_key) - num_steps = jnp.array(num_steps).item() - keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put( - split(key, shape), NamedSharding(mesh, p) - ) # random keys, distributed across devices - keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( - initial_state, keys_sampling, keys_adaptation - ) - - return final_state, final_adaptation_state, info_history - - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -num_chains = 128 -ndims = 2 - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -# initialize the chains -initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh -) - -alpha = 1.9 -C = 0.1 -r_end=5e-3 -ensemble_observables=lambda x: x - -# burn-in with the unadjusted method # -kernel = umclmc.build_kernel(logdensity_fn) -save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), -) - - -final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - 100, - num_chains, - mesh, - ensemble_observables, - early_stop=True, - ) - - -# a = jnp.array([8.0, 4.0]) - -# def f(rng_key, x, args): -# return x + normal(rng_key, x.shape) + a, a - -# out = ensemble_execute_fn( -# func = f, -# rng_key = jax.random.PRNGKey(0), -# num_chains = 4, -# mesh = mesh, -# x = None, -# args = None, -# summary_statistics_fn = lambda y: a, -# ) - -# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_2.py b/tests/mcmc/minimal_repro_2.py deleted file mode 100644 index c73f52681..000000000 --- a/tests/mcmc/minimal_repro_2.py +++ /dev/null @@ -1,419 +0,0 @@ -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -from blackjax.util import run_eca - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -# def eca_step( -# kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -# ): -# """ -# Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. -# """ - -# def _step(state_all, xs): -# """This function operates on a single device.""" -# ( -# state, -# adaptation_state, -# ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. -# ( -# _, -# keys_sampling, -# key_adaptation, -# ) = xs # keys_sampling.shape = (chains_per_device, ) - -# # update the state of all chains on this device -# state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - -# # combine all the chains to compute expectation values -# theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) -# Etheta = tree_map( -# lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") -# / num_chains, -# theta, -# ) - -# # use these to adapt the hyperparameters of the dynamics -# adaptation_state, info_to_be_stored = adaptation_update( -# adaptation_state, Etheta -# ) - -# return (state, adaptation_state), info_to_be_stored - -# if ensemble_info is not None: - -# def step(state_all, xs): -# (state, adaptation_state), info_to_be_stored = _step(state_all, xs) -# return (state, adaptation_state), ( -# info_to_be_stored, -# vmap(ensemble_info)(state.position), -# ) - -# return step - -# else: -# return _step - - - - -# def ensemble_execute_fn( -# func, -# rng_key, -# num_chains, -# mesh, -# x=None, -# args=None, -# summary_statistics_fn=lambda y: 0.0, -# ): -# """Given a sequential function -# func(rng_key, x, args) = y, -# evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. -# Args: -# x: array distributed over all decvices -# args: additional arguments for func, not distributed. -# summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. -# rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - -# Returns: -# y: array distributed over all decvices. Need not be of the same shape as x. -# Etheta: expected values of the summary statistics -# """ -# p, pscalar = PartitionSpec("chains"), PartitionSpec() - -# if x is None: -# X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) -# else: -# X = x - -# adaptation_update = lambda _, Etheta: (Etheta, None) - -# _F = eca_step( -# func, -# lambda y, info, key: summary_statistics_fn(y), -# adaptation_update, -# num_chains, -# ) - -# def F(x, keys): -# """This function operates on a single device. key is a random key for this device.""" -# y, summary_statistics = _F((x, args), (None, keys, None))[0] -# return y, summary_statistics - -# parallel_execute = shard_map( -# F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False -# ) - -# keys = device_put( -# split(rng_key, num_chains), NamedSharding(mesh, p) -# ) # random keys, distributed across devices -# # apply F in parallel -# return parallel_execute(X, keys) - -# def run_eca( -# rng_key, -# initial_state, -# kernel, -# adaptation, -# num_steps, -# num_chains, -# mesh, -# ensemble_info=None, -# early_stop=False, -# ): -# """ -# Run ensemble chain adaptation (eca) in parallel on multiple devices. -# ----------------------------------------------------- -# Args: -# rng_key: random key -# initial_state: initial state of the system -# kernel: kernel for the dynamics -# adaptation: adaptation object -# num_steps: number of steps to run -# num_chains: number of chains -# mesh: mesh for parallelization -# ensemble_info: function that takes the state of the system and returns some information about the ensemble -# early_stop: whether to stop early -# Returns: -# final_state: final state of the system -# final_adaptation_state: final adaptation state -# info_history: history of the information that was stored at each step (if early_stop is False, then this is None) -# """ - -# step = eca_step( -# kernel, -# adaptation.summary_statistics_fn, -# adaptation.update, -# num_chains, -# ensemble_info, -# ) - -# def all_steps(initial_state, keys_sampling, keys_adaptation): -# """This function operates on a single device. key is a random key for this device.""" - -# initial_state_all = (initial_state, adaptation.initial_state) - -# # run sampling -# xs = ( -# jnp.arange(num_steps), -# keys_sampling.T, -# keys_adaptation, -# ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - -# # ((a, Int) -> (a, Int)) -# def step_while(a): -# x, i, _ = a - -# auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - -# output, info = step(x, auxilliary_input) - -# return (output, i + 1, info[0].get("while_cond")) - -# if early_stop: -# final_state_all, i, _ = lax.while_loop( -# lambda a: ((a[1] < num_steps) & a[2]), -# step_while, -# (initial_state_all, 0, True), -# ) -# info_history = None - -# else: -# final_state_all, info_history = lax.scan(step, initial_state_all, xs) - -# final_state, final_adaptation_state = final_state_all -# return ( -# final_state, -# final_adaptation_state, -# info_history, -# ) # info history is composed of averages over all chains, so it is a couple of scalars - -# p, pscalar = PartitionSpec("chains"), PartitionSpec() -# parallel_execute = shard_map( -# all_steps, -# mesh=mesh, -# in_specs=(p, p, pscalar), -# out_specs=(p, pscalar, pscalar), -# check_rep=False, -# ) - -# # produce all random keys that will be needed - -# key_sampling, key_adaptation = split(rng_key) -# num_steps = jnp.array(num_steps).item() -# keys_adaptation = split(key_adaptation, num_steps) -# distribute_keys = lambda key, shape: device_put( -# split(key, shape), NamedSharding(mesh, p) -# ) # random keys, distributed across devices -# keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - -# # run sampling in parallel -# final_state, final_adaptation_state, info_history = parallel_execute( -# initial_state, keys_sampling, keys_adaptation -# ) - -# return final_state, final_adaptation_state, info_history - - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -# key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -num_chains = 128 -ndims = 2 - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -# # initialize the chains -# initial_state = umclmc.initialize( -# key_init, logdensity_fn, sample_init, num_chains, mesh -# ) - -# alpha = 1.9 -# C = 0.1 -# r_end=5e-3 -# ensemble_observables=lambda x: x - -# # burn-in with the unadjusted method # -# kernel = umclmc.build_kernel(logdensity_fn) -# save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -# adap = umclmc.Adaptation( -# ndims, -# alpha=alpha, -# bias_type=3, -# save_num=save_num, -# C=C, -# power=3.0 / 8.0, -# r_end=r_end, -# observables_for_bias=lambda position: jnp.square( -# transform(jax.flatten_util.ravel_pytree(position)[0]) -# ), -# ) - - -# final_state, final_adaptation_state, info1 = run_eca( -# key_umclmc, -# initial_state, -# kernel, -# adap, -# 100, -# num_chains, -# mesh, -# ensemble_observables, -# early_stop=True, -# ) - -from blackjax.mcmc.integrators import mclachlan_coefficients - -import sys -# sys.path.append(".") -# sys.path.append("../") -from blackjax.adaptation.ensemble_mclmc import emaus -# from blackjax.mcmc.alternate_emaus import emaus - - -# def emaus( -# logdensity_fn, -# sample_init, -# transform, -# ndims, -# num_steps1, -# num_steps2, -# num_chains, -# mesh, -# rng_key, -# alpha=1.9, -# save_frac=0.2, -# C=0.1, -# early_stop=True, -# r_end=5e-3, -# diagonal_preconditioning=True, -# integrator_coefficients=None, -# steps_per_sample=15, -# acc_prob=None, -# observables=lambda x: None, -# ensemble_observables=None, -# diagnostics=True, -# ): -# """ -# model: the target density object -# num_steps1: number of steps in the first phase -# num_steps2: number of steps in the second phase -# num_chains: number of chains -# mesh: the mesh object, used for distributing the computation across cpus and nodes -# rng_key: the random key -# alpha: L = sqrt{d}*alpha*variances -# save_frac: the fraction of samples used to estimate the fluctuation in the first phase -# C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) -# early_stop: whether to stop the first phase early -# r_end -# diagonal_preconditioning: whether to use diagonal preconditioning -# integrator_coefficients: the coefficients of the integrator -# steps_per_sample: the number of steps per sample -# acc_prob: the acceptance probability -# observables: the observables (for diagnostic use) -# ensemble_observables: observable calculated over the ensemble (for diagnostic use) -# diagnostics: whether to return diagnostics -# """ - -# # observables_for_bias, contract = bias(model) -# key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - -# # initialize the chains -# initial_state = umclmc.initialize( -# key_init, logdensity_fn, sample_init, num_chains, mesh -# ) - -# # burn-in with the unadjusted method # -# kernel = umclmc.build_kernel(logdensity_fn) -# save_num = (int)(jnp.rint(save_frac * num_steps1)) -# adap = umclmc.Adaptation( -# ndims, -# alpha=alpha, -# bias_type=3, -# save_num=save_num, -# C=C, -# power=3.0 / 8.0, -# r_end=r_end, -# observables_for_bias=lambda position: jnp.square( -# transform(jax.flatten_util.ravel_pytree(position)[0]) -# ), -# ) - -# final_state, final_adaptation_state, info1 = run_eca( -# key_umclmc, -# initial_state, -# kernel, -# adap, -# num_steps1, -# num_chains, -# mesh, -# ensemble_observables, -# early_stop=early_stop, -# ) - -key = jax.random.key(0) - -emaus( - logdensity_fn=logdensity_fn, - sample_init=sample_init, - transform=transform, - ndims=ndims, - num_steps1=100, - num_steps2=300, - num_chains=num_chains, - mesh=mesh, - rng_key=key, - alpha=1.9, - C=0.1, - early_stop=1, - r_end=1e-2, - diagonal_preconditioning=True, - integrator_coefficients=mclachlan_coefficients, - steps_per_sample=15, - acc_prob=None, - ensemble_observables=lambda x: x, - # adap=adap, - # kernel=kernel, - # initial_state=initial_state, - # key_umclmc=key_umclmc, - # ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - -# a = jnp.array([8.0, 4.0]) - -# def f(rng_key, x, args): -# return x + normal(rng_key, x.shape) + a, a - -# out = ensemble_execute_fn( -# func = f, -# rng_key = jax.random.PRNGKey(0), -# num_chains = 4, -# mesh = mesh, -# x = None, -# args = None, -# summary_statistics_fn = lambda y: a, -# ) - -# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_3.py b/tests/mcmc/minimal_repro_3.py deleted file mode 100644 index 56e08caeb..000000000 --- a/tests/mcmc/minimal_repro_3.py +++ /dev/null @@ -1,514 +0,0 @@ - - -from typing import Any, NamedTuple - -import jax -import jax.numpy as jnp - -import blackjax.adaptation.ensemble_umclmc as umclmc -from blackjax.adaptation.ensemble_umclmc import ( - equipartition_diagonal, - equipartition_diagonal_loss, -) -from blackjax.adaptation.step_size import bisection_monotonic_fn -from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt -from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.integrators import ( - generate_isokinetic_integrator, - mclachlan_coefficients, - omelyan_coefficients, -) -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def eca_step( - kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -): - """ - Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. - """ - - def _step(state_all, xs): - """This function operates on a single device.""" - ( - state, - adaptation_state, - ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - ( - _, - keys_sampling, - key_adaptation, - ) = xs # keys_sampling.shape = (chains_per_device, ) - - # update the state of all chains on this device - state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - - # combine all the chains to compute expectation values - theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map( - lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") - / num_chains, - theta, - ) - - # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update( - adaptation_state, Etheta - ) - - return (state, adaptation_state), info_to_be_stored - - if ensemble_info is not None: - - def step(state_all, xs): - (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), ( - info_to_be_stored, - vmap(ensemble_info)(state.position), - ) - - return step - - else: - return _step - - - - -def ensemble_execute_fn( - func, - rng_key, - num_chains, - mesh, - x=None, - args=None, - summary_statistics_fn=lambda y: 0.0, -): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics - """ - p, pscalar = PartitionSpec("chains"), PartitionSpec() - - if x is None: - X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X = x - - adaptation_update = lambda _, Etheta: (Etheta, None) - - _F = eca_step( - func, - lambda y, info, key: summary_statistics_fn(y), - adaptation_update, - num_chains, - ) - - def F(x, keys): - """This function operates on a single device. key is a random key for this device.""" - y, summary_statistics = _F((x, args), (None, keys, None))[0] - return y, summary_statistics - - parallel_execute = shard_map( - F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False - ) - - keys = device_put( - split(rng_key, num_chains), NamedSharding(mesh, p) - ) # random keys, distributed across devices - # apply F in parallel - return parallel_execute(X, keys) - -def run_eca( - rng_key, - initial_state, - kernel, - adaptation, - num_steps, - num_chains, - mesh, - ensemble_info=None, - early_stop=False, -): - """ - Run ensemble chain adaptation (eca) in parallel on multiple devices. - ----------------------------------------------------- - Args: - rng_key: random key - initial_state: initial state of the system - kernel: kernel for the dynamics - adaptation: adaptation object - num_steps: number of steps to run - num_chains: number of chains - mesh: mesh for parallelization - ensemble_info: function that takes the state of the system and returns some information about the ensemble - early_stop: whether to stop early - Returns: - final_state: final state of the system - final_adaptation_state: final adaptation state - info_history: history of the information that was stored at each step (if early_stop is False, then this is None) - """ - - step = eca_step( - kernel, - adaptation.summary_statistics_fn, - adaptation.update, - num_chains, - ensemble_info, - ) - - def all_steps(initial_state, keys_sampling, keys_adaptation): - """This function operates on a single device. key is a random key for this device.""" - - initial_state_all = (initial_state, adaptation.initial_state) - - # run sampling - xs = ( - jnp.arange(num_steps), - keys_sampling.T, - keys_adaptation, - ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - - # ((a, Int) -> (a, Int)) - def step_while(a): - x, i, _ = a - - auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - - output, info = step(x, auxilliary_input) - - return (output, i + 1, info[0].get("while_cond")) - - if early_stop: - final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2]), - step_while, - (initial_state_all, 0, True), - ) - info_history = None - - else: - final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all - return ( - final_state, - final_adaptation_state, - info_history, - ) # info history is composed of averages over all chains, so it is a couple of scalars - - p, pscalar = PartitionSpec("chains"), PartitionSpec() - parallel_execute = shard_map( - all_steps, - mesh=mesh, - in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), - check_rep=False, - ) - - # produce all random keys that will be needed - - key_sampling, key_adaptation = split(rng_key) - num_steps = jnp.array(num_steps).item() - keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put( - split(key, shape), NamedSharding(mesh, p) - ) # random keys, distributed across devices - keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( - initial_state, keys_sampling, keys_adaptation - ) - - return final_state, final_adaptation_state, info_history - -# from blackjax.util import run_eca - - - -class AdaptationState(NamedTuple): - steps_per_sample: float - step_size: float - stepsize_adaptation_state: ( - Any # the state of the bisection algorithm to find a stepsize - ) - iteration: int - - -build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( - logdensity_fn=logdensity_fn, - integrator=integrator, - inverse_mass_matrix=inverse_mass_matrix, -)( - rng_key=key, - state=state, - step_size=adap.step_size, - num_integration_steps=adap.steps_per_sample, - L_proposal_factor=1.25, -) - - -class Adaptation: - def __init__( - self, - adaptation_state, - num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params - steps_per_sample=15, # L/eps - acc_prob_target=0.8, - observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep - observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains - contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions - ): - self.num_adaptation_samples = num_adaptation_samples - self.observables = observables - self.observables_for_bias = observables_for_bias - self.contract = contract - - # Determine the initial hyperparameters # - - # stepsize # - # if we switched to the more accurate integrator we can use longer step size - # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. - # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 - # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 - # With the current eps, we had sigma^2 = EEVPD * d for N = 1. - # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adaptation_state.step_size - - # Initialize the bisection for finding the step size - self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) - - self.initial_state = AdaptationState( - steps_per_sample, step_size, stepsize_adaptation_state, 0 - ) - - def summary_statistics_fn(self, state, info, rng_key): - return { - "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal( - state - ), # metric for bias: equipartition theorem gives todo... - "observables": self.observables(state.position), - "observables_for_bias": self.observables_for_bias(state.position), - } - - def update(self, adaptation_state, Etheta): - acc_prob = Etheta["acceptance_probability"] - equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) # remove - - info_to_be_stored = { - "L": adaptation_state.step_size * adaptation_state.steps_per_sample, - "steps_per_sample": adaptation_state.steps_per_sample, - "step_size": adaptation_state.step_size, - "acc_prob": acc_prob, - "equi_diag": equi_diag, - "bias": true_bias, - "observables": Etheta["observables"], - } - - # Bisection to find step size - stepsize_adaptation_state, step_size = self.epsadap_update( - adaptation_state.stepsize_adaptation_state, - adaptation_state.step_size, - acc_prob, - ) - - return ( - AdaptationState( - adaptation_state.steps_per_sample, - step_size, - stepsize_adaptation_state, - adaptation_state.iteration + 1, - ), - info_to_be_stored, - ) - - -def bias(model): - """should be transfered to benchmarks/""" - - def observables(position): - return jnp.square(model.transform(position)) - - def contract(sampler_E_x2): - bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 - return jnp.array([jnp.max(bsq), jnp.average(bsq)]) - - return observables, contract - - -def while_steps_num(cond): - if jnp.all(cond): - return len(cond) - else: - return jnp.argmin(cond) + 1 - - - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -num_chains = 128 - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -integrator_coefficients = mclachlan_coefficients - -acc_prob = None - -# initialize the chains -initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh -) - -diagonal_preconditioning = False -ndims = 2 - -alpha = 1.9 -C = 0.1 -r_end=5e-3 -ensemble_observables=lambda x: x - -# burn-in with the unadjusted method # -kernel = umclmc.build_kernel(logdensity_fn) -save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), -) - -final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - 100, - num_chains, - mesh, - ensemble_observables, - early_stop=True, - ) - -# refine the results with the adjusted method -_acc_prob = acc_prob -if integrator_coefficients is None: - high_dims = ndims > 200 - _integrator_coefficients = ( - omelyan_coefficients if high_dims else mclachlan_coefficients - ) - if acc_prob is None: - _acc_prob = 0.9 if high_dims else 0.7 - -else: - _integrator_coefficients = integrator_coefficients - if acc_prob is None: - _acc_prob = 0.9 - -integrator = generate_isokinetic_integrator(_integrator_coefficients) -gradient_calls_per_step = ( - len(_integrator_coefficients) // 2 -) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. - -if diagonal_preconditioning: - inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) - - # scale the stepsize so that it reflects averag scale change of the preconditioning - average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) - final_adaptation_state = final_adaptation_state._replace( - step_size=final_adaptation_state.step_size / average_scale_change - ) - -else: - inverse_mass_matrix = 1.0 - -kernel = build_kernel( - logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix -) -steps_per_sample = 15 -num_steps2 = 100 - - -initial_state = HMCState( - final_state.position, final_state.logdensity, final_state.logdensity_grad - ) - -print(initial_state.position.shape, "bar\n\n") - -# pos = jax.random.normal(key_mclmc, shape=(num_chains, ndims)) - - - -# print("baz", logdensity_fn(pos)) - -# initial_state = HMCState( -# pos, logdensity_fn(pos[0]), jax.grad(logdensity_fn)(pos[0]) -# ) - - -num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) -num_adaptation_samples = ( - num_samples // 2 -) # number of samples after which the stepsize is fixed. - -adap = Adaptation( - final_adaptation_state, - num_adaptation_samples, - steps_per_sample, - _acc_prob, -) - - - -final_state, final_adaptation_state, info2 = run_eca( - key_mclmc, - initial_state, - kernel, - adap, - num_samples, - num_chains, - mesh, - ensemble_observables, -) - From b55ab0df1666efc8757fd1999c98e2146dfbcf22 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:36:39 -0400 Subject: [PATCH 30/34] bug fix --- blackjax/adaptation/ensemble_umclmc.py | 1 + blackjax/mcmc/alternate_emaus.py | 85 -------------------------- 2 files changed, 1 insertion(+), 85 deletions(-) delete mode 100644 blackjax/mcmc/alternate_emaus.py diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 3cc62ae00..d5df4ee92 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -128,6 +128,7 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): + new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) return jnp.concatenate((new_vals[None, :], history[:-1, :])) diff --git a/blackjax/mcmc/alternate_emaus.py b/blackjax/mcmc/alternate_emaus.py deleted file mode 100644 index 6010bab73..000000000 --- a/blackjax/mcmc/alternate_emaus.py +++ /dev/null @@ -1,85 +0,0 @@ -import jax -import jax.numpy as jnp -from blackjax.util import run_eca -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def emaus( - logdensity_fn, - sample_init, - transform, - ndims, - num_steps1, - num_steps2, - num_chains, - mesh, - rng_key, - alpha=1.9, - save_frac=0.2, - C=0.1, - early_stop=True, - r_end=5e-3, - diagonal_preconditioning=True, - integrator_coefficients=None, - steps_per_sample=15, - acc_prob=None, - observables=lambda x: None, - ensemble_observables=None, - diagnostics=True, -): - """ - model: the target density object - num_steps1: number of steps in the first phase - num_steps2: number of steps in the second phase - num_chains: number of chains - mesh: the mesh object, used for distributing the computation across cpus and nodes - rng_key: the random key - alpha: L = sqrt{d}*alpha*variances - save_frac: the fraction of samples used to estimate the fluctuation in the first phase - C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) - early_stop: whether to stop the first phase early - r_end - diagonal_preconditioning: whether to use diagonal preconditioning - integrator_coefficients: the coefficients of the integrator - steps_per_sample: the number of steps per sample - acc_prob: the acceptance probability - observables: the observables (for diagnostic use) - ensemble_observables: observable calculated over the ensemble (for diagnostic use) - diagnostics: whether to return diagnostics - """ - - # observables_for_bias, contract = bias(model) - key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - - # initialize the chains - initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh - ) - - # burn-in with the unadjusted method # - kernel = umclmc.build_kernel(logdensity_fn) - save_num = (int)(jnp.rint(save_frac * num_steps1)) - adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), - ) - - final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - num_steps1, - num_chains, - mesh, - ensemble_observables, - early_stop=early_stop, - ) \ No newline at end of file From e6da5c2a951685ffc3c9683ed8195f2b496e72a3 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Mon, 10 Mar 2025 12:19:58 -0700 Subject: [PATCH 31/34] changes --- blackjax/adaptation/ensemble_mclmc.py | 6 +++--- blackjax/util.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index ebc87cd5a..8c16008b2 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -222,7 +222,7 @@ def emaus( contract=contract, ) - final_state, final_adaptation_state, info1 = run_eca( + final_state, final_adaptation_state, info1, steps_done_phase_1 = run_eca( key_umclmc, initial_state, kernel, @@ -288,7 +288,7 @@ def emaus( ) - final_state, final_adaptation_state, info2 = run_eca( + final_state, final_adaptation_state, info2, steps_done_phase_2 = run_eca( key_mclmc, initial_state, kernel, @@ -300,7 +300,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": info1, "phase_2": info2} + info = {"phase_1": {'steps_done' : steps_done_phase_1}, "phase_2": info2} else: info = None diff --git a/blackjax/util.py b/blackjax/util.py index 4668befee..2c87d0e8e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -441,11 +441,12 @@ def step_while(a): step_while, (initial_state_all, 0, True), ) + steps_done = i info_history = None else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) - + steps_done = num_steps final_state, final_adaptation_state = final_state_all @@ -453,6 +454,7 @@ def step_while(a): final_state, final_adaptation_state, info_history, + steps_done ) # info history is composed of averages over all chains, so it is a couple of scalars @@ -461,7 +463,7 @@ def step_while(a): all_steps, mesh=mesh, in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), + out_specs=(p, pscalar, pscalar, pscalar), check_rep=False, ) @@ -476,11 +478,13 @@ def step_while(a): keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( + final_state, final_adaptation_state, info_history, steps_done = parallel_execute( initial_state, keys_sampling, keys_adaptation ) - return final_state, final_adaptation_state, info_history + + + return final_state, final_adaptation_state, info_history, steps_done def ensemble_execute_fn( From 0bd14143170dad939b7d1c30a4c4df86381c5231 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 15:24:24 -0400 Subject: [PATCH 32/34] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index cb763d94a..0b418c9e9 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -294,8 +294,7 @@ def emaus( observables_for_bias=observables_for_bias, ) - - final_state, final_adaptation_state, info2, steps_done_phase_2 = run_eca( + final_state, final_adaptation_state, info2, _ = run_eca( key_mclmc, initial_state, kernel, @@ -307,7 +306,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": {'steps_done' : steps_done_phase_1}, "phase_2": info2} + info = {"phase_1": {"steps_done": steps_done_phase_1}, "phase_2": info2} else: info = None From 13a375ca0ba1cb516992f8114b535945bd6da767 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 15:29:01 -0400 Subject: [PATCH 33/34] bug fix --- blackjax/util.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 91dcc01db..d6a2d15ca 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -442,14 +442,13 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) steps_done = num_steps - final_state, final_adaptation_state = final_state_all return ( final_state, final_adaptation_state, info_history, - steps_done + steps_done, ) # info history is composed of averages over all chains, so it is a couple of scalars p, pscalar = PartitionSpec("chains"), PartitionSpec() @@ -476,8 +475,6 @@ def step_while(a): initial_state, keys_sampling, keys_adaptation ) - - return final_state, final_adaptation_state, info_history, steps_done From 3906c0e370aafc0e6710c5465f575e7031fc0879 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 1 Apr 2025 18:22:46 -0400 Subject: [PATCH 34/34] emaus diagnostics --- blackjax/adaptation/ensemble_mclmc.py | 2 +- blackjax/util.py | 40 ++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 0b418c9e9..7b47b2164 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -306,7 +306,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": {"steps_done": steps_done_phase_1}, "phase_2": info2} + info = {"phase_1": info1, "phase_2": info2} else: info = None diff --git a/blackjax/util.py b/blackjax/util.py index d6a2d15ca..b074b79e7 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -420,15 +420,35 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): keys_adaptation, ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - # ((a, Int) -> (a, Int)) + EEVPD = jnp.zeros((num_steps,)) + EEVPD_wanted = jnp.zeros((num_steps,)) + L = jnp.zeros((num_steps,)) + entropy = jnp.zeros((num_steps,)) + equi_diag = jnp.zeros((num_steps,)) + equi_full = jnp.zeros((num_steps,)) + observables = jnp.zeros((num_steps,)) + r_avg = jnp.zeros((num_steps,)) + r_max = jnp.zeros((num_steps,)) + step_size = jnp.zeros((num_steps,)) + def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - output, info = step(x, auxilliary_input) + output, (info, pos) = step(x, auxilliary_input) + EEVPD.at[i].set(info.get("EEVPD")) + EEVPD_wanted.at[i].set(info.get("EEVPD_wanted")) + L.at[i].set(info.get("L")) + entropy.at[i].set(info.get("entropy")) + equi_diag.at[i].set(info.get("equi_diag")) + equi_full.at[i].set(info.get("equi_full")) + observables.at[i].set(info.get("observables")) + r_avg.at[i].set(info.get("r_avg")) + r_max.at[i].set(info.get("r_max")) + step_size.at[i].set(info.get("step_size")) - return (output, i + 1, info[0].get("while_cond")) + return (output, i + 1, info.get("while_cond")) if early_stop: final_state_all, i, _ = lax.while_loop( @@ -437,7 +457,19 @@ def step_while(a): (initial_state_all, 0, True), ) steps_done = i - info_history = None + info_history = { + "EEVPD": EEVPD, + "EEVPD_wanted": EEVPD_wanted, + "L": L, + "entropy": entropy, + "equi_diag": equi_diag, + "equi_full": equi_full, + "observables": observables, + "r_avg": r_avg, + "r_max": r_max, + "step_size": step_size, + "steps_done": steps_done, + } else: final_state_all, info_history = lax.scan(step, initial_state_all, xs)