-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pin jaxlib to use cuda120 build with cuda-nvcc and update README note #549
Conversation
The conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build. Adding a explicit pin here to pick up the cuda120 build version.
/condalock |
Pulling in INFO:conda_lock.conda_lock:Using virtual packages from virtual-packages.yml
Locking dependencies for ['linux-64']...
INFO:conda_lock.conda_solver:linux-64 using specs ['cuda-version >=12.0', 'flax >=0.8.0', 'jax', 'jaxlib >=0.4.23 cuda120*', 'jupyterlab-nvdashboard', 'keras-cv', 'tensorflow >=2.15.0 cuda120*', 'adlfs', 'argopy', 'awscli', 'black', 'boto3', 'bottleneck', 'cartopy', 'cdsapi', 'cfgrib', 'cf_xarray', 'ciso', 'cmocean', 'dask-ml', 'datashader', 'descartes', 'earthaccess', 'eofs', 'erddapy', 'esmpy', 'fastjmd95', 'flox', 'fsspec', 'gcm_filters', 'gcsfs', 'gh', 'gh-scoped-creds', 'geocube', 'geopandas', 'geopy', 'geoviews-core', 'git-lfs', 'gsw', 'h5netcdf', 'h5py', 'holoviews', 'hvplot', 'intake', 'intake-esm', 'intake-geopandas', 'intake-stac', 'intake-xarray', 'ipdb', 'ipykernel', 'ipyleaflet', 'ipytree', 'ipywidgets', 'jupyterlab_code_formatter', 'jupyterlab-git', 'jupyterlab-lsp', 'jupyterlab-myst', 'jupyter-panel-proxy', 'jupyter-resource-usage', 'kerchunk', 'line_profiler', 'lxml', 'lz4', 'matplotlib-base', 'memory_profiler', 'metpy', 'nb_conda_kernels', 'nbstripout', 'nc-time-axis', 'netcdf4', 'numbagg', 'numcodecs', 'numpy', 'numpy_groupies', 'odc-stac', 'pandas', 'panel', 'parcels', 'param', 'pop-tools', 'pyarrow', 'pycamhd', 'pydap', 'pystac', 'pystac-client', 'python-blosc', 'python-gist', 'python-graphviz', 'python-lsp-ruff', 'python-xxhash', 'rasterio', 'rechunker', 'rio-cogeo', 'rioxarray', 'ruff', 's3fs', 'satpy', 'scikit-image', 'scikit-learn', 'scipy', 'seaborn', 'sparse', 'snakeviz', 'stackstac', 'tiledb-py', 'timezonefinder', 'watermark', 'xarray', 'xarrayutils', 'xarray-datatree', 'xarray_leaflet', 'xarray-spatial', 'xbatcher', 'xcape', 'xclim', 'xesmf', 'xgboost', 'xgcm', 'xhistogram', 'xmip', 'xmitgcm', 'xpublish', 'xrft', 'xskillscore', 'xxhash', 'zarr', 'python 3.11.*', 'pangeo-notebook 2024.05.20.*', 'pip']
Failed to parse json, Expecting value: line 1 column 1 (char 0)
Could not lock the environment for platform linux-64
Could not solve for environment specs
The following packages are incompatible
├─ jaxlib >=0.4.23 cuda120* is installable with the potential options
│ ├─ jaxlib 0.4.23 would require
│ │ └─ libabseil >=20240116.1,<20240117.0a0 , which can be installed;
│ ├─ jaxlib 0.4.23 would require
│ │ └─ libabseil >=20240116.2,<20240117.0a0 , which can be installed;
│ └─ jaxlib 0.4.23 would require
│ └─ python_abi 3.12.* *_cp312, which requires
│ └─ python 3.12.* *_cpython, which can be installed;
├─ python 3.11** is not installable because it conflicts with any installable versions previously reported;
└─ tensorflow >=2.15.0 cuda120* is not installable because it requires
└─ tensorflow-base [2.15.0 cuda120py310heceb7ac_2|2.15.0 cuda120py310heceb7ac_3|...|2.15.0 cuda120py39hf42b710_3], which requires
└─ libabseil >=20230802.1,<20230803.0a0 , which conflicts with any installable versions previously reported.
{
"success": false
} Need to wait for newer version of |
README.md
Outdated
@@ -54,7 +54,7 @@ The primary use of these Docker images is running on Pangeo Cloud deployments wi | |||
* Since 2020.10.16, [mamba](https://github.com/mamba-org/mamba) is installed into the base-image and conda-lock environment and is used by default to solve for a compatible environment (see #146) | |||
* For a simple list of packages for a given image, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/blob/2020.10.08/pangeo-notebook/packages.txt | |||
* To compare changes between two images, you can use a link like this: https://github.com/pangeo-data/pangeo-docker-images/compare/2020.10.03..2020.10.08 | |||
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs, you should install `conda install -c nvidia cuda-nvcc==11.6.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134 | |||
* Our `ml-notebook` image now contains JAX and TensorFlow with XLA enabled. Due to licensing issues, conda-forge does not have `ptxas`, but `ptxas` is needed for XLA to work correctly. Should you like to use JAX and/or TensorFlow with XLA optimization, please install `ptxas` on your own, for example, by `conda install -c nvidia cuda-nvcc`. At the time of writing (October 2022), JAX throws a compilation error if the `ptxas` version is higher than the driver version. There does not exist an easy solution for K80 GPUs, but in the case of T4 GPUs or newer, you should install `conda install -c nvidia cuda-nvcc==12.*` to be safe. Alternatively for any GPU, you could set an environment variable to resolve the error caused by JAX: `XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"`. The aforementioned error will be removed (and likely turned into a warning) in a future version of JAX. See https://github.com/google/jax/issues/12776#issuecomment-1276649134 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be ok to remove this note, once we have cuda-nvcc
pulled in as a dependency of jaxlib
. Ideally, we'll add a unit test to https://github.com/pangeo-data/pangeo-docker-images/blob/master/tests/test_ml-notebook.py (maybe using the snippet from #387 (comment)) to ensure that jax
works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the uninitiated, I'm not even sure what XLA is :) In light of that, could add a sentence here to give context for why this matters (or just link https://openxla.org/xla).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I didn't know what the XLA acronym stands for either 😆 I've added a link to that XLA page in commit a5fb3e9, though users shouldn't need to dig into this too much since we're using XLA-enabled builds of Tensorflow by default now.
Regression test to ensure that JaX and cuda-nvcc are installed and compatible on GPU devices.
Since we're using a version of GitHub Actions without a GPU, the JaX random number generator test on GPU cannot be properly tested, so wrapping the check in a try-except block. Also tidied up some import statements and docstrings in the test file.
Since `cuda-nvcc` is installed with jaxlib's cuda120 builds, the workaround to conda install cuda-nvcc should not be needed anymore.
# Test running on GPU (need to run locally) | ||
try: | ||
gpu_device = jax.devices("gpu")[0] | ||
with jax.default_device(gpu_device): | ||
key = random.key(seed=24) | ||
x = random.normal(key=key) | ||
np.testing.assert_allclose(x, -1.168644) | ||
except RuntimeError: # Unknown backend: 'gpu' requested | ||
logging.log(level=logging.INFO, msg="JAX was not tested on a GPU device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work on GitHub Actions without a GPU, but I ran this locally using:
mamba create -n ml-notebook --file https://raw.githubusercontent.com/pangeo-data/pangeo-docker-images/2024.06.02/ml-notebook/conda-linux-64.lock
mamba activate ml-notebook
pytest --verbose tests/test_ml-notebook.py
The tests passed on my computer with an NVIDIA GPU, so should be ok I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there were Azure credits for Pangeo we could add GPU runners https://docs.github.com/en/actions/using-github-hosted-runners/about-larger-runners/managing-larger-runners !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the notes @weiji14, sorry I overlooked approving this until now. merge when you're happy with it!
Link to https://openxla.org/xla so that people know what XLA is.
@@ -8,6 +8,7 @@ dependencies: | |||
- cuda-version>=12.0 | |||
- flax>=0.8.0 | |||
- jax | |||
- jaxlib>=0.4.23=cuda120* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pin on using a cuda120*
build of jaxlib might actually need to go away if we're planning on adding multi-arch builds of ml-notebook
(see #399 (review) for context), but will deal with this later since the same pin is present on Tensorflow.
Adding a explicit pin on
jaxlib=*=cuda120*
to pick up the cuda120 build version,as the conda-lock solver was picking up the cpu build of jaxlib instead of the cuda build in #514.See:pangeo-docker-images/ml-notebook/conda-lock.yml
Lines 4166 to 4182 in 6d4c2ab
Edit: the
cuda120
build is actually picked up automatically now as of the2024.06.02
tag, but setting thecuda120
pin still to be sure that this doesn't break in the future.Note that
jaxlib-0.4.23-cuda120py*
has an explicit runtime dependency oncuda-nvcc
since conda-forge/jaxlib-feedstock#241, so this should mean users won't have to installcuda-nvcc
explicitly anymore.The
cuda-nvcc
workaround in the main README.md file has also been removed, in place of a message recommending users to useml-notebook>=2024.06.02
.Fixes #438