Description
Dear iminuit developers,
Thank you very much for this great package!
I am the author of evermore - a pure JAX based package to build binned likelihoods in HEP. Here, one can construct arbitrary PyTrees of nuisances parameters and use them in their loss function. It is highly efficient to be able to group parameters into arrays to modify bin content in a vectorized fashion (especially for barlow-beeston[-lite]). Users have some parameters that are just single values (e.g. a single cross section uncertainty), and some that are represented as arrays (e.g. barlow-beeston statistical uncertainties).
Thus, I'd like to ask if it is possible to add the feature to use mixtures of different sized arrays (and floats), e.g.:
import numpy as np
from iminuit import Minuit
def fun(x, c):
return c + x[0]**2 + x[1]**4
Minuit(fun, x=np.ones(2), c=np.ones(1))
This is in particular handy when working with JAX loss functions where the parameters (x
, c
) are often in practise a nested PyTree of jax.Arrays
of arbitrary size:
import jax.numpy as jnp
import jax.tree_util as jtu
from functools import partial
params = {"x": jnp.ones(2), "c": jnp.ones(1)}
def fun(params):
x = params["x"]
c = params["c"]
return c + x[0]**2 + x[1]**4
def wrapped_fun(flat_params, treedef):
params = jtu.tree_unflatten(treedef, flat_params)
return fun(params)
flat_params, treedef = jtu.tree_flatten(params)
Minuit(partial(wrapped_fun, treedef=treedef), flat_params, name=treedef.node_data()[1])
In this example params
is just a simple dictionary, but this would also work with arbitrary (nested) PyTree structures if iminuit
could support arrays of arbitrary size for the loss function kwargs.
Best, Peter
PS: JAX optimisers, i.e. optax
, can minimise directly w.r.t these PyTree structures. The minimiser returns the original PyTree structure, but its leafs contain the fitted parameter values. Here, one does not even need the step of the wrapped_fun
to convert any PyTree to a list of arguments.