66from operator import itemgetter
77import warnings
88
9+ import jax
910from jax import eval_shape , random , vmap
1011from jax .lax import stop_gradient
1112import jax .numpy as jnp
@@ -33,6 +34,9 @@ class ELBO:
3334
3435 :param num_particles: The number of particles/samples used to form the ELBO
3536 (gradient) estimators.
37+ :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
38+ num_particles-many particles in parallel. If False use `jax.lax.map`.
39+ Defaults to True.
3640 """
3741
3842 """
@@ -42,8 +46,9 @@ class ELBO:
4246 """
4347 can_infer_discrete = False
4448
45- def __init__ (self , num_particles = 1 ):
49+ def __init__ (self , num_particles = 1 , vectorize_particles = True ):
4650 self .num_particles = num_particles
51+ self .vectorize_particles = vectorize_particles
4752
4853 def loss (self , rng_key , param_map , model , guide , * args , ** kwargs ):
4954 """
@@ -108,11 +113,11 @@ class Trace_ELBO(ELBO):
108113
109114 :param num_particles: The number of particles/samples used to form the ELBO
110115 (gradient) estimators.
116+ :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
117+ num_particles-many particles in parallel. If False use `jax.lax.map`.
118+ Defaults to True.
111119 """
112120
113- def __init__ (self , num_particles = 1 ):
114- self .num_particles = num_particles
115-
116121 def loss_with_mutable_state (
117122 self , rng_key , param_map , model , guide , * args , ** kwargs
118123 ):
@@ -163,7 +168,10 @@ def single_particle_elbo(rng_key):
163168 return {"loss" : - elbo , "mutable_state" : mutable_state }
164169 else :
165170 rng_keys = random .split (rng_key , self .num_particles )
166- elbos , mutable_state = vmap (single_particle_elbo )(rng_keys )
171+ if self .vectorize_particles :
172+ elbos , mutable_state = vmap (single_particle_elbo )(rng_keys )
173+ else :
174+ elbos , mutable_state = jax .lax .map (single_particle_elbo , rng_keys )
167175 return {"loss" : - jnp .mean (elbos ), "mutable_state" : mutable_state }
168176
169177
@@ -291,7 +299,10 @@ def single_particle_elbo(rng_key):
291299 return {"loss" : - elbo , "mutable_state" : mutable_state }
292300 else :
293301 rng_keys = random .split (rng_key , self .num_particles )
294- elbos , mutable_state = vmap (single_particle_elbo )(rng_keys )
302+ if self .vectorize_particles :
303+ elbos , mutable_state = vmap (single_particle_elbo )(rng_keys )
304+ else :
305+ elbos , mutable_state = jax .lax .map (single_particle_elbo , rng_keys )
295306 return {"loss" : - jnp .mean (elbos ), "mutable_state" : mutable_state }
296307
297308
@@ -311,6 +322,9 @@ class RenyiELBO(ELBO):
311322 Here :math:`\alpha \neq 1`. Default is 0.
312323 :param num_particles: The number of particles/samples
313324 used to form the objective (gradient) estimator. Default is 2.
325+ :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
326+ num_particles-many particles in parallel. If False use `jax.lax.map`.
327+ Defaults to True.
314328
315329 Example::
316330
@@ -427,7 +441,10 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
427441 )
428442
429443 rng_keys = random .split (rng_key , self .num_particles )
430- elbos , common_plate_scale = vmap (single_particle_elbo )(rng_keys )
444+ if self .vectorize_particles :
445+ elbos , common_plate_scale = vmap (single_particle_elbo )(rng_keys )
446+ else :
447+ elbos , common_plate_scale = jax .lax .map (single_particle_elbo , rng_keys )
431448 assert common_plate_scale .shape == (self .num_particles ,)
432449 assert elbos .shape [0 ] == self .num_particles
433450 scaled_elbos = (1.0 - self .alpha ) * elbos
@@ -695,8 +712,10 @@ class TraceGraph_ELBO(ELBO):
695712
696713 can_infer_discrete = True
697714
698- def __init__ (self , num_particles = 1 ):
699- super ().__init__ (num_particles = num_particles )
715+ def __init__ (self , num_particles = 1 , vectorize_particles = True ):
716+ super ().__init__ (
717+ num_particles = num_particles , vectorize_particles = vectorize_particles
718+ )
700719
701720 def loss (self , rng_key , param_map , model , guide , * args , ** kwargs ):
702721 """
@@ -771,7 +790,10 @@ def single_particle_elbo(rng_key):
771790 return - single_particle_elbo (rng_key )
772791 else :
773792 rng_keys = random .split (rng_key , self .num_particles )
774- return - jnp .mean (vmap (single_particle_elbo )(rng_keys ))
793+ if self .vectorize_particles :
794+ return - jnp .mean (vmap (single_particle_elbo )(rng_keys ))
795+ else :
796+ return - jnp .mean (jax .lax .map (single_particle_elbo , rng_keys ))
775797
776798
777799def get_importance_trace_enum (
@@ -953,9 +975,13 @@ class TraceEnum_ELBO(ELBO):
953975
954976 can_infer_discrete = True
955977
956- def __init__ (self , num_particles = 1 , max_plate_nesting = float ("inf" )):
978+ def __init__ (
979+ self , num_particles = 1 , max_plate_nesting = float ("inf" ), vectorize_particles = True
980+ ):
957981 self .max_plate_nesting = max_plate_nesting
958- super ().__init__ (num_particles = num_particles )
982+ super ().__init__ (
983+ num_particles = num_particles , vectorize_particles = vectorize_particles
984+ )
959985
960986 def loss (self , rng_key , param_map , model , guide , * args , ** kwargs ):
961987 def single_particle_elbo (rng_key ):
@@ -1128,4 +1154,7 @@ def single_particle_elbo(rng_key):
11281154 return - single_particle_elbo (rng_key )
11291155 else :
11301156 rng_keys = random .split (rng_key , self .num_particles )
1131- return - jnp .mean (vmap (single_particle_elbo )(rng_keys ))
1157+ if self .vectorize_particles :
1158+ return - jnp .mean (vmap (single_particle_elbo )(rng_keys ))
1159+ else :
1160+ return - jnp .mean (jax .lax .map (single_particle_elbo , rng_keys ))
0 commit comments