@@ -72,7 +72,11 @@ def transition_fn(carry, y):
7272 with handlers .mask (mask = first_capture_mask ):
7373 mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask )
7474 # NumPyro exactly sums out the discrete states z_t.
75- z = numpyro .sample ("z" , dist .Bernoulli (dist .util .clamp_probs (mu_z_t )))
75+ z = numpyro .sample (
76+ "z" ,
77+ dist .Bernoulli (dist .util .clamp_probs (mu_z_t )),
78+ infer = {"enumerate" : "parallel" },
79+ )
7680 mu_y_t = rho * z
7781 numpyro .sample (
7882 "y" , dist .Bernoulli (dist .util .clamp_probs (mu_y_t )), obs = y
@@ -112,7 +116,11 @@ def transition_fn(carry, y):
112116 with handlers .mask (mask = first_capture_mask ):
113117 mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask )
114118 # NumPyro exactly sums out the discrete states z_t.
115- z = numpyro .sample ("z" , dist .Bernoulli (dist .util .clamp_probs (mu_z_t )))
119+ z = numpyro .sample (
120+ "z" ,
121+ dist .Bernoulli (dist .util .clamp_probs (mu_z_t )),
122+ infer = {"enumerate" : "parallel" },
123+ )
116124 mu_y_t = rho * z
117125 numpyro .sample (
118126 "y" , dist .Bernoulli (dist .util .clamp_probs (mu_y_t )), obs = y
@@ -160,7 +168,11 @@ def transition_fn(carry, y):
160168 with handlers .mask (mask = first_capture_mask ):
161169 mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask )
162170 # NumPyro exactly sums out the discrete states z_t.
163- z = numpyro .sample ("z" , dist .Bernoulli (dist .util .clamp_probs (mu_z_t )))
171+ z = numpyro .sample (
172+ "z" ,
173+ dist .Bernoulli (dist .util .clamp_probs (mu_z_t )),
174+ infer = {"enumerate" : "parallel" },
175+ )
164176 mu_y_t = rho * z
165177 numpyro .sample (
166178 "y" , dist .Bernoulli (dist .util .clamp_probs (mu_y_t )), obs = y
@@ -202,7 +214,11 @@ def transition_fn(carry, y):
202214 with handlers .mask (mask = first_capture_mask ):
203215 mu_z_t = first_capture_mask * phi * z + (1 - first_capture_mask )
204216 # NumPyro exactly sums out the discrete states z_t.
205- z = numpyro .sample ("z" , dist .Bernoulli (dist .util .clamp_probs (mu_z_t )))
217+ z = numpyro .sample (
218+ "z" ,
219+ dist .Bernoulli (dist .util .clamp_probs (mu_z_t )),
220+ infer = {"enumerate" : "parallel" },
221+ )
206222 mu_y_t = rho * z
207223 numpyro .sample (
208224 "y" , dist .Bernoulli (dist .util .clamp_probs (mu_y_t )), obs = y
@@ -249,7 +265,11 @@ def transition_fn(carry, y):
249265 with handlers .mask (mask = first_capture_mask ):
250266 mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask )
251267 # NumPyro exactly sums out the discrete states z_t.
252- z = numpyro .sample ("z" , dist .Bernoulli (dist .util .clamp_probs (mu_z_t )))
268+ z = numpyro .sample (
269+ "z" ,
270+ dist .Bernoulli (dist .util .clamp_probs (mu_z_t )),
271+ infer = {"enumerate" : "parallel" },
272+ )
253273 mu_y_t = rho * z
254274 numpyro .sample (
255275 "y" , dist .Bernoulli (dist .util .clamp_probs (mu_y_t )), obs = y
0 commit comments