Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equinox models integration #1709

Open
juanitorduz opened this issue Dec 27, 2023 · 10 comments · May be fixed by #2005
Open

Equinox models integration #1709

juanitorduz opened this issue Dec 27, 2023 · 10 comments · May be fixed by #2005
Labels
enhancement New feature or request

Comments

@juanitorduz
Copy link
Contributor

It would be nice to have equinox_module and random_equinox_module model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.

Would this be a good addition?

I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.

@fehiepsi
Copy link
Member

Hi @juanitorduz, if you need this feature, please feel free to put it in contrib.module. I guess you can mimic random_flax_module for an implementation. If you need to clarify something, please leave a comment in this issue thread.

@juanitorduz
Copy link
Contributor Author

Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months!

@danielward27
Copy link
Contributor

I've been using this in my package flowjax for registering parameters for equinox modules.


def register_params(
    name: str,
    model: PyTree,
    filter_spec: Callable | PyTree = eqx.is_inexact_array,
):
    """Register numpyro params for an arbitrary pytree.

    This partitions the parameters and static components, registers the parameters using
    numpyro.param, then recombines them. This should be called from within an inference
    context to have an effect, e.g. within a numpyro model or guide function.

    Args:
        name: Name for the parameter set.
        model: The pytree (e.g. an equinox module, flowjax distribution/bijection).
        filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a
            callable `leaf -> bool`, or a PyTree with prefix structure matching `dist`
            with True/False values. Defaults to `eqx.is_inexact_array`.

    """
    params, static = eqx.partition(model, filter_spec)
    if callable(params):
        # Wrap to avoid special handling of callables by numpyro. Numpyro expects a
        # callable to be used for lazy initialization, whereas in our case it is likely
        # a callable module we wish to train.
        params = numpyro.param(name, lambda _: params)
    else:
        params = numpyro.param(name, params)
    return eqx.combine(params, static)

It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use eqx.combine(trained_params, model) to retrieve the trained module.

@juanitorduz
Copy link
Contributor Author

Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February)

@kylejcaron
Copy link
Contributor

I'm happy to pick this up, I have a current project that this would help for and can mirror some aspects of @juanitorduz's nnx implementation. Thanks for the register_params function @danielward27!

A big open question I have: It looks like the other numpyro contrib modules organize params as dictionaries, while this is a pytree module. Curious how this will impact passing in priors in a random_equinox_module which is done via dictionary in the other modules. If anyone has any ideas here please let me know!

I left 3 examples of models with very different architectures below. example 2, where there are multiple nodes with the same name but we only want a prior on one of them seems to be the most difficult edge case:

  • example 1: maybe we want to apply a dist.Cauchy() prior to each bias term in each MLP linear layer. prior={"Linear.bias": dist.Cauchy()}
  • example 2: maybe we want a more specific prior on the first Linear layer of an MLP (for examples sake). prior={"Linear_0.weight": dist.Normal(0, 0.01)}
  • example 3: maybe we want different priors for each of Q,K,V in Multi head attention prior={"query.weight": dist.Normal(0, 0.01), "key.weight": dist.Normal(0, 0.01)}

I have a minimal example working locally:

import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import equinox as eqx
import jax

import matplotlib.pyplot as plt


rng = np.random.default_rng(99)
N = 1000
X = rng.normal(0, 1.5, size=(N,3))
y = rng.normal(1.2 + X @ np.array([0.4, -0.1, 0.22]), 0.1)


def register_params(name, model,filter_spec = eqx.is_inexact_array):
    """from daniel ward https://github.com/pyro-ppl/numpyro/issues/1709"""
    params, static = eqx.partition(model, filter_spec)
    if callable(params):
        # Wrap to avoid special handling of callables by numpyro. Numpyro expects a
        # callable to be used for lazy initialization, whereas in our case it is likely
        # a callable module we wish to train.
        params = numpyro.param(name, lambda _: params)
    else:
        params = numpyro.param(name, params)
    return eqx.combine(params, static)

def equinox_module(name, nn_module, *args, **kwargs):
   """minimal implementation"""
    rng_key = numpyro.prng_key()
    nn_module = nn_module(key=rng_key, *args, **kwargs)
    nn_module = register_params(name + '$params', nn_module)
    return jax.vmap(nn_module)

def model(X, y=None):
    linear_model = equinox_module("linear", eqx.nn.Linear, in_features=X.shape[-1], out_features=1)
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    
    # Use the model to make predictions
    with numpyro.plate("data", X.shape[0]):
        mu = numpyro.deterministic("mu", linear_model(X).squeeze(-1))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y)


# Fit model and check predictions
guide = AutoNormal(model, init_loc_fn=numpyro.infer.init_to_median())
svi = SVI(model, guide, numpyro.optim.Adam(0.05), Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, X=X, y=y)

pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=1000)
preds = pred_func(jax.random.PRNGKey(0), X)

plt.hist(preds['mu'].mean(0) - y, bins=25)
plt.show()

@juanitorduz
Copy link
Contributor Author

In the nnx case we use a lot of nnx helper functions converting from state to dicts (#1990) I wonder if any of those could be reused.

@fehiepsi will probably have better ideas of how to move this forward! :)

Cool stuff @kylejcaron 🚀

@kylejcaron
Copy link
Contributor

kylejcaron commented Mar 10, 2025

In the nnx case we use a lot of nnx helper functions converting from state to dicts (#1990) I wonder if any of those could be reused.

@fehiepsi will probably have better ideas of how to move this forward! :)

Cool stuff @kylejcaron 🚀

I'll take a look thank you!

And leaving a design note here:

It looks like pytree key paths might be the solution for priors naming, but I dont love how they choose names

For example this:
[jtu.keystr(path) for path, vals in jtu.tree_flatten_with_path(mlp)[0]]

returns the following which gets the job done but I wouldve thought that working with the name Linear would be better instead of layers

['.layers[0].weight',
 '.layers[0].bias',
 '.layers[1].weight',
 '.layers[1].bias',
 '.layers[2].weight',
 '.layers[2].bias',
 '.activation',
 '.final_activation']

If there are any pytree experts out there that can weigh in on mapping a dictionary of priors to update pytree nodes let me know (examples in my last post)!

@juanitorduz
Copy link
Contributor Author

Why don't you open a PR and take it from there (this is how we started 💪 )

@fehiepsi
Copy link
Member

nnx operates on dict, so it is not complicated to reason about the paths. I'm not sure about equinox - i can't find similar model surgery operators in the docs.

@kylejcaron
Copy link
Contributor

nnx operates on dict, so it is not complicated to reason about the paths. I'm not sure about equinox - i can't find similar model surgery operators in the docs.

just found an example! patrick-kidger/equinox#657

I'll open up a PR soon using this approach but if we find anything better for naming we can pivot.

Thank you both!

@kylejcaron kylejcaron linked a pull request Mar 11, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants