Skip to content

Commit cc777e8

Browse files
neerajpradfehiepsi
authored andcommitted
Minor changes to infer.util for 0.2.2 (#487)
* Minor changes to infer.util for 0.2.2 * fix test * fix test; address comment * fix invocation
1 parent db872b1 commit cc777e8

File tree

6 files changed

+59
-44
lines changed

6 files changed

+59
-44
lines changed

examples/stochastic_volatility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def main(args):
8282
_, fetch = load_dataset(SP500, shuffle=False)
8383
dates, returns = fetch()
8484
init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
85-
init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, returns)
85+
init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, model_args=(returns,))
8686
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
8787
hmc_state = init_kernel(init_params, args.num_warmup, rng_key=sample_rng_key)
8888
hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state,

numpyro/contrib/autoguide/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ def __init__(self, model, prefix="auto", init_strategy=init_to_uniform()):
136136
def _setup_prototype(self, *args, **kwargs):
137137
super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
138138
rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity())
139-
init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model, *args,
139+
init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model,
140140
init_strategy=self.init_strategy,
141-
**kwargs)
141+
model_args=args,
142+
model_kwargs=kwargs)
142143
self._inv_transforms = {}
143144
self._has_transformed_dist = False
144145
unconstrained_sites = {}

numpyro/infer/mcmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
151151
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
152152
>>>
153153
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
154-
... model, data, labels)
154+
... model, model_args=(data, labels,))
155155
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
156156
>>> hmc_state = init_kernel(init_params,
157157
... trajectory_length=10,
@@ -495,10 +495,11 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg
495495
' `potential_fn`.')
496496
# Find valid initial params
497497
if self._model and not init_params:
498-
init_params, is_valid = find_valid_initial_params(rng_key, self._model, *model_args,
498+
init_params, is_valid = find_valid_initial_params(rng_key, self._model,
499499
init_strategy=self._init_strategy,
500500
param_as_improper=True,
501-
**model_kwargs)
501+
model_args=model_args,
502+
model_kwargs=model_kwargs)
502503
if not_jax_tracer(is_valid):
503504
if device_get(~np.all(is_valid)):
504505
raise RuntimeError("Cannot find valid initial parameters. "

numpyro/infer/util.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232

3333
def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False):
3434
"""
35-
Computes log of joint density for the model given latent values ``params``.
35+
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
36+
latent values ``params``.
3637
3738
:param model: Python callable containing NumPyro primitives.
3839
:param tuple model_args: args provided to the model.
39-
:param dict model_kwargs`: kwargs provided to the model.
40+
:param dict model_kwargs: kwargs provided to the model.
4041
:param dict params: dictionary of current parameter values keyed by site
4142
name.
4243
:param bool skip_dist_transforms: whether to compute log probability of a site
@@ -76,8 +77,9 @@ def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=Fa
7677

7778
def transform_fn(transforms, params, invert=False):
7879
"""
79-
Callable that applies a transformation from the `transforms` dict to values in the
80-
`params` dict and returns the transformed values keyed on the same names.
80+
(EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms`
81+
dict to values in the `params` dict and returns the transformed values keyed on
82+
the same names.
8183
8284
:param transforms: Dictionary of transforms keyed by names. Names in
8385
`transforms` and `params` should align.
@@ -93,17 +95,18 @@ def transform_fn(transforms, params, invert=False):
9395

9496
def constrain_fn(model, transforms, model_args, model_kwargs, params):
9597
"""
96-
Gets value at each latent site in `model` given unconstrained parameters `params`.
97-
The `transforms` is used to transform these unconstrained parameters to base values
98-
of the corresponding priors in `model`. If a prior is a transformed distribution,
99-
the corresponding base value lies in the support of base distribution. Otherwise,
100-
the base value lies in the support of the distribution.
98+
(EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given
99+
unconstrained parameters `params`. The `transforms` is used to transform these
100+
unconstrained parameters to base values of the corresponding priors in `model`.
101+
If a prior is a transformed distribution, the corresponding base value lies in
102+
the support of base distribution. Otherwise, the base value lies in the support
103+
of the distribution.
101104
102105
:param model: a callable containing NumPyro primitives.
103-
:param tuple model_args: args provided to the model.
104-
:param dict model_kwargs: kwargs provided to the model.
105106
:param dict transforms: dictionary of transforms keyed by names. Names in
106107
`transforms` and `params` should align.
108+
:param tuple model_args: args provided to the model.
109+
:param dict model_kwargs: kwargs provided to the model.
107110
:param dict params: dictionary of unconstrained values keyed by site
108111
names.
109112
:return: `dict` of transformed params.
@@ -116,16 +119,16 @@ def constrain_fn(model, transforms, model_args, model_kwargs, params):
116119

117120
def potential_energy(model, inv_transforms, model_args, model_kwargs, params):
118121
"""
119-
Computes potential energy of a model given unconstrained params.
122+
(EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params.
120123
The `inv_transforms` is used to transform these unconstrained parameters to base values
121124
of the corresponding priors in `model`. If a prior is a transformed distribution,
122125
the corresponding base value lies in the support of base distribution. Otherwise,
123126
the base value lies in the support of the distribution.
124127
125128
:param model: a callable containing NumPyro primitives.
126-
:param tuple model_args: args provided to the model.
127-
:param dict model_kwargs`: kwargs provided to the model.
128129
:param dict inv_transforms: dictionary of transforms keyed by names.
130+
:param tuple model_args: args provided to the model.
131+
:param dict model_kwargs: kwargs provided to the model.
129132
:param dict params: unconstrained parameters of `model`.
130133
:return: potential energy given unconstrained parameters.
131134
"""
@@ -268,8 +271,11 @@ def init_to_value(values):
268271
return partial(_init_to_value, values=values)
269272

270273

271-
def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to_uniform(),
272-
param_as_improper=False, **model_kwargs):
274+
def find_valid_initial_params(rng_key, model,
275+
init_strategy=init_to_uniform(),
276+
param_as_improper=False,
277+
model_args=(),
278+
model_kwargs=None):
273279
"""
274280
(EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial
275281
valid unconstrained value for all the parameters. This function also returns an
@@ -281,11 +287,11 @@ def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to
281287
sample from the prior. The returned `init_params` will have the
282288
batch shape ``rng_key.shape[:-1]``.
283289
:param model: Python callable containing Pyro primitives.
284-
:param `*model_args`: args provided to the model.
285290
:param callable init_strategy: a per-site initialization function.
286291
:param bool param_as_improper: a flag to decide whether to consider sites with
287292
`param` statement as sites with improper priors.
288-
:param `**model_kwargs`: kwargs provided to the model.
293+
:param tuple model_args: args provided to the model.
294+
:param dict model_kwargs: kwargs provided to the model.
289295
:return: tuple of (`init_params`, `is_valid`).
290296
"""
291297
init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper)
@@ -416,8 +422,11 @@ def constrain_fun(*args, **kwargs):
416422
return potential_fn, constrain_fun
417423

418424

419-
def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(),
420-
dynamic_args=False, **model_kwargs):
425+
def initialize_model(rng_key, model,
426+
init_strategy=init_to_uniform(),
427+
dynamic_args=False,
428+
model_args=(),
429+
model_kwargs=None):
421430
"""
422431
(EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
423432
and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
@@ -427,30 +436,33 @@ def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(
427436
sample from the prior. The returned `init_params` will have the
428437
batch shape ``rng_key.shape[:-1]``.
429438
:param model: Python callable containing Pyro primitives.
430-
:param `*model_args`: args provided to the model.
431439
:param callable init_strategy: a per-site initialization function.
432440
See :ref:`init_strategy` section for available functions.
433441
:param bool dynamic_args: if `True`, the `potential_fn` and
434442
`constraints_fn` are themselves dependent on model arguments.
435443
When provided a `*model_args, **model_kwargs`, they return
436444
`potential_fn` and `constraints_fn` callables, respectively.
437-
:param `**model_kwargs`: kwargs provided to the model.
445+
:param tuple model_args: args provided to the model.
446+
:param dict model_kwargs: kwargs provided to the model.
438447
:return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
439448
`init_params` are values from the prior used to initiate MCMC,
440449
`constrain_fn` is a callable that uses inverse transforms
441450
to convert unconstrained HMC samples to constrained values that
442451
lie within the site's support.
443452
"""
453+
if model_kwargs is None:
454+
model_kwargs = {}
444455
potential_fun, constrain_fun = get_potential_fn(rng_key if rng_key.ndim == 1 else rng_key[0],
445456
model,
446457
dynamic_args=dynamic_args,
447458
model_args=model_args,
448459
model_kwargs=model_kwargs)
449460

450-
init_params, is_valid = find_valid_initial_params(rng_key, model, *model_args,
461+
init_params, is_valid = find_valid_initial_params(rng_key, model,
451462
init_strategy=init_strategy,
452463
param_as_improper=True,
453-
**model_kwargs)
464+
model_args=model_args,
465+
model_kwargs=model_kwargs)
454466

455467
if not_jax_tracer(is_valid):
456468
if device_get(~np.all(is_valid)):
@@ -559,11 +571,8 @@ def get_samples(self, rng_key, *args, **kwargs):
559571

560572
def log_likelihood(model, posterior_samples, *args, **kwargs):
561573
"""
562-
Returns log likelihood at observation nodes of model, given samples of all latent variables.
563-
564-
.. warning::
565-
The interface for the `log_likelihood` function is experimental, and
566-
might change in the future.
574+
(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
575+
given samples of all latent variables.
567576
568577
:param model: Python callable containing Pyro primitives.
569578
:param dict posterior_samples: dictionary of samples from the posterior.

test/test_infer_util.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,13 @@ def model(data):
192192
])
193193

194194
rng_keys = random.split(random.PRNGKey(1), 2)
195-
init_params, _, _ = initialize_model(rng_keys, model, count_data,
196-
init_strategy=init_strategy)
195+
init_params, _, _ = initialize_model(rng_keys, model,
196+
init_strategy=init_strategy,
197+
model_args=(count_data,))
197198
for i in range(2):
198-
init_params_i, _, _ = initialize_model(rng_keys[i], model, count_data,
199-
init_strategy=init_strategy)
199+
init_params_i, _, _ = initialize_model(rng_keys[i], model,
200+
init_strategy=init_strategy,
201+
model_args=(count_data,))
200202
for name, p in init_params.items():
201203
# XXX: the result is equal if we disable fast-math-mode
202204
assert_allclose(p[i], init_params_i[name], atol=1e-6)
@@ -219,11 +221,13 @@ def model(data):
219221
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
220222

221223
rng_keys = random.split(random.PRNGKey(1), 2)
222-
init_params, _, _ = initialize_model(rng_keys, model, data,
223-
init_strategy=init_strategy)
224+
init_params, _, _ = initialize_model(rng_keys, model,
225+
init_strategy=init_strategy,
226+
model_args=(data,))
224227
for i in range(2):
225-
init_params_i, _, _ = initialize_model(rng_keys[i], model, data,
226-
init_strategy=init_strategy)
228+
init_params_i, _, _ = initialize_model(rng_keys[i], model,
229+
init_strategy=init_strategy,
230+
model_args=(data,))
227231
for name, p in init_params.items():
228232
# XXX: the result is equal if we disable fast-math-mode
229233
assert_allclose(p[i], init_params_i[name], atol=1e-6)

test/test_mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def model(data):
459459

460460
true_probs = np.array([0.9, 0.1])
461461
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
462-
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data)
462+
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, model_args=(data,))
463463
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
464464
hmc_state = init_kernel(init_params,
465465
trajectory_length=1.,

0 commit comments

Comments
 (0)