-
Notifications
You must be signed in to change notification settings - Fork 44
Closed
Description
Issue
Problem Description
Gaussian and GaussLegendre integrators don't work with jax backend, they throw
TypeError: prod requires ndarray or scalar arguments, got <class 'list'> at position 0.
and this error comes from
torchquad/torchquad/integration/gaussian.py
Lines 67 to 69 in bf8ed5c
| return anp.prod( | |
| anp.meshgrid(*([weights] * dim), like=backend), axis=0 | |
| ).ravel() |
What Needs to be Done
I propose to convert the list returned by anp.meshgrid into an array with anp.stack. This fix works for me in jax, however we should first check that this doesn't cause problems in other backends.
return anp.prod(
anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0
).ravel()How Can It Be Tested or Reproduced
Using jax-0.4.35 and torchquad 0.4.0 run
import jax
import jax.numpy as jnp
from torchquad import set_up_backend, MonteCarlo, Gaussian, GaussLegendre
set_up_backend(backend="jax")
@jax.jit
def some_function(x):
return jnp.power(x[:, 0] - x[:, 1], 2)
g = Gaussian()
# It also fails with GaussLegendre
# g = GaussLegendre()
integral_value = g.integrate(
lambda x: some_function(x),
dim=2,
N=10000,
integration_domain=jnp.asarray([[-1.0, 1.0], [-1.0, 1.0]]),
)Metadata
Metadata
Assignees
Labels
No labels