3232
3333def log_density (model , model_args , model_kwargs , params , skip_dist_transforms = False ):
3434 """
35- Computes log of joint density for the model given latent values ``params``.
35+ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
36+ latent values ``params``.
3637
3738 :param model: Python callable containing NumPyro primitives.
3839 :param tuple model_args: args provided to the model.
39- :param dict model_kwargs` : kwargs provided to the model.
40+ :param dict model_kwargs: kwargs provided to the model.
4041 :param dict params: dictionary of current parameter values keyed by site
4142 name.
4243 :param bool skip_dist_transforms: whether to compute log probability of a site
@@ -76,8 +77,9 @@ def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=Fa
7677
7778def transform_fn (transforms , params , invert = False ):
7879 """
79- Callable that applies a transformation from the `transforms` dict to values in the
80- `params` dict and returns the transformed values keyed on the same names.
80+ (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms`
81+ dict to values in the `params` dict and returns the transformed values keyed on
82+ the same names.
8183
8284 :param transforms: Dictionary of transforms keyed by names. Names in
8385 `transforms` and `params` should align.
@@ -93,17 +95,18 @@ def transform_fn(transforms, params, invert=False):
9395
9496def constrain_fn (model , transforms , model_args , model_kwargs , params ):
9597 """
96- Gets value at each latent site in `model` given unconstrained parameters `params`.
97- The `transforms` is used to transform these unconstrained parameters to base values
98- of the corresponding priors in `model`. If a prior is a transformed distribution,
99- the corresponding base value lies in the support of base distribution. Otherwise,
100- the base value lies in the support of the distribution.
98+ (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given
99+ unconstrained parameters `params`. The `transforms` is used to transform these
100+ unconstrained parameters to base values of the corresponding priors in `model`.
101+ If a prior is a transformed distribution, the corresponding base value lies in
102+ the support of base distribution. Otherwise, the base value lies in the support
103+ of the distribution.
101104
102105 :param model: a callable containing NumPyro primitives.
103- :param tuple model_args: args provided to the model.
104- :param dict model_kwargs: kwargs provided to the model.
105106 :param dict transforms: dictionary of transforms keyed by names. Names in
106107 `transforms` and `params` should align.
108+ :param tuple model_args: args provided to the model.
109+ :param dict model_kwargs: kwargs provided to the model.
107110 :param dict params: dictionary of unconstrained values keyed by site
108111 names.
109112 :return: `dict` of transformed params.
@@ -116,16 +119,16 @@ def constrain_fn(model, transforms, model_args, model_kwargs, params):
116119
117120def potential_energy (model , inv_transforms , model_args , model_kwargs , params ):
118121 """
119- Computes potential energy of a model given unconstrained params.
122+ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params.
120123 The `inv_transforms` is used to transform these unconstrained parameters to base values
121124 of the corresponding priors in `model`. If a prior is a transformed distribution,
122125 the corresponding base value lies in the support of base distribution. Otherwise,
123126 the base value lies in the support of the distribution.
124127
125128 :param model: a callable containing NumPyro primitives.
126- :param tuple model_args: args provided to the model.
127- :param dict model_kwargs`: kwargs provided to the model.
128129 :param dict inv_transforms: dictionary of transforms keyed by names.
130+ :param tuple model_args: args provided to the model.
131+ :param dict model_kwargs: kwargs provided to the model.
129132 :param dict params: unconstrained parameters of `model`.
130133 :return: potential energy given unconstrained parameters.
131134 """
@@ -268,8 +271,11 @@ def init_to_value(values):
268271 return partial (_init_to_value , values = values )
269272
270273
271- def find_valid_initial_params (rng_key , model , * model_args , init_strategy = init_to_uniform (),
272- param_as_improper = False , ** model_kwargs ):
274+ def find_valid_initial_params (rng_key , model ,
275+ init_strategy = init_to_uniform (),
276+ param_as_improper = False ,
277+ model_args = (),
278+ model_kwargs = None ):
273279 """
274280 (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial
275281 valid unconstrained value for all the parameters. This function also returns an
@@ -281,11 +287,11 @@ def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to
281287 sample from the prior. The returned `init_params` will have the
282288 batch shape ``rng_key.shape[:-1]``.
283289 :param model: Python callable containing Pyro primitives.
284- :param `*model_args`: args provided to the model.
285290 :param callable init_strategy: a per-site initialization function.
286291 :param bool param_as_improper: a flag to decide whether to consider sites with
287292 `param` statement as sites with improper priors.
288- :param `**model_kwargs`: kwargs provided to the model.
293+ :param tuple model_args: args provided to the model.
294+ :param dict model_kwargs: kwargs provided to the model.
289295 :return: tuple of (`init_params`, `is_valid`).
290296 """
291297 init_strategy = jax .partial (init_strategy , skip_param = not param_as_improper )
@@ -416,8 +422,11 @@ def constrain_fun(*args, **kwargs):
416422 return potential_fn , constrain_fun
417423
418424
419- def initialize_model (rng_key , model , * model_args , init_strategy = init_to_uniform (),
420- dynamic_args = False , ** model_kwargs ):
425+ def initialize_model (rng_key , model ,
426+ init_strategy = init_to_uniform (),
427+ dynamic_args = False ,
428+ model_args = (),
429+ model_kwargs = None ):
421430 """
422431 (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
423432 and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
@@ -427,30 +436,33 @@ def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(
427436 sample from the prior. The returned `init_params` will have the
428437 batch shape ``rng_key.shape[:-1]``.
429438 :param model: Python callable containing Pyro primitives.
430- :param `*model_args`: args provided to the model.
431439 :param callable init_strategy: a per-site initialization function.
432440 See :ref:`init_strategy` section for available functions.
433441 :param bool dynamic_args: if `True`, the `potential_fn` and
434442 `constraints_fn` are themselves dependent on model arguments.
435443 When provided a `*model_args, **model_kwargs`, they return
436444 `potential_fn` and `constraints_fn` callables, respectively.
437- :param `**model_kwargs`: kwargs provided to the model.
445+ :param tuple model_args: args provided to the model.
446+ :param dict model_kwargs: kwargs provided to the model.
438447 :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`),
439448 `init_params` are values from the prior used to initiate MCMC,
440449 `constrain_fn` is a callable that uses inverse transforms
441450 to convert unconstrained HMC samples to constrained values that
442451 lie within the site's support.
443452 """
453+ if model_kwargs is None :
454+ model_kwargs = {}
444455 potential_fun , constrain_fun = get_potential_fn (rng_key if rng_key .ndim == 1 else rng_key [0 ],
445456 model ,
446457 dynamic_args = dynamic_args ,
447458 model_args = model_args ,
448459 model_kwargs = model_kwargs )
449460
450- init_params , is_valid = find_valid_initial_params (rng_key , model , * model_args ,
461+ init_params , is_valid = find_valid_initial_params (rng_key , model ,
451462 init_strategy = init_strategy ,
452463 param_as_improper = True ,
453- ** model_kwargs )
464+ model_args = model_args ,
465+ model_kwargs = model_kwargs )
454466
455467 if not_jax_tracer (is_valid ):
456468 if device_get (~ np .all (is_valid )):
@@ -559,11 +571,8 @@ def get_samples(self, rng_key, *args, **kwargs):
559571
560572def log_likelihood (model , posterior_samples , * args , ** kwargs ):
561573 """
562- Returns log likelihood at observation nodes of model, given samples of all latent variables.
563-
564- .. warning::
565- The interface for the `log_likelihood` function is experimental, and
566- might change in the future.
574+ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model,
575+ given samples of all latent variables.
567576
568577 :param model: Python callable containing Pyro primitives.
569578 :param dict posterior_samples: dictionary of samples from the posterior.
0 commit comments