You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Updating to jax==0.5.0 causes a number of test failures. Mostly these seem to be precision issues (arrays being equal to only four-ish decimal places, which doesn't pass the more demanding tests), but also some of the kernel gradients seem to be being flipped.
Remove the !=0.5..0 specifier from the jax dependency in pyproject.toml once done.
The text was updated successfully, but these errors were encountered:
Some of the test failures may simply be caused by random seed dependence - jax made some breaking changes to its random number generation with the release of 0.5.0, which has caused many of our tests to run on new inputs.
What's the issue?
Updating to
jax==0.5.0
causes a number of test failures. Mostly these seem to be precision issues (arrays being equal to only four-ish decimal places, which doesn't pass the more demanding tests), but also some of the kernel gradients seem to be being flipped.Remove the
!=0.5..0
specifier from thejax
dependency inpyproject.toml
once done.The text was updated successfully, but these errors were encountered: