Skip to content

Commit 9df98f1

Browse files
committed
ci: apply fixes
Signed-off-by: nstarman <[email protected]>
1 parent e01aea7 commit 9df98f1

File tree

7 files changed

+39
-20
lines changed

7 files changed

+39
-20
lines changed

numpyro/contrib/control_flow/scan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ def body_fn(wrapped_carry, x, prefix=None):
197197
)
198198
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
199199

200-
with handlers.block(
201-
hide_fn=lambda site: not site["name"].startswith("_PREV_")
202-
), enum(first_available_dim=first_available_dim):
200+
with (
201+
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
202+
enum(first_available_dim=first_available_dim),
203+
):
203204
wrapped_carry = (0, rng_key, init)
204205
y0s = []
205206
# We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`

numpyro/infer/autoguide.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
14631463
if self.global_guide is not None:
14641464
global_latents = self.global_guide(*args, **kwargs)
14651465
rng_key = numpyro.prng_key()
1466-
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
1467-
data=global_latents
1466+
with (
1467+
handlers.block(),
1468+
handlers.seed(rng_seed=rng_key),
1469+
handlers.substitute(data=global_latents),
14681470
):
14691471
global_outputs = self.global_guide.model(*args, **kwargs)
14701472
local_args = (global_outputs,)
@@ -1575,9 +1577,12 @@ def fn(x):
15751577
if self.local_guide is not None:
15761578
key = numpyro.prng_key()
15771579
subsample_guide = partial(_subsample_model, self.local_guide)
1578-
with handlers.block(), handlers.trace() as tr, handlers.seed(
1579-
rng_seed=key
1580-
), handlers.substitute(data=local_guide_params):
1580+
with (
1581+
handlers.block(),
1582+
handlers.trace() as tr,
1583+
handlers.seed(rng_seed=key),
1584+
handlers.substitute(data=local_guide_params),
1585+
):
15811586
with warnings.catch_warnings():
15821587
warnings.simplefilter("ignore")
15831588
subsample_guide(*local_args, **local_kwargs)

numpyro/infer/elbo.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,11 @@ def get_importance_trace_enum(
874874
trace as _trace,
875875
)
876876

877-
with plate_to_enum_plate(), enum(
878-
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
877+
with (
878+
plate_to_enum_plate(),
879+
enum(
880+
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
881+
),
879882
):
880883
guide = substitute(guide, data=params)
881884
with _without_rsample_stop_gradient():

numpyro/infer/inspect.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ def get_trace():
5858

5959
def _get_log_probs(model, model_args, model_kwargs, **sample):
6060
# Note: We use seed 0 for parameter initialization.
61-
with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute(
62-
data=sample
61+
with (
62+
handlers.trace() as tr,
63+
handlers.seed(rng_seed=0),
64+
handlers.substitute(data=sample),
6365
):
6466
model(*model_args, **model_kwargs)
6567
return {
@@ -370,8 +372,9 @@ def process_message(self, msg):
370372

371373
# Note: We use seed 0 for parameter initialization.
372374
with handlers.trace() as tr, handlers.seed(rng_seed=0):
373-
with handlers.substitute(data=sample), substitute_deterministic(
374-
data=sample
375+
with (
376+
handlers.substitute(data=sample),
377+
substitute_deterministic(data=sample),
375378
):
376379
model(*model_args, **model_kwargs)
377380
provenance_arrays = {}

test/infer/test_hmc_util.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def optimize(f):
5757
@pytest.mark.parametrize("regularize", [True, False])
5858
@pytest.mark.filterwarnings("ignore:numpy.linalg support is experimental:UserWarning")
5959
def test_welford_covariance(jitted, diagonal, regularize):
60-
with optional(jitted, disable_jit()), optional(
61-
jitted, control_flow_prims_disabled()
60+
with (
61+
optional(jitted, disable_jit()),
62+
optional(jitted, control_flow_prims_disabled()),
6263
):
6364
np.random.seed(0)
6465
loc = np.random.randn(3)

test/test_distributions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2221,8 +2221,12 @@ def g(x):
22212221

22222222

22232223
def test_beta_proportion_invalid_mean():
2224-
with dist.distribution.validation_enabled(), pytest.raises(
2225-
ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$"
2224+
with (
2225+
dist.distribution.validation_enabled(),
2226+
pytest.raises(
2227+
ValueError,
2228+
match=r"^BetaProportion distribution got invalid mean parameter\.$",
2229+
),
22262230
):
22272231
dist.BetaProportion(1.0, 1.0)
22282232

test/test_handlers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,10 @@ def test_subsample_substitute():
372372
data = jnp.arange(100.0)
373373
subsample_size = 7
374374
subsample = jnp.array([13, 3, 30, 4, 1, 68, 5])
375-
with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute(
376-
data={"a": subsample}
375+
with (
376+
handlers.trace() as tr,
377+
handlers.seed(rng_seed=0),
378+
handlers.substitute(data={"a": subsample}),
377379
):
378380
with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx:
379381
assert data[idx].shape == (subsample_size,)

0 commit comments

Comments
 (0)