@@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
1463
1463
if self .global_guide is not None :
1464
1464
global_latents = self .global_guide (* args , ** kwargs )
1465
1465
rng_key = numpyro .prng_key ()
1466
- with handlers .block (), handlers .seed (rng_seed = rng_key ), handlers .substitute (
1467
- data = global_latents
1466
+ with (
1467
+ handlers .block (),
1468
+ handlers .seed (rng_seed = rng_key ),
1469
+ handlers .substitute (data = global_latents ),
1468
1470
):
1469
1471
global_outputs = self .global_guide .model (* args , ** kwargs )
1470
1472
local_args = (global_outputs ,)
@@ -1575,9 +1577,12 @@ def fn(x):
1575
1577
if self .local_guide is not None :
1576
1578
key = numpyro .prng_key ()
1577
1579
subsample_guide = partial (_subsample_model , self .local_guide )
1578
- with handlers .block (), handlers .trace () as tr , handlers .seed (
1579
- rng_seed = key
1580
- ), handlers .substitute (data = local_guide_params ):
1580
+ with (
1581
+ handlers .block (),
1582
+ handlers .trace () as tr ,
1583
+ handlers .seed (rng_seed = key ),
1584
+ handlers .substitute (data = local_guide_params ),
1585
+ ):
1581
1586
with warnings .catch_warnings ():
1582
1587
warnings .simplefilter ("ignore" )
1583
1588
subsample_guide (* local_args , ** local_kwargs )
0 commit comments