Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin jaxlib to use cuda120 build with cuda-nvcc and update README note #549

Merged
merged 8 commits into from
Jul 1, 2024
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146)
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs or newer, you should install `conda install -c nvidia cuda-nvcc==12.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be ok to remove this note, once we have cuda-nvcc pulled in as a dependency of jaxlib. Ideally, we'll add a unit test to https://github.com/pangeo-data/pangeo-docker-images/blob/master/tests/test_ml-notebook.py (maybe using the snippet from #387 (comment)) to ensure that jax works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the uninitiated, I'm not even sure what XLA is :) In light of that, could add a sentence here to give context for why this matters (or just link https://openxla.org/xla).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I didn't know what the XLA acronym stands for either 😆 I've added a link to that XLA page in commit a5fb3e9, though users shouldn't need to dig into this too much since we're using XLA-enabled builds of Tensorflow by default now.

* There used to be a `pangeo/forge` image, built for use with [pangeo-forge](https://pangeo-forge.org/). It is
no longer actively maintained or used, but you can still use the [historical tags](https://quay.io/repository/pangeo/forge?tab=tags)
if you wish.
if you wish.
1 change: 1 addition & 0 deletions ml-notebook/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- cuda-version>=12.0
- flax>=0.8.0
- jax
- jaxlib>=0.4.23=cuda120*
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pin on using a cuda120* build of jaxlib might actually need to go away if we're planning on adding multi-arch builds of ml-notebook (see #399 (review) for context), but will deal with this later since the same pin is present on Tensorflow.

- jupyterlab-nvdashboard
- keras-cv
- tensorflow>=2.15.0=cuda120*
Loading