Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin jaxlib to use cuda120 build with cuda-nvcc and update README note #549

Merged
merged 8 commits into from
Jul 1, 2024
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146)
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
* As of 2024.05.21, the `ml-notebook` and `pytorch-notebook` docker images contain
machine learning libraries built with CUDA 12. In previous versions, we have suggested
`ml-notebook` users to install `cuda-nvcc` manually to obtain JAX and/or TensorFlow
with [XLA](https://openxla.org/xla) optimization, but this workaround should no longer
be needed if you are using `ml-notebook` 2024.06.02 or newer that comes with
`cuda-nvcc` pre-installed.
* There used to be a `pangeo/forge` image, built for use with [pangeo-forge](https://pangeo-forge.org/). It is
no longer actively maintained or used, but you can still use the [historical tags](https://quay.io/repository/pangeo/forge?tab=tags)
if you wish.
if you wish.
1 change: 1 addition & 0 deletions ml-notebook/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- cuda-version>=12.0
- flax>=0.8.0
- jax
- jaxlib>=0.4.23=cuda120*
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pin on using a cuda120* build of jaxlib might actually need to go away if we're planning on adding multi-arch builds of ml-notebook (see #399 (review) for context), but will deal with this later since the same pin is present on Tensorflow.

- jupyterlab-nvdashboard
- keras-cv
- tensorflow>=2.15.0=cuda120*
42 changes: 36 additions & 6 deletions tests/test_ml-notebook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import importlib
import sys
import os
import logging

import pytest

packages = [
# machine learning stuff
Expand All @@ -23,10 +24,39 @@ def test_start():
assert os.environ['PANGEO_ENV'] == 'ml-notebook'

def test_jax_tf_together():
""" sometimes this impport fails due to sharing private symbols
complicated longer story, but it is better to
ensure they can coexist
"""
import tensorflow, jax
Sometimes this import fails due to sharing private symbols.
Complicated long story, but it is better to ensure they can coexist
"""
import jax
import tensorflow
assert int(tensorflow.__version__[0]) >= 2
assert int(jax.__version__[0]) >= 0


def test_jax_random_number_generator():
"""
Ensure that initializing a random number generator on JaX works.
Regression test for checking that JaX and cuda-nvcc are installed and compatible on
GPU devices, see https://github.com/pangeo-data/pangeo-docker-images/issues/438.
"""
import jax
import numpy as np
from jax import random

# Test running on CPU
with jax.default_device(jax.devices("cpu")[0]):
key = random.key(seed=42)
x = random.normal(key=key)
np.testing.assert_allclose(x, -0.18471177)

# Test running on GPU (need to run locally)
try:
gpu_device = jax.devices("gpu")[0]
with jax.default_device(gpu_device):
key = random.key(seed=24)
x = random.normal(key=key)
np.testing.assert_allclose(x, -1.168644)
except RuntimeError: # Unknown backend: 'gpu' requested
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device")
Comment on lines +54 to +62
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work on GitHub Actions without a GPU, but I ran this locally using:

mamba create -n ml-notebook --file https://raw.githubusercontent.com/pangeo-data/pangeo-docker-images/2024.06.02/ml-notebook/conda-linux-64.lock
mamba activate ml-notebook
pytest --verbose tests/test_ml-notebook.py

The tests passed on my computer with an NVIDIA GPU, so should be ok I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.