Skip to content

Commit c4cab3b

Browse files
ori-kron-wiscanergenpre-commit-ci[bot]
authored
Custom dataloader registry support (#2932)
Co-authored-by: Can Ergen <canergen.ac@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ced87df commit c4cab3b

File tree

23 files changed

+2785
-232
lines changed

23 files changed

+2785
-232
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
name: test (custom dataloaders)
2+
3+
on:
4+
push:
5+
branches: [main, "[0-9]+.[0-9]+.x"]
6+
pull_request:
7+
branches: [main, "[0-9]+.[0-9]+.x"]
8+
types: [labeled, synchronize, opened]
9+
schedule:
10+
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
11+
workflow_dispatch:
12+
13+
concurrency:
14+
group: ${{ github.workflow }}-${{ github.ref }}
15+
cancel-in-progress: true
16+
17+
jobs:
18+
test:
19+
# if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered
20+
if: >-
21+
(
22+
contains(github.event.pull_request.labels.*.name, 'custom_dataloader') ||
23+
contains(github.event.pull_request.labels.*.name, 'all tests') ||
24+
contains(github.event_name, 'schedule') ||
25+
contains(github.event_name, 'workflow_dispatch')
26+
)
27+
28+
runs-on: ${{ matrix.os }}
29+
30+
defaults:
31+
run:
32+
shell: bash -e {0} # -e to fail on error
33+
34+
strategy:
35+
fail-fast: false
36+
matrix:
37+
os: [ubuntu-latest]
38+
python: ["3.12"]
39+
40+
name: integration
41+
42+
env:
43+
OS: ${{ matrix.os }}
44+
PYTHON: ${{ matrix.python }}
45+
46+
steps:
47+
- uses: actions/checkout@v4
48+
49+
- uses: actions/setup-python@v5
50+
with:
51+
python-version: ${{ matrix.python }}
52+
cache: "pip"
53+
cache-dependency-path: "**/pyproject.toml"
54+
55+
- name: Install dependencies
56+
run: |
57+
python -m pip install --upgrade pip wheel uv
58+
python -m uv pip install --system "scvi-tools[tests] @ ."
59+
60+
- name: Run specific custom dataloader pytest
61+
env:
62+
MPLBACKEND: agg
63+
PLATFORM: ${{ matrix.os }}
64+
DISPLAY: :42
65+
COLUMNS: 120
66+
run: |
67+
coverage run -m pytest -v --color=yes --custom-dataloader-tests
68+
coverage report
69+
70+
- uses: codecov/codecov-action@v4
71+
with:
72+
token: ${{ secrets.CODECOV_TOKEN }}

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ to [Semantic Versioning]. Full commit history is available in the
1616
- Add get normalized function model property for any generative model {pr}`3238` and changed
1717
get_accessibility_estimates to get_normalized_accessibility, where needed.
1818
- Add {class}`scvi.external.TOTALANVI`. {pr}`3259`.
19+
- Add Custom Dataloaders registry support, {pr}`2932`.
20+
- Add support for using Census and LaminAI custom dataloaders for {class}`scvi.model.SCVI`
21+
and {class}`scvi.model.SCANVI`, {pr}`2932`.
1922
- Add Early stopping KL warmup steps. {pr}`3262`.
2023
- Add Minification option to {class}`~scvi.model.LinearSCVI` {pr}`3294`.
2124

docs/tutorials/index_use_cases.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
notebooks/use_cases/autotune_scvi
77
notebooks/use_cases/minification
88
notebooks/use_cases/interpretability
9+
notebooks/use_cases/custom_dl/tiledb
10+
notebooks/use_cases/custom_dl/lamin
911
```

docs/tutorials/notebooks

docs/user_guide/use_case/custom_dataloaders.md

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,34 @@ Pros:
2121
- Optimized for ML Workflows: If your dataset is structured as tables (rows and columns), LamindDB’s format aligns well with SCVI's expectations, potentially reducing the need for complex transformations.
2222

2323
```python
24-
os.system("lamin init --storage ./test-registries")
2524
import lamindb as ln
25+
from scvi.dataloaders import MappedCollectionDataModule
26+
import scvi
27+
import os
28+
29+
os.system("lamin init --storage ./test-registries")
2630

2731
ln.setup.init(name="lamindb_instance_name", storage=save_path)
2832

2933
# a test for mapped collection
30-
collection = ln.Collection.get(name="covid_normal_lung")
34+
collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung")
3135
artifacts = collection.artifacts.all()
3236
artifacts.df()
3337

3438
datamodule = MappedCollectionDataModule(
35-
collection, batch_key="assay", batch_size=1024, join="inner"
39+
collection,
40+
batch_key="assay",
41+
batch_size=1024,
42+
join="inner",
43+
shuffle=True,
3644
)
3745
model = scvi.model.SCVI(adata=None, registry=datamodule.registry)
46+
model.train(max_epochs=1, batch_size=1024, datamodule=datamodule.inference_dataloader())
3847
...
3948
```
4049
LamindDB may not be as efficient or flexible as TileDB for handling complex multi-dimensional data
4150

42-
2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.
51+
2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on TileDBDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.
4352

4453
TileDB is a general-purpose, multi-dimensional array storage engine designed for high-performance, scalable data access. It supports various data types, including dense and sparse arrays, and is optimized for handling large datasets efficiently. TileDB’s strength lies in its ability to store and query data across multiple dimensions and scale efficiently with large volumes of data.
4554

@@ -52,9 +61,10 @@ Scalability: Handles large datasets that exceed your system's memory capacity, m
5261
```python
5362
import cellxgene_census
5463
import tiledbsoma as soma
55-
from cellxgene_census.experimental.ml import experiment_dataloader
56-
from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule
64+
import tiledbsoma_ml
65+
from scvi.dataloaders import TileDBDataModule
5766
import numpy as np
67+
import scvi
5868

5969
# this test checks the local custom dataloder made by CZI and run several tests with it
6070
census = cellxgene_census.open_soma(census_version="stable")
@@ -66,25 +76,48 @@ obs_value_filter = (
6676

6777
hv_idx = np.arange(100) # just ot make it smaller and faster for debug
6878

69-
# this is CZI part to be taken once all is ready
70-
batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]
71-
datamodule = CensusSCVIDataModule(
72-
census["census_data"][experiment_name],
79+
# For HVG, we can use the highly_variable_genes function provided in cellxgene_census,
80+
# which can compute HVGs in constant memory:
81+
hvg_query = census["census_data"][experiment_name].axis_query(
7382
measurement_name="RNA",
74-
X_name="raw",
7583
obs_query=soma.AxisQuery(value_filter=obs_value_filter),
7684
var_query=soma.AxisQuery(coords=(list(hv_idx),)),
85+
)
86+
87+
# this is CZI part to be taken once all is ready
88+
batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]
89+
label_keys = ["tissue_general"]
90+
datamodule = TileDBDataModule(
91+
hvg_query,
92+
layer_name="raw",
7793
batch_size=1024,
7894
shuffle=True,
79-
batch_keys=batch_keys,
95+
seed=42,
96+
batch_column_names=batch_keys,
97+
label_keys=label_keys,
98+
train_size=0.9,
99+
unlabeled_category="label_0",
80100
dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
81101
)
82102

103+
# We can now create the scVI model object and train it:
104+
model = scvi.model.SCVI(
105+
adata=None,
106+
registry=datamodule.registry,
107+
gene_likelihood="nb",
108+
encode_covariates=False,
109+
)
110+
111+
# creating the dataloader for trainset
112+
datamodule.setup()
83113

84-
# basicaly we should mimiC everything below to any model census in scvi
85-
adata_orig = synthetic_iid()
86-
scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch")
87-
model = scvi.model.SCVI(adata_orig)
114+
model.train(
115+
datamodule=datamodule,
116+
max_epochs=1,
117+
batch_size=1024,
118+
train_size=0.9,
119+
early_stopping=False,
120+
)
88121
...
89122
```
90123
Key Differences between them in terms of Custom Dataloaders:
@@ -110,6 +143,8 @@ When to Use Each:
110143
Writing custom dataloaders requires a good understanding of PyTorch’s DataLoader class and how to integrate it with SCVI, which may be difficult for beginners.
111144
It will also requite maintenance: If the data format or preprocessing needs change, you’ll have to modify and maintain the custom dataloader code, But it can be a greate addition to the model pipeline, in terms of runtime and how much data we can digest.
112145

146+
See relevant tutorials in this subject for further examples.
147+
113148
:::{note}
114-
As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental.
149+
As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental and only supported for adata and SCVI and SCANVI models
115150
:::

pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,9 @@ docs = [
8282
docsbuild = ["scvi-tools[docs,optional]"]
8383

8484
# scvi.autotune
85-
autotune = ["hyperopt>=0.2", "ray[tune]","scib-metrics"]
85+
autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics"]
8686
# scvi.hub.HubModel.pull_from_s3
8787
aws = ["boto3"]
88-
# scvi.data.cellxgene
89-
census = ["cellxgene-census", "numpy<2.0"]
9088
# scvi.hub dependencies
9189
hub = ["huggingface_hub", "igraph", "leidenalg", "dvc[s3]"]
9290
# scvi.data.add_dna_sequence
@@ -96,13 +94,15 @@ scanpy = ["scanpy>=1.10", "scikit-misc"]
9694
# for convinient files sharing
9795
file_sharing = ["pooch", "cellxgene-census"]
9896
# for parallelization engine
99-
parallel = ["dask[array]>=2023.5.1,<2024.8.0"]
97+
parallel = ["dask[array]>=2023.5.1,<2024.8.0", "zarr<3.0.0"]
10098
# for supervised models interpretability
101-
interpretability = ["captum","shap"]
99+
interpretability = ["captum", "shap"]
100+
# for custom dataloders
101+
dataloaders = ["lamindb>=1.3.0", "biomart", "bionty", "cellxgene_lamin", "cellxgene-census", "numpy<2.0", "tiledbsoma", "tiledb", "tiledbsoma_ml", "torchdata==0.9.0"]
102102

103103

104104
optional = [
105-
"scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability]"
105+
"scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability,dataloaders]"
106106
]
107107
tutorials = [
108108
"cell2location",
@@ -137,6 +137,7 @@ markers = [
137137
"private: mark tests that uses private keys, like HF",
138138
"multigpu: mark tests that are used to check multi GPU performance",
139139
"autotune: mark tests that are used to check ray autotune capabilities",
140+
"custom dataloaders: mark tests that are used to check different custom data loaders",
140141
]
141142

142143
[tool.ruff]

src/scvi/data/_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from . import _constants
2222

2323
if TYPE_CHECKING:
24+
from collections.abc import Iterator
25+
2426
import numpy.typing as npt
2527
from pandas.api.types import CategoricalDtype
2628
from torch import Tensor
@@ -361,3 +363,21 @@ def _check_fragment_counts(
361363
) # True if there are more 2s than 1s
362364
ret = not (non_fragments or binary)
363365
return ret
366+
367+
368+
def _validate_adata_dataloader_input(
369+
model,
370+
adata: AnnOrMuData | None = None,
371+
dataloader: Iterator[dict[str, Tensor | None]] | None = None,
372+
):
373+
"""Validate that model uses adata or custom dataloader"""
374+
if adata is not None and dataloader is not None:
375+
raise ValueError("Only one of `adata` or `dataloader` can be provided.")
376+
elif (
377+
hasattr(model, "registry")
378+
and "setup_method_name" in model.registry.keys()
379+
and model.registry["setup_method_name"] == "setup_datamodule"
380+
and dataloader is None
381+
):
382+
raise ValueError("`dataloader` must be provided.")
383+
return

src/scvi/dataloaders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from ._ann_dataloader import AnnDataLoader
55
from ._concat_dataloader import ConcatDataLoader
6+
from ._custom_dataloders import MappedCollectionDataModule, TileDBDataModule
67
from ._data_splitting import (
78
DataSplitter,
89
DeviceBackedDataSplitter,
@@ -20,4 +21,6 @@
2021
"DataSplitter",
2122
"SemiSupervisedDataSplitter",
2223
"BatchDistributedSampler",
24+
"MappedCollectionDataModule",
25+
"TileDBDataModule",
2326
]

0 commit comments

Comments
 (0)