-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pin jaxlib to use cuda120 build with cuda-nvcc and update README note #549
Changes from 6 commits
c3f34a5
f7c2bbf
d62d7f2
1cf6637
c43e2bf
5049f6c
9c5dd55
a5fb3e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
import pytest | ||
import importlib | ||
import sys | ||
import os | ||
import logging | ||
|
||
import pytest | ||
|
||
packages = [ | ||
# machine learning stuff | ||
|
@@ -23,10 +24,39 @@ def test_start(): | |
assert os.environ['PANGEO_ENV'] == 'ml-notebook' | ||
|
||
def test_jax_tf_together(): | ||
""" sometimes this impport fails due to sharing private symbols | ||
complicated longer story, but it is better to | ||
ensure they can coexist | ||
""" | ||
import tensorflow, jax | ||
Sometimes this import fails due to sharing private symbols. | ||
Complicated long story, but it is better to ensure they can coexist | ||
""" | ||
import jax | ||
import tensorflow | ||
assert int(tensorflow.__version__[0]) >= 2 | ||
assert int(jax.__version__[0]) >= 0 | ||
|
||
|
||
def test_jax_random_number_generator(): | ||
""" | ||
Ensure that initializing a random number generator on JaX works. | ||
|
||
Regression test for checking that JaX and cuda-nvcc are installed and compatible on | ||
GPU devices, see https://github.com/pangeo-data/pangeo-docker-images/issues/438. | ||
""" | ||
import jax | ||
import numpy as np | ||
from jax import random | ||
|
||
# Test running on CPU | ||
with jax.default_device(jax.devices("cpu")[0]): | ||
key = random.key(seed=42) | ||
x = random.normal(key=key) | ||
np.testing.assert_allclose(x, -0.18471177) | ||
|
||
# Test running on GPU (need to run locally) | ||
try: | ||
gpu_device = jax.devices("gpu")[0] | ||
with jax.default_device(gpu_device): | ||
key = random.key(seed=24) | ||
x = random.normal(key=key) | ||
np.testing.assert_allclose(x, -1.168644) | ||
except RuntimeError: # Unknown backend: 'gpu' requested | ||
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device") | ||
Comment on lines
+54
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't work on GitHub Actions without a GPU, but I ran this locally using:
The tests passed on my computer with an NVIDIA GPU, so should be ok I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there were Azure credits for Pangeo we could add GPU runners https://docs.github.com/en/actions/using-github-hosted-runners/about-larger-runners/managing-larger-runners ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pin on using a
cuda120*
build of jaxlib might actually need to go away if we're planning on adding multi-arch builds ofml-notebook
(see #399 (review) for context), but will deal with this later since the same pin is present on Tensorflow.