Skip to content

Feature Request: mix arguments of different-sized arrays (& floats)? #982

Open
@pfackeldey

Description

@pfackeldey

Dear iminuit developers,

Thank you very much for this great package!

I am the author of evermore - a pure JAX based package to build binned likelihoods in HEP. Here, one can construct arbitrary PyTrees of nuisances parameters and use them in their loss function. It is highly efficient to be able to group parameters into arrays to modify bin content in a vectorized fashion (especially for barlow-beeston[-lite]). Users have some parameters that are just single values (e.g. a single cross section uncertainty), and some that are represented as arrays (e.g. barlow-beeston statistical uncertainties).

Thus, I'd like to ask if it is possible to add the feature to use mixtures of different sized arrays (and floats), e.g.:

import numpy as np
from iminuit import Minuit

def fun(x, c):
    return c + x[0]**2 + x[1]**4
    
Minuit(fun, x=np.ones(2), c=np.ones(1))

This is in particular handy when working with JAX loss functions where the parameters (x, c) are often in practise a nested PyTree of jax.Arrays of arbitrary size:

import jax.numpy as jnp
import jax.tree_util as jtu
from functools import partial

params = {"x": jnp.ones(2), "c": jnp.ones(1)}

def fun(params):
    x = params["x"]
    c = params["c"]
    return c + x[0]**2 + x[1]**4

def wrapped_fun(flat_params, treedef):
    params = jtu.tree_unflatten(treedef, flat_params)
    return fun(params)
    
flat_params, treedef = jtu.tree_flatten(params)

Minuit(partial(wrapped_fun, treedef=treedef), flat_params, name=treedef.node_data()[1])

In this example params is just a simple dictionary, but this would also work with arbitrary (nested) PyTree structures if iminuit could support arrays of arbitrary size for the loss function kwargs.

Best, Peter

PS: JAX optimisers, i.e. optax, can minimise directly w.r.t these PyTree structures. The minimiser returns the original PyTree structure, but its leafs contain the fitted parameter values. Here, one does not even need the step of the wrapped_fun to convert any PyTree to a list of arguments.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions