Skip to content

Commit 21e1c4c

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Extract the resampling logic into its own function, remove old systematic_resample.
PiperOrigin-RevId: 721110369
1 parent f3147cd commit 21e1c4c

File tree

5 files changed

+119
-156
lines changed

5 files changed

+119
-156
lines changed

Diff for: discussion/probabilistic_bundle_adjustment/ProbBundle.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -3377,7 +3377,7 @@
33773377
" if (self.auto_resample.value or self.resample) or not jnp.all(\n",
33783378
" jnp.isfinite(extra.target_log_prob)\n",
33793379
" ):\n",
3380-
" (_, _), ancestor_idx = fun_mc.systematic_resample(\n",
3380+
" (_, _), ancestor_idx = fun_mc.resample(\n",
33813381
" (),\n",
33823382
" resample_strength * extra.target_log_prob,\n",
33833383
" jax.random.key(self.step),\n",

Diff for: spinoffs/fun_mc/fun_mc/fun_mc_lib.py

-55
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@
128128
'SimpleDualAveragesState',
129129
'splitting_integrator_step',
130130
'State',
131-
'systematic_resample',
132131
'trace',
133132
'transform_log_prob_fn',
134133
'TransitionOperator',
@@ -3466,60 +3465,6 @@ def clip_part(v):
34663465
)
34673466

34683467

3469-
@util.named_call
3470-
def systematic_resample(
3471-
particles: State,
3472-
log_weights: FloatArray,
3473-
seed: Any,
3474-
do_resample: Optional[BooleanArray] = None,
3475-
) -> tuple[tuple[State, FloatArray], IntArray]:
3476-
"""Systematically resamples particles in proportion to their weights.
3477-
3478-
This uses the algorithm from [1].
3479-
3480-
Args:
3481-
particles: The particles.
3482-
log_weights: Un-normalized weights.
3483-
seed: PRNG seed.
3484-
do_resample: Whether to perform the resample. If None, resampling is
3485-
performed unconditionally.
3486-
3487-
Returns:
3488-
particles_and_weights: tuple of resampled particles and weights.
3489-
ancestor_idx: Indices from which the returned particles were sampled from.
3490-
3491-
#### References
3492-
3493-
[1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction
3494-
Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal
3495-
Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818
3496-
"""
3497-
log_weights = jnp.asarray(log_weights)
3498-
log_weights = jnp.where(
3499-
jnp.isnan(log_weights),
3500-
jnp.array(-float('inf'), log_weights.dtype),
3501-
log_weights,
3502-
)
3503-
probs = jax.nn.softmax(log_weights)
3504-
num_particles = probs.shape[0]
3505-
3506-
shift = util.random_uniform([], log_weights.dtype, seed)
3507-
pie = jnp.cumsum(probs) * num_particles + shift
3508-
repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32)
3509-
parent_idxs = util.repeat(
3510-
jnp.arange(num_particles), repeats, total_repeat_length=num_particles
3511-
)
3512-
if do_resample is not None:
3513-
parent_idxs = jnp.where(do_resample, parent_idxs, jnp.arange(num_particles))
3514-
new_particles = util.map_tree(lambda x: x[parent_idxs], particles)
3515-
new_log_weights = jnp.full(
3516-
log_weights.shape, tfp.math.reduce_logmeanexp(log_weights)
3517-
)
3518-
if do_resample is not None:
3519-
new_log_weights = jnp.where(do_resample, new_log_weights, log_weights)
3520-
return (new_particles, new_log_weights), parent_idxs
3521-
3522-
35233468
class GeometricAnnealingPathExtra(NamedTuple):
35243469
"""Extra outputs of `geometric_annealing_path`.
35253470

Diff for: spinoffs/fun_mc/fun_mc/fun_mc_test.py

-56
Original file line numberDiff line numberDiff line change
@@ -2039,62 +2039,6 @@ def eval_fn(x):
20392039
self.assertAllCloseNested(value, fn(x))
20402040
self.assertAllCloseNested(expected_grad, grad)
20412041

2042-
def testSystematicResample(self):
2043-
probs = self._constant([0.0, 0.5, 0.2, 0.3, 0.0])
2044-
log_weights = jnp.log(probs)
2045-
particles = jnp.arange(probs.shape[0])
2046-
2047-
@jax.jit
2048-
def body(seed):
2049-
seed, resample_seed = util.split_seed(seed, 2)
2050-
(new_particles, new_log_weights), _ = fun_mc.systematic_resample(
2051-
particles, log_weights, resample_seed
2052-
)
2053-
return seed, (new_particles, new_log_weights)
2054-
2055-
_, (new_particles, new_log_weights) = fun_mc.trace(
2056-
self._make_seed(_test_seed()), body, 1000, trace_mask=(True, False)
2057-
)
2058-
2059-
new_particles_probs = jnp.mean(
2060-
jnp.array(new_particles[..., jnp.newaxis] == particles, jnp.float32),
2061-
(0, 1),
2062-
)
2063-
2064-
self.assertAllClose(new_particles_probs, probs, atol=0.05)
2065-
self.assertEqual(new_particles_probs[0], 0.0)
2066-
self.assertEqual(new_particles_probs[-1], 0.0)
2067-
self.assertAllClose(
2068-
new_log_weights,
2069-
jnp.full(probs.shape, tfp.math.reduce_logmeanexp(log_weights)),
2070-
)
2071-
2072-
def testSystematicResampleAncestors(self):
2073-
log_weights = self._constant([-float('inf'), 0.0])
2074-
particles = jnp.arange(log_weights.shape[0])
2075-
seed = self._make_seed(_test_seed())
2076-
2077-
(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
2078-
particles, log_weights, seed=seed
2079-
)
2080-
self.assertAllEqual(new_particles, jnp.ones_like(particles))
2081-
self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5])))
2082-
self.assertAllEqual(ancestors, jnp.ones_like(particles))
2083-
2084-
(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
2085-
particles, log_weights, do_resample=True, seed=seed
2086-
)
2087-
self.assertAllEqual(new_particles, jnp.ones_like(particles))
2088-
self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5])))
2089-
self.assertAllEqual(ancestors, jnp.ones_like(particles))
2090-
2091-
(new_particles, new_log_weights), ancestors = fun_mc.systematic_resample(
2092-
particles, log_weights, do_resample=False, seed=seed
2093-
)
2094-
self.assertAllEqual(new_particles, particles)
2095-
self.assertAllEqual(new_log_weights, log_weights)
2096-
self.assertAllEqual(ancestors, particles)
2097-
20982042

20992043
@test_util.multi_backend_test(globals(), 'fun_mc_test')
21002044
class FunMCTest32(FunMCTest):

Diff for: spinoffs/fun_mc/fun_mc/smc.py

+91-44
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
'conditional_systematic_resampling',
4949
'effective_sample_size_predicate',
5050
'ParticleGatherFn',
51+
'resample',
5152
'ResamplingPredicate',
5253
'SampleAncestorsFn',
5354
'sequential_monte_carlo_init',
@@ -78,6 +79,8 @@ def systematic_resampling(
7879
) -> Int[Array, 'num_particles']:
7980
"""Generate parent indices via systematic resampling.
8081
82+
This uses the algorithm from [1].
83+
8184
Args:
8285
log_weights: Unnormalized log-scale weights.
8386
seed: PRNG seed.
@@ -87,13 +90,14 @@ def systematic_resampling(
8790
Returns:
8891
parent_idxs: parent indices such that the marginal probability that a
8992
randomly chosen element will be `i` is equal to `softmax(log_weights)[i]`.
93+
94+
#### References
95+
96+
[1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction
97+
Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal
98+
Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818
9099
"""
91100
shift_seed, permute_seed = util.split_seed(seed, 2)
92-
log_weights = jnp.where(
93-
jnp.isnan(log_weights),
94-
jnp.array(-float('inf'), log_weights.dtype),
95-
log_weights,
96-
)
97101
probs = jax.nn.softmax(log_weights)
98102
# A common situation is all -inf log_weights that creats a NaN vector.
99103
probs = jnp.where(
@@ -146,11 +150,6 @@ def conditional_systematic_resampling(
146150
https://www.jstor.org/stable/43590414
147151
"""
148152
mixture_seed, shift_seed, permute_seed = util.split_seed(seed, 3)
149-
log_weights = jnp.where(
150-
jnp.isnan(log_weights),
151-
jnp.array(-float('inf'), log_weights.dtype),
152-
log_weights,
153-
)
154153
probs = jax.nn.softmax(log_weights)
155154
num_particles = log_weights.shape[0]
156155

@@ -377,7 +376,7 @@ def __call__(
377376

378377

379378
@types.runtime_typed
380-
def _defalt_pytree_gather(
379+
def _default_pytree_gather(
381380
state: State,
382381
indices: Int[Array, 'num_particles'],
383382
) -> State:
@@ -395,6 +394,75 @@ def _defalt_pytree_gather(
395394
return util.map_tree(lambda x: x[indices], state)
396395

397396

397+
@types.runtime_typed
398+
def resample(
399+
state: State,
400+
log_weights: Float[Array, 'num_particles'],
401+
seed: Seed,
402+
do_resample: BoolScalar = True,
403+
sample_ancestors_fn: SampleAncestorsFn = systematic_resampling,
404+
state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather,
405+
) -> tuple[
406+
tuple[State, Float[Array, 'num_particles']], Int[Array, 'num_particles']
407+
]:
408+
"""Possibly resamples state according to the log_weights.
409+
410+
The state should represent the same number of particles as implied by the
411+
length of `log_weights`. If resampling occurs, the new log weights are
412+
log-mean-exp of the incoming log weights. Otherwise, they are unchanged. By
413+
default, this function performs systematic resampling.
414+
415+
Args:
416+
state: The particles.
417+
log_weights: Un-normalized log weights. NaN log weights are treated as -inf.
418+
seed: Random seed.
419+
do_resample: Whether to resample.
420+
sample_ancestors_fn: Ancestor index sampling function.
421+
state_gather_fn: State gather function.
422+
423+
Returns:
424+
state_and_log_weights: tuple of the resampled state and log weights.
425+
ancestor_idx: Indices that indicate which elements of the original state the
426+
returned state particles were sampled from.
427+
"""
428+
429+
def do_resample_fn(
430+
state,
431+
log_weights,
432+
seed,
433+
):
434+
log_weights = jnp.where(
435+
jnp.isnan(log_weights),
436+
jnp.array(-float('inf'), log_weights.dtype),
437+
log_weights,
438+
)
439+
ancestor_idxs = sample_ancestors_fn(log_weights, seed)
440+
new_state = state_gather_fn(state, ancestor_idxs)
441+
num_particles = log_weights.shape[0]
442+
new_log_weights = jnp.full(
443+
(num_particles,), tfp.math.reduce_logmeanexp(log_weights)
444+
)
445+
return (new_state, new_log_weights), ancestor_idxs
446+
447+
def dont_resample_fn(
448+
state,
449+
log_weights,
450+
seed,
451+
):
452+
del seed
453+
num_particles = log_weights.shape[0]
454+
return (state, log_weights), jnp.arange(num_particles)
455+
456+
return _smart_cond(
457+
do_resample,
458+
do_resample_fn,
459+
dont_resample_fn,
460+
state,
461+
log_weights,
462+
seed,
463+
)
464+
465+
398466
@types.runtime_typed
399467
def sequential_monte_carlo_init(
400468
state: State,
@@ -430,7 +498,7 @@ def sequential_monte_carlo_step(
430498
seed: Seed,
431499
resampling_pred: ResamplingPredicate = effective_sample_size_predicate,
432500
sample_ancestors_fn: SampleAncestorsFn = systematic_resampling,
433-
state_gather_fn: ParticleGatherFn[State] = _defalt_pytree_gather,
501+
state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather,
434502
) -> tuple[
435503
SequentialMonteCarloState[State], SequentialMonteCarloExtra[State, Extra]
436504
]:
@@ -461,43 +529,21 @@ def sequential_monte_carlo_step(
461529
"""
462530
resample_seed, kernel_seed = util.split_seed(seed, 2)
463531

464-
def do_resample(
465-
state,
466-
log_weights,
467-
seed,
468-
):
469-
ancestor_idxs = sample_ancestors_fn(log_weights, seed)
470-
new_state = state_gather_fn(state, ancestor_idxs)
471-
num_particles = log_weights.shape[0]
472-
new_log_weights = jnp.full(
473-
(num_particles,), tfp.math.reduce_logmeanexp(log_weights)
474-
)
475-
return (new_state, ancestor_idxs, new_log_weights)
476-
477-
def dont_resample(
478-
state,
479-
log_weights,
480-
seed,
481-
):
482-
del seed
483-
num_particles = log_weights.shape[0]
484-
return state, jnp.arange(num_particles), log_weights
485-
486532
# NOTE: We don't explicitly disable resampling at the first step. However, if
487533
# we initialize the log weights to zeros, either of
488534
# 1. resampling according to the effective sample size criterion and
489535
# 2. using systematic resampling effectively disables resampling at the first
490536
# step.
491537
# First-step resampling can always be forced via the `resampling_pred`.
492-
should_resample = resampling_pred(smc_state)
493-
state_after_resampling, ancestor_idxs, log_weights_after_resampling = (
494-
_smart_cond(
495-
should_resample,
496-
do_resample,
497-
dont_resample,
498-
smc_state.state,
499-
smc_state.log_weights,
500-
resample_seed,
538+
do_resample = resampling_pred(smc_state)
539+
(state_after_resampling, log_weights_after_resampling), ancestor_idxs = (
540+
resample(
541+
state=smc_state.state,
542+
log_weights=smc_state.log_weights,
543+
do_resample=do_resample,
544+
seed=resample_seed,
545+
sample_ancestors_fn=sample_ancestors_fn,
546+
state_gather_fn=state_gather_fn,
501547
)
502548
)
503549

@@ -516,7 +562,7 @@ def dont_resample(
516562
smc_extra = SequentialMonteCarloExtra(
517563
incremental_log_weights=incremental_log_weights,
518564
kernel_extra=kernel_extra,
519-
resampled=should_resample,
565+
resampled=do_resample,
520566
ancestor_idxs=ancestor_idxs,
521567
state_after_resampling=state_after_resampling,
522568
log_weights_after_resampling=log_weights_after_resampling,
@@ -711,6 +757,7 @@ def inner_kernel(state, stage, tlp_fn, seed):
711757
)
712758

713759

760+
@types.runtime_typed
714761
def _smart_cond(
715762
pred: BoolScalar,
716763
true_fn: Callable[..., T],

Diff for: spinoffs/fun_mc/fun_mc/smc_test.py

+27
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,33 @@ def kernel(seed):
272272
)
273273
self.assertAllClose(rejection_freqs, conditional_freqs, atol=0.05)
274274

275+
def test_resample(self):
276+
state = jnp.array([3, 2, 1, 0])
277+
log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype)
278+
seed = _test_seed()
279+
280+
(new_state, new_log_weights), ancestor_idxs = smc.resample(
281+
state=state, log_weights=log_weights, seed=seed
282+
)
283+
284+
self.assertAllTrue(new_state != 3)
285+
self.assertAllTrue(new_state != 2)
286+
self.assertAllTrue(~jnp.isnan(new_log_weights))
287+
self.assertAllEqual(3 - new_state, ancestor_idxs)
288+
289+
def test_resample_but_dont(self):
290+
state = jnp.array([3, 2, 1, 0])
291+
log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype)
292+
seed = _test_seed()
293+
294+
(new_state, new_log_weights), ancestor_idxs = smc.resample(
295+
state=state, log_weights=log_weights, do_resample=False, seed=seed
296+
)
297+
298+
self.assertAllEqual(new_state, state)
299+
self.assertAllEqual(new_log_weights, log_weights)
300+
self.assertAllEqual(ancestor_idxs, jnp.arange(state.shape[0]))
301+
275302
def test_smc_runs_and_shapes_correct(self):
276303
num_particles = 3
277304
num_timesteps = 20

0 commit comments

Comments
 (0)