Skip to content

Commit 09649eb

Browse files
authored
Pin jaxlib to use cuda120 build with cuda-nvcc and update README note (#549)
* Pin jaxlib to use cuda120 build The conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build. Adding a explicit pin here to pick up the cuda120 build version. * Update note on cuda-nvcc workaround * Add unit test for jax random number generator Regression test to ensure that JaX and cuda-nvcc are installed and compatible on GPU devices. * Wrap jax test on GPU in try-except block to pass on GH Actions CI Since we're using a version of GitHub Actions without a GPU, the JaX random number generator test on GPU cannot be properly tested, so wrapping the check in a try-except block. Also tidied up some import statements and docstrings in the test file. * Update note in README to say cuda-nvcc workaround is no longer needed Since `cuda-nvcc` is installed with jaxlib's cuda120 builds, the workaround to conda install cuda-nvcc should not be needed anymore. * Add link to XLA (Accelerated Linear Algebra) page Link to https://openxla.org/xla so that people know what XLA is.
1 parent cbcd24e commit 09649eb

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi
5454
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146)
5555
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt
5656
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08
57-
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
57+
* As of 2024.05.21, the `ml-notebook` and `pytorch-notebook` docker images contain
58+
machine learning libraries built with CUDA 12. In previous versions, we have suggested
59+
`ml-notebook` users to install `cuda-nvcc` manually to obtain JAX and/or TensorFlow
60+
with [XLA](https://openxla.org/xla) optimization, but this workaround should no longer
61+
be needed if you are using `ml-notebook` 2024.06.02 or newer that comes with
62+
`cuda-nvcc` pre-installed.
5863
* There used to be a `pangeo/forge` image, built for use with [pangeo-forge](https://pangeo-forge.org/). It is
5964
no longer actively maintained or used, but you can still use the [historical tags](https://quay.io/repository/pangeo/forge?tab=tags)
60-
if you wish.
65+
if you wish.

ml-notebook/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- cuda-version>=12.0
99
- flax>=0.8.0
1010
- jax
11+
- jaxlib>=0.4.23=cuda120*
1112
- jupyterlab-nvdashboard
1213
- keras-cv
1314
- tensorflow>=2.15.0=cuda120*

tests/test_ml-notebook.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import pytest
21
import importlib
3-
import sys
42
import os
3+
import logging
4+
5+
import pytest
56

67
packages = [
78
# machine learning stuff
@@ -23,10 +24,39 @@ def test_start():
2324
assert os.environ['PANGEO_ENV'] == 'ml-notebook'
2425

2526
def test_jax_tf_together():
26-
""" sometimes this impport fails due to sharing private symbols
27-
complicated longer story, but it is better to
28-
ensure they can coexist
2927
"""
30-
import tensorflow, jax
28+
Sometimes this import fails due to sharing private symbols.
29+
Complicated long story, but it is better to ensure they can coexist
30+
"""
31+
import jax
32+
import tensorflow
3133
assert int(tensorflow.__version__[0]) >= 2
3234
assert int(jax.__version__[0]) >= 0
35+
36+
37+
def test_jax_random_number_generator():
38+
"""
39+
Ensure that initializing a random number generator on JaX works.
40+
41+
Regression test for checking that JaX and cuda-nvcc are installed and compatible on
42+
GPU devices, see https://github.com/pangeo-data/pangeo-docker-images/issues/438.
43+
"""
44+
import jax
45+
import numpy as np
46+
from jax import random
47+
48+
# Test running on CPU
49+
with jax.default_device(jax.devices("cpu")[0]):
50+
key = random.key(seed=42)
51+
x = random.normal(key=key)
52+
np.testing.assert_allclose(x, -0.18471177)
53+
54+
# Test running on GPU (need to run locally)
55+
try:
56+
gpu_device = jax.devices("gpu")[0]
57+
with jax.default_device(gpu_device):
58+
key = random.key(seed=24)
59+
x = random.normal(key=key)
60+
np.testing.assert_allclose(x, -1.168644)
61+
except RuntimeError: # Unknown backend: 'gpu' requested
62+
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device")

0 commit comments

Comments
 (0)