-
Notifications
You must be signed in to change notification settings - Fork 11
gh-502: add array API tests and typing for grf._transformations
#614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
066c47f
to
cdda0e6
Compare
cdda0e6
to
9d88886
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. A couple of small comments.
np.testing.assert_array_almost_equal_nulp(glass.grf.icorr(t1, t2, y), x, nulp=5) | ||
np.testing.assert_array_almost_equal_nulp(glass.grf.icorr(t1, t2, y), x, nulp=8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea why this has changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this with JAX with anything less than 8 -
============================================================================== FAILURES ===============================================================================
__________________________________________________________________ test_sqnormal_sqnormal[jax.numpy] __________________________________________________________________
xp = <module 'jax.numpy' from '/Users/saransh/Code/UCL/glass/.env/lib/python3.13/site-packages/jax/numpy/__init__.py'>
urng = <glass.jax.Generator object at 0x111a3ffd0>
def test_sqnormal_sqnormal(xp: types.ModuleType, urng: UnifiedGenerator) -> None:
lam1, var1 = urng.uniform(size=2)
a1 = xp.sqrt(1 - var1)
t1 = glass.grf.SquaredNormal(a1, lam1)
lam2, var2 = urng.uniform(size=2)
a2 = xp.sqrt(1 - var2)
t2 = glass.grf.SquaredNormal(a2, lam2)
# https://arxiv.org/pdf/2408.16903, (E.7)
x = urng.random(10)
y = 2 * lam1 * lam2 * x * (x + 2 * a1 * a2)
dy = 4 * lam1 * lam2 * (x + a1 * a2)
np.testing.assert_array_equal(glass.grf.corr(t1, t2, x), y)
> np.testing.assert_array_almost_equal_nulp(glass.grf.icorr(t1, t2, y), x, nulp=5)
E AssertionError: Arrays are not equal to 5 ULP (max is 8)
tests/grf/test_transformations.py:95: AssertionError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay makes sense. Bit annoying.
Co-authored-by: Patrick J. Roddy <[email protected]>
I'm going to merge this just so we have fewer stale PRs when we resume work in September |
Description
Add array API tests and typing for
grf._transformations
. Should be okay to merge after #610.Fixes: #502
Checks