Skip to content

Commit 56b88c3

Browse files
authored
Add vectorized_particles to ELBO (#1624)
* add vectorized_particles to ELBO * address comments
1 parent 4e37df3 commit 56b88c3

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

numpyro/infer/elbo.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from operator import itemgetter
77
import warnings
88

9+
import jax
910
from jax import eval_shape, random, vmap
1011
from jax.lax import stop_gradient
1112
import jax.numpy as jnp
@@ -33,6 +34,9 @@ class ELBO:
3334
3435
:param num_particles: The number of particles/samples used to form the ELBO
3536
(gradient) estimators.
37+
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
38+
num_particles-many particles in parallel. If False use `jax.lax.map`.
39+
Defaults to True.
3640
"""
3741

3842
"""
@@ -42,8 +46,9 @@ class ELBO:
4246
"""
4347
can_infer_discrete = False
4448

45-
def __init__(self, num_particles=1):
49+
def __init__(self, num_particles=1, vectorize_particles=True):
4650
self.num_particles = num_particles
51+
self.vectorize_particles = vectorize_particles
4752

4853
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
4954
"""
@@ -108,11 +113,11 @@ class Trace_ELBO(ELBO):
108113
109114
:param num_particles: The number of particles/samples used to form the ELBO
110115
(gradient) estimators.
116+
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
117+
num_particles-many particles in parallel. If False use `jax.lax.map`.
118+
Defaults to True.
111119
"""
112120

113-
def __init__(self, num_particles=1):
114-
self.num_particles = num_particles
115-
116121
def loss_with_mutable_state(
117122
self, rng_key, param_map, model, guide, *args, **kwargs
118123
):
@@ -163,7 +168,10 @@ def single_particle_elbo(rng_key):
163168
return {"loss": -elbo, "mutable_state": mutable_state}
164169
else:
165170
rng_keys = random.split(rng_key, self.num_particles)
166-
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
171+
if self.vectorize_particles:
172+
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
173+
else:
174+
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
167175
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
168176

169177

@@ -291,7 +299,10 @@ def single_particle_elbo(rng_key):
291299
return {"loss": -elbo, "mutable_state": mutable_state}
292300
else:
293301
rng_keys = random.split(rng_key, self.num_particles)
294-
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
302+
if self.vectorize_particles:
303+
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
304+
else:
305+
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
295306
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
296307

297308

@@ -311,6 +322,9 @@ class RenyiELBO(ELBO):
311322
Here :math:`\alpha \neq 1`. Default is 0.
312323
:param num_particles: The number of particles/samples
313324
used to form the objective (gradient) estimator. Default is 2.
325+
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
326+
num_particles-many particles in parallel. If False use `jax.lax.map`.
327+
Defaults to True.
314328
315329
Example::
316330
@@ -427,7 +441,10 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
427441
)
428442

429443
rng_keys = random.split(rng_key, self.num_particles)
430-
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
444+
if self.vectorize_particles:
445+
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
446+
else:
447+
elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys)
431448
assert common_plate_scale.shape == (self.num_particles,)
432449
assert elbos.shape[0] == self.num_particles
433450
scaled_elbos = (1.0 - self.alpha) * elbos
@@ -695,8 +712,10 @@ class TraceGraph_ELBO(ELBO):
695712

696713
can_infer_discrete = True
697714

698-
def __init__(self, num_particles=1):
699-
super().__init__(num_particles=num_particles)
715+
def __init__(self, num_particles=1, vectorize_particles=True):
716+
super().__init__(
717+
num_particles=num_particles, vectorize_particles=vectorize_particles
718+
)
700719

701720
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
702721
"""
@@ -771,7 +790,10 @@ def single_particle_elbo(rng_key):
771790
return -single_particle_elbo(rng_key)
772791
else:
773792
rng_keys = random.split(rng_key, self.num_particles)
774-
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
793+
if self.vectorize_particles:
794+
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
795+
else:
796+
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
775797

776798

777799
def get_importance_trace_enum(
@@ -953,9 +975,13 @@ class TraceEnum_ELBO(ELBO):
953975

954976
can_infer_discrete = True
955977

956-
def __init__(self, num_particles=1, max_plate_nesting=float("inf")):
978+
def __init__(
979+
self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True
980+
):
957981
self.max_plate_nesting = max_plate_nesting
958-
super().__init__(num_particles=num_particles)
982+
super().__init__(
983+
num_particles=num_particles, vectorize_particles=vectorize_particles
984+
)
959985

960986
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
961987
def single_particle_elbo(rng_key):
@@ -1128,4 +1154,7 @@ def single_particle_elbo(rng_key):
11281154
return -single_particle_elbo(rng_key)
11291155
else:
11301156
rng_keys = random.split(rng_key, self.num_particles)
1131-
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
1157+
if self.vectorize_particles:
1158+
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
1159+
else:
1160+
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))

test/infer/test_svi.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,28 @@ def get_renyi(n=N, k=K, fix_indices=True):
170170
assert_allclose(atol, 0.0, atol=1e-5)
171171

172172

173+
def test_vectorized_particle():
174+
data = jnp.array([1.0] * 8 + [0.0] * 2)
175+
176+
def model(data):
177+
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
178+
with numpyro.plate("N", len(data)):
179+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
180+
181+
def guide(data):
182+
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
183+
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
184+
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
185+
186+
vmap_results = SVI(
187+
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=True)
188+
).run(random.PRNGKey(0), 100, data)
189+
map_results = SVI(
190+
model, guide, optim.Adam(0.1), Trace_ELBO(vectorize_particles=False)
191+
).run(random.PRNGKey(0), 100, data)
192+
assert_allclose(vmap_results.losses, map_results.losses, atol=1e-5)
193+
194+
173195
@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
174196
@pytest.mark.parametrize("optimizer", [optim.Adam(0.01), optimizers.adam(0.01)])
175197
def test_beta_bernoulli(elbo, optimizer):

0 commit comments

Comments
 (0)