-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox models integration #1709
Comments
Hi @juanitorduz, if you need this feature, please feel free to put it in |
Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months! |
I've been using this in my package flowjax for registering parameters for equinox modules.
It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use |
Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February) |
I'm happy to pick this up, I have a current project that this would help for and can mirror some aspects of @juanitorduz's nnx implementation. Thanks for the A big open question I have: It looks like the other numpyro contrib modules organize params as dictionaries, while this is a pytree module. Curious how this will impact passing in priors in a I left 3 examples of models with very different architectures below. example 2, where there are multiple nodes with the same name but we only want a prior on one of them seems to be the most difficult edge case:
I have a minimal example working locally: import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import equinox as eqx
import jax
import matplotlib.pyplot as plt
rng = np.random.default_rng(99)
N = 1000
X = rng.normal(0, 1.5, size=(N,3))
y = rng.normal(1.2 + X @ np.array([0.4, -0.1, 0.22]), 0.1)
def register_params(name, model,filter_spec = eqx.is_inexact_array):
"""from daniel ward https://github.com/pyro-ppl/numpyro/issues/1709"""
params, static = eqx.partition(model, filter_spec)
if callable(params):
# Wrap to avoid special handling of callables by numpyro. Numpyro expects a
# callable to be used for lazy initialization, whereas in our case it is likely
# a callable module we wish to train.
params = numpyro.param(name, lambda _: params)
else:
params = numpyro.param(name, params)
return eqx.combine(params, static)
def equinox_module(name, nn_module, *args, **kwargs):
"""minimal implementation"""
rng_key = numpyro.prng_key()
nn_module = nn_module(key=rng_key, *args, **kwargs)
nn_module = register_params(name + '$params', nn_module)
return jax.vmap(nn_module)
def model(X, y=None):
linear_model = equinox_module("linear", eqx.nn.Linear, in_features=X.shape[-1], out_features=1)
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
# Use the model to make predictions
with numpyro.plate("data", X.shape[0]):
mu = numpyro.deterministic("mu", linear_model(X).squeeze(-1))
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
# Fit model and check predictions
guide = AutoNormal(model, init_loc_fn=numpyro.infer.init_to_median())
svi = SVI(model, guide, numpyro.optim.Adam(0.05), Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, X=X, y=y)
pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=1000)
preds = pred_func(jax.random.PRNGKey(0), X)
plt.hist(preds['mu'].mean(0) - y, bins=25)
plt.show() |
In the nnx case we use a lot of nnx helper functions converting from state to dicts (#1990) I wonder if any of those could be reused. @fehiepsi will probably have better ideas of how to move this forward! :) Cool stuff @kylejcaron 🚀 |
I'll take a look thank you! And leaving a design note here: It looks like pytree key paths might be the solution for priors naming, but I dont love how they choose names For example this: returns the following which gets the job done but I wouldve thought that working with the name
If there are any pytree experts out there that can weigh in on mapping a dictionary of priors to update pytree nodes let me know (examples in my last post)! |
Why don't you open a PR and take it from there (this is how we started 💪 ) |
nnx operates on dict, so it is not complicated to reason about the paths. I'm not sure about equinox - i can't find similar model surgery operators in the docs. |
just found an example! patrick-kidger/equinox#657 I'll open up a PR soon using this approach but if we find anything better for naming we can pivot. Thank you both! |
It would be nice to have
equinox_module
andrandom_equinox_module
model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.Would this be a good addition?
I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.
The text was updated successfully, but these errors were encountered: