Skip to content

Commit

Permalink
Pin jaxlib to use cuda120 build with cuda-nvcc and update README note (
Browse files Browse the repository at this point in the history
…#549)

* Pin jaxlib to use cuda120 build

The conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build. Adding a explicit pin here to pick up the cuda120 build version.

* Update note on cuda-nvcc workaround

* Add unit test for jax random number generator

Regression test to ensure that JaX and cuda-nvcc are installed and compatible on GPU devices.

* Wrap jax test on GPU in try-except block to pass on GH Actions CI

Since we're using a version of GitHub Actions without a GPU, the JaX random number generator test on GPU cannot be properly tested, so wrapping the check in a try-except block. Also tidied up some import statements and docstrings in the test file.

* Update note in README to say cuda-nvcc workaround is no longer needed

Since `cuda-nvcc` is installed with jaxlib's cuda120 builds, the workaround to conda install cuda-nvcc should not be needed anymore.

* Add link to XLA (Accelerated Linear Algebra) page

Link to https://openxla.org/xla so that people know what XLA is.
  • Loading branch information
weiji14 authored Jul 1, 2024
1 parent cbcd24e commit 09649eb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146)
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
* As of 2024.05.21, the `ml-notebook` and `pytorch-notebook` docker images contain
machine learning libraries built with CUDA 12. In previous versions, we have suggested
`ml-notebook` users to install `cuda-nvcc` manually to obtain JAX and/or TensorFlow
with [XLA](https://openxla.org/xla) optimization, but this workaround should no longer
be needed if you are using `ml-notebook` 2024.06.02 or newer that comes with
`cuda-nvcc` pre-installed.
* There used to be a `pangeo/forge` image, built for use with [pangeo-forge](https://pangeo-forge.org/). It is
no longer actively maintained or used, but you can still use the [historical tags](https://quay.io/repository/pangeo/forge?tab=tags)
if you wish.
if you wish.
1 change: 1 addition & 0 deletions ml-notebook/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- cuda-version>=12.0
- flax>=0.8.0
- jax
- jaxlib>=0.4.23=cuda120*
- jupyterlab-nvdashboard
- keras-cv
- tensorflow>=2.15.0=cuda120*
42 changes: 36 additions & 6 deletions tests/test_ml-notebook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import importlib
import sys
import os
import logging

import pytest

packages = [
# machine learning stuff
Expand All @@ -23,10 +24,39 @@ def test_start():
assert os.environ['PANGEO_ENV'] == 'ml-notebook'

def test_jax_tf_together():
""" sometimes this impport fails due to sharing private symbols
complicated longer story, but it is better to
ensure they can coexist
"""
import tensorflow, jax
Sometimes this import fails due to sharing private symbols.
Complicated long story, but it is better to ensure they can coexist
"""
import jax
import tensorflow
assert int(tensorflow.__version__[0]) >= 2
assert int(jax.__version__[0]) >= 0


def test_jax_random_number_generator():
"""
Ensure that initializing a random number generator on JaX works.
Regression test for checking that JaX and cuda-nvcc are installed and compatible on
GPU devices, see https://github.com/pangeo-data/pangeo-docker-images/issues/438.
"""
import jax
import numpy as np
from jax import random

# Test running on CPU
with jax.default_device(jax.devices("cpu")[0]):
key = random.key(seed=42)
x = random.normal(key=key)
np.testing.assert_allclose(x, -0.18471177)

# Test running on GPU (need to run locally)
try:
gpu_device = jax.devices("gpu")[0]
with jax.default_device(gpu_device):
key = random.key(seed=24)
x = random.normal(key=key)
np.testing.assert_allclose(x, -1.168644)
except RuntimeError: # Unknown backend: 'gpu' requested
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device")

0 comments on commit 09649eb

Please sign in to comment.