48
48
'conditional_systematic_resampling' ,
49
49
'effective_sample_size_predicate' ,
50
50
'ParticleGatherFn' ,
51
+ 'resample' ,
51
52
'ResamplingPredicate' ,
52
53
'SampleAncestorsFn' ,
53
54
'sequential_monte_carlo_init' ,
@@ -78,6 +79,8 @@ def systematic_resampling(
78
79
) -> Int [Array , 'num_particles' ]:
79
80
"""Generate parent indices via systematic resampling.
80
81
82
+ This uses the algorithm from [1].
83
+
81
84
Args:
82
85
log_weights: Unnormalized log-scale weights.
83
86
seed: PRNG seed.
@@ -87,13 +90,14 @@ def systematic_resampling(
87
90
Returns:
88
91
parent_idxs: parent indices such that the marginal probability that a
89
92
randomly chosen element will be `i` is equal to `softmax(log_weights)[i]`.
93
+
94
+ #### References
95
+
96
+ [1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction
97
+ Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal
98
+ Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818
90
99
"""
91
100
shift_seed , permute_seed = util .split_seed (seed , 2 )
92
- log_weights = jnp .where (
93
- jnp .isnan (log_weights ),
94
- jnp .array (- float ('inf' ), log_weights .dtype ),
95
- log_weights ,
96
- )
97
101
probs = jax .nn .softmax (log_weights )
98
102
# A common situation is all -inf log_weights that creats a NaN vector.
99
103
probs = jnp .where (
@@ -146,11 +150,6 @@ def conditional_systematic_resampling(
146
150
https://www.jstor.org/stable/43590414
147
151
"""
148
152
mixture_seed , shift_seed , permute_seed = util .split_seed (seed , 3 )
149
- log_weights = jnp .where (
150
- jnp .isnan (log_weights ),
151
- jnp .array (- float ('inf' ), log_weights .dtype ),
152
- log_weights ,
153
- )
154
153
probs = jax .nn .softmax (log_weights )
155
154
num_particles = log_weights .shape [0 ]
156
155
@@ -377,7 +376,7 @@ def __call__(
377
376
378
377
379
378
@types .runtime_typed
380
- def _defalt_pytree_gather (
379
+ def _default_pytree_gather (
381
380
state : State ,
382
381
indices : Int [Array , 'num_particles' ],
383
382
) -> State :
@@ -395,6 +394,75 @@ def _defalt_pytree_gather(
395
394
return util .map_tree (lambda x : x [indices ], state )
396
395
397
396
397
+ @types .runtime_typed
398
+ def resample (
399
+ state : State ,
400
+ log_weights : Float [Array , 'num_particles' ],
401
+ seed : Seed ,
402
+ do_resample : BoolScalar = True ,
403
+ sample_ancestors_fn : SampleAncestorsFn = systematic_resampling ,
404
+ state_gather_fn : ParticleGatherFn [State ] = _default_pytree_gather ,
405
+ ) -> tuple [
406
+ tuple [State , Float [Array , 'num_particles' ]], Int [Array , 'num_particles' ]
407
+ ]:
408
+ """Possibly resamples state according to the log_weights.
409
+
410
+ The state should represent the same number of particles as implied by the
411
+ length of `log_weights`. If resampling occurs, the new log weights are
412
+ log-mean-exp of the incoming log weights. Otherwise, they are unchanged. By
413
+ default, this function performs systematic resampling.
414
+
415
+ Args:
416
+ state: The particles.
417
+ log_weights: Un-normalized log weights. NaN log weights are treated as -inf.
418
+ seed: Random seed.
419
+ do_resample: Whether to resample.
420
+ sample_ancestors_fn: Ancestor index sampling function.
421
+ state_gather_fn: State gather function.
422
+
423
+ Returns:
424
+ state_and_log_weights: tuple of the resampled state and log weights.
425
+ ancestor_idx: Indices that indicate which elements of the original state the
426
+ returned state particles were sampled from.
427
+ """
428
+
429
+ def do_resample_fn (
430
+ state ,
431
+ log_weights ,
432
+ seed ,
433
+ ):
434
+ log_weights = jnp .where (
435
+ jnp .isnan (log_weights ),
436
+ jnp .array (- float ('inf' ), log_weights .dtype ),
437
+ log_weights ,
438
+ )
439
+ ancestor_idxs = sample_ancestors_fn (log_weights , seed )
440
+ new_state = state_gather_fn (state , ancestor_idxs )
441
+ num_particles = log_weights .shape [0 ]
442
+ new_log_weights = jnp .full (
443
+ (num_particles ,), tfp .math .reduce_logmeanexp (log_weights )
444
+ )
445
+ return (new_state , new_log_weights ), ancestor_idxs
446
+
447
+ def dont_resample_fn (
448
+ state ,
449
+ log_weights ,
450
+ seed ,
451
+ ):
452
+ del seed
453
+ num_particles = log_weights .shape [0 ]
454
+ return (state , log_weights ), jnp .arange (num_particles )
455
+
456
+ return _smart_cond (
457
+ do_resample ,
458
+ do_resample_fn ,
459
+ dont_resample_fn ,
460
+ state ,
461
+ log_weights ,
462
+ seed ,
463
+ )
464
+
465
+
398
466
@types .runtime_typed
399
467
def sequential_monte_carlo_init (
400
468
state : State ,
@@ -430,7 +498,7 @@ def sequential_monte_carlo_step(
430
498
seed : Seed ,
431
499
resampling_pred : ResamplingPredicate = effective_sample_size_predicate ,
432
500
sample_ancestors_fn : SampleAncestorsFn = systematic_resampling ,
433
- state_gather_fn : ParticleGatherFn [State ] = _defalt_pytree_gather ,
501
+ state_gather_fn : ParticleGatherFn [State ] = _default_pytree_gather ,
434
502
) -> tuple [
435
503
SequentialMonteCarloState [State ], SequentialMonteCarloExtra [State , Extra ]
436
504
]:
@@ -461,43 +529,21 @@ def sequential_monte_carlo_step(
461
529
"""
462
530
resample_seed , kernel_seed = util .split_seed (seed , 2 )
463
531
464
- def do_resample (
465
- state ,
466
- log_weights ,
467
- seed ,
468
- ):
469
- ancestor_idxs = sample_ancestors_fn (log_weights , seed )
470
- new_state = state_gather_fn (state , ancestor_idxs )
471
- num_particles = log_weights .shape [0 ]
472
- new_log_weights = jnp .full (
473
- (num_particles ,), tfp .math .reduce_logmeanexp (log_weights )
474
- )
475
- return (new_state , ancestor_idxs , new_log_weights )
476
-
477
- def dont_resample (
478
- state ,
479
- log_weights ,
480
- seed ,
481
- ):
482
- del seed
483
- num_particles = log_weights .shape [0 ]
484
- return state , jnp .arange (num_particles ), log_weights
485
-
486
532
# NOTE: We don't explicitly disable resampling at the first step. However, if
487
533
# we initialize the log weights to zeros, either of
488
534
# 1. resampling according to the effective sample size criterion and
489
535
# 2. using systematic resampling effectively disables resampling at the first
490
536
# step.
491
537
# First-step resampling can always be forced via the `resampling_pred`.
492
- should_resample = resampling_pred (smc_state )
493
- state_after_resampling , ancestor_idxs , log_weights_after_resampling = (
494
- _smart_cond (
495
- should_resample ,
496
- do_resample ,
497
- dont_resample ,
498
- smc_state . state ,
499
- smc_state . log_weights ,
500
- resample_seed ,
538
+ do_resample = resampling_pred (smc_state )
539
+ ( state_after_resampling , log_weights_after_resampling ), ancestor_idxs = (
540
+ resample (
541
+ state = smc_state . state ,
542
+ log_weights = smc_state . log_weights ,
543
+ do_resample = do_resample ,
544
+ seed = resample_seed ,
545
+ sample_ancestors_fn = sample_ancestors_fn ,
546
+ state_gather_fn = state_gather_fn ,
501
547
)
502
548
)
503
549
@@ -516,7 +562,7 @@ def dont_resample(
516
562
smc_extra = SequentialMonteCarloExtra (
517
563
incremental_log_weights = incremental_log_weights ,
518
564
kernel_extra = kernel_extra ,
519
- resampled = should_resample ,
565
+ resampled = do_resample ,
520
566
ancestor_idxs = ancestor_idxs ,
521
567
state_after_resampling = state_after_resampling ,
522
568
log_weights_after_resampling = log_weights_after_resampling ,
@@ -711,6 +757,7 @@ def inner_kernel(state, stage, tlp_fn, seed):
711
757
)
712
758
713
759
760
+ @types .runtime_typed
714
761
def _smart_cond (
715
762
pred : BoolScalar ,
716
763
true_fn : Callable [..., T ],
0 commit comments