Skip to content

Initialize numpyro NUTS with SVI #80

@stanbiryukov

Description

@stanbiryukov

Hi there, great initial work wrapping numpyro and pyro into a more user friendly interface! I'm having an issue with a few simple models where the numpyro backend gives me the following error:
Cannot find valid initial parameters. Please check your model again.
It seems to occur when there are many predictors in the formula.

Pyro sampling and SVI both work fine for this model with the default Cauchy beta priors. Any thoughts on better initializing numpyro NUTS with SVI or perhaps using maximum a posteriori estimates?
It's tough to figure out what exactly is causing MCMC to immediately fail but I'm assuming it's the initial starting values. Full traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-185-9da937b2eeba> in <module>
----> 1 fit = model.fit(backend=numpyro, seed=8877, iter=1000, warmup=500)

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in fit(self, algo, **kwargs)
    173         """
    174         assert algo in ['prior', 'nuts', 'svi']
--> 175         return getattr(self, algo)(**kwargs)
    176 
    177     def nuts(self, iter=10, warmup=None, num_chains=1, seed=None, backend=numpyro_backend):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in nuts(self, iter, warmup, num_chains, seed, backend)
    200         """
    201         warmup = iter // 2 if warmup is None else warmup
--> 202         return self.run_algo('nuts', backend, iter, warmup, num_chains, seed)
    203 
    204     def svi(self, iter=10, num_samples=10, seed=None, backend=pyro_backend, **kwargs):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, backend, df, *args, **kwargs)
    154         data = self.model.encode(df) if df is not None else self.data
    155         assets_wrapper = self.model.gen(backend)
--> 156         return assets_wrapper.run_algo(name, data_from_numpy(backend, data), *args, **kwargs)
    157 
    158     def fit(self, algo='nuts', **kwargs):

/opt/conda/lib/python3.6/site-packages/brmp/__init__.py in run_algo(self, name, data, *args, **kwargs)
     75 
     76     def run_algo(self, name, data, *args, **kwargs):
---> 77         samples = getattr(self.backend, name)(data, self.assets, *args, **kwargs)
     78         return Fit(self.model.formula, self.model.metadata,
     79                    self.model.contrasts, data,

/opt/conda/lib/python3.6/site-packages/brmp/numpyro_backend.py in nuts(data, assets, iter, warmup, num_chains, seed)
     86     # `num_chains` > 1 to achieve parallel chains.
     87     mcmc = MCMC(kernel, warmup, iter, num_chains=num_chains)
---> 88     mcmc.run(rng, **data)
     89     samples = mcmc.get_samples(group_by_chain=True)
     90 

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, collect_warmup, init_params, *args, **kwargs)
    639         if self.num_chains == 1:
    640             states_flat = self._single_chain_mcmc((rng_key, init_params), collect_fields, collect_warmup,
--> 641                                                   args, kwargs)
    642             states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
    643         else:

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, collect_fields, collect_warmup, args, kwargs)
    582         rng_key, init_params = init
    583         init_state, constrain_fn = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 584                                                      model_args=args, model_kwargs=kwargs)
    585         if self.constrain_fn is None:
    586             constrain_fn = identity if constrain_fn is None else constrain_fn

/opt/conda/lib/python3.6/site-packages/numpyro/infer/mcmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    409                 rng_key, rng_key_init_model = np.swapaxes(vmap(random.split)(rng_key), 0, 1)
    410             init_params_, self.potential_fn, constrain_fn = initialize_model(
--> 411                 rng_key_init_model, self.model, *model_args, init_strategy=self.init_strategy, **model_kwargs)
    412             if init_params is None:
    413                 init_params = init_params_

/opt/conda/lib/python3.6/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, *model_args, **model_kwargs)
    413     if not_jax_tracer(is_valid):
    414         if device_get(~np.all(is_valid)):
--> 415             raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
    416     return init_params, potential_fn, constrain_fun
    417 

RuntimeError: Cannot find valid initial parameters. Please check your model again.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions