Skip to content

Commit f4f7f3c

Browse files
authored
ci: remove Jax version constraint (#3309)
Removing the jax constraint and adding a warning for version in mrVI training, as jax <=0.4.35 is not supported by newer versions of CUDA which becomes more popular and default
1 parent b473c04 commit f4f7f3c

File tree

4 files changed

+28
-9
lines changed

4 files changed

+28
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ to [Semantic Versioning]. Full commit history is available in the
3333

3434
#### Removed
3535

36+
- Removed Jax version constraint for mrVI training. {pr}`3309`.
37+
3638
### 1.3.0 (2025-02-28)
3739

3840
#### Added

docs/installation.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,20 @@ scvi-tools depends on PyTorch and JAX for accelerated computing. If you don't pl
4949
an accelerated device, we recommend installing scvi-tools directly and letting these dependencies
5050
be installed automatically by your package manager of choice.
5151

52-
If you plan on taking advantage of an accelerated device (e.g. Nvidia GPU or Apple Silicon), we
53-
recommend installing PyTorch and JAX _before_ installing scvi-tools. Please follow the respective
54-
installation instructions for [PyTorch](https://pytorch.org/get-started/locally/) and
55-
[JAX](https://jax.readthedocs.io/en/latest/installation.html) compatible with your system and
56-
device type.
52+
If you plan on taking advantage of an accelerated device (e.g. Nvidia GPU or Apple Silicon), scvi-tools supports it.
53+
In order to install scvi-tools with Nvidia GPU CUDA support use:
54+
```bash
55+
pip install -U scvi-tools[cuda]
56+
```
57+
And for Apple Silicon metal (MPS) support:
58+
```bash
59+
pip install -U scvi-tools[metal]
60+
```
61+
62+
However, there might be cases where the GPU HW is not supporting the latest installation of PyTorch and Jax.
63+
In this case we recommend installing PyTorch and JAX _before_ installing scvi-tools.
64+
Please follow the respective installation instructions for [PyTorch](https://pytorch.org/get-started/locally/) and
65+
[JAX](https://jax.readthedocs.io/en/latest/installation.html) compatible with your system and device type.
5766

5867
## Optional dependencies
5968

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ dependencies = [
3535
"anndata>=0.11",
3636
"docrep>=0.3.2",
3737
"flax",
38-
"jax<=0.4.35",
39-
"jaxlib<=0.4.35",
38+
"jax",
39+
"jaxlib",
4040
"lightning>=2.0",
4141
"ml-collections>=0.1.1",
4242
"mudata>=0.1.2",
@@ -62,7 +62,7 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"]
6262
editing = ["jupyter", "pre-commit"]
6363
dev = ["scvi-tools[editing,tests]"]
6464
test = ["scvi-tools[tests]"]
65-
cuda = ["torchvision", "torchaudio", "jax[cuda]<=0.4.35"]
65+
cuda = ["torchvision", "torchaudio", "jax[cuda12]"]
6666
metal = ["torchvision", "torchaudio", "jax-metal"]
6767

6868
docs = [

src/scvi/external/mrvi/_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import xarray as xr
1111
from tqdm import tqdm
1212

13-
from scvi import REGISTRY_KEYS
13+
from scvi import REGISTRY_KEYS, settings
1414
from scvi.data import AnnDataManager, fields
1515
from scvi.external.mrvi._module import MRVAE
1616
from scvi.external.mrvi._types import MRVIReduction
@@ -248,6 +248,14 @@ def train(
248248
train_kwargs["plan_kwargs"] = dict(
249249
deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs
250250
)
251+
from packaging import version
252+
253+
if version.parse(jax.__version__) > version.parse("0.4.35"):
254+
warnings.warn(
255+
"Running mrVI with Jax version larger 0.4.35 can cause performance issues",
256+
UserWarning,
257+
stacklevel=settings.warnings_stacklevel,
258+
)
251259
super().train(**train_kwargs)
252260

253261
def get_latent_representation(

0 commit comments

Comments
 (0)