Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions .github/workflows/test_linux_custom_dataloader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
name: test (custom dataloaders)

on:
push:
branches: [main, "[0-9]+.[0-9]+.x"]
pull_request:
branches: [main, "[0-9]+.[0-9]+.x"]
types: [labeled, synchronize, opened]
schedule:
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
# if PR has label "custom_dataloader" or "all tests" or if scheduled or manually triggered
if: >-
(
contains(github.event.pull_request.labels.*.name, 'custom_dataloader') ||
contains(github.event.pull_request.labels.*.name, 'all tests') ||
contains(github.event_name, 'schedule') ||
contains(github.event_name, 'workflow_dispatch')
)

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash -e {0} # -e to fail on error

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.12"]

name: integration

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"

- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."

- name: Run specific custom dataloader pytest
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
COLUMNS: 120
run: |
coverage run -m pytest -v --color=yes --custom-dataloader-tests
coverage report

- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ to [Semantic Versioning]. Full commit history is available in the
- Add get normalized function model property for any generative model {pr}`3238` and changed
get_accessibility_estimates to get_normalized_accessibility, where needed.
- Add {class}`scvi.external.TOTALANVI`. {pr}`3259`.
- Add Custom Dataloaders registry support, {pr}`2932`.
- Add support for using Census and LaminAI custom dataloaders for {class}`scvi.model.SCVI`
and {class}`scvi.model.SCANVI`, {pr}`2932`.
- Add Early stopping KL warmup steps. {pr}`3262`.
- Add Minification option to {class}`~scvi.model.LinearSCVI` {pr}`3294`.

Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/index_use_cases.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
notebooks/use_cases/autotune_scvi
notebooks/use_cases/minification
notebooks/use_cases/interpretability
notebooks/use_cases/custom_dl/tiledb
notebooks/use_cases/custom_dl/lamin
```
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
69 changes: 52 additions & 17 deletions docs/user_guide/use_case/custom_dataloaders.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,34 @@ Pros:
- Optimized for ML Workflows: If your dataset is structured as tables (rows and columns), LamindDB’s format aligns well with SCVI's expectations, potentially reducing the need for complex transformations.

```python
os.system("lamin init --storage ./test-registries")
import lamindb as ln
from scvi.dataloaders import MappedCollectionDataModule
import scvi
import os

os.system("lamin init --storage ./test-registries")

ln.setup.init(name="lamindb_instance_name", storage=save_path)

# a test for mapped collection
collection = ln.Collection.get(name="covid_normal_lung")
collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung")
artifacts = collection.artifacts.all()
artifacts.df()

datamodule = MappedCollectionDataModule(
collection, batch_key="assay", batch_size=1024, join="inner"
collection,
batch_key="assay",
batch_size=1024,
join="inner",
shuffle=True,
)
model = scvi.model.SCVI(adata=None, registry=datamodule.registry)
model.train(max_epochs=1, batch_size=1024, datamodule=datamodule.inference_dataloader())
...
```
LamindDB may not be as efficient or flexible as TileDB for handling complex multi-dimensional data

2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.
2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on TileDBDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.

TileDB is a general-purpose, multi-dimensional array storage engine designed for high-performance, scalable data access. It supports various data types, including dense and sparse arrays, and is optimized for handling large datasets efficiently. TileDB’s strength lies in its ability to store and query data across multiple dimensions and scale efficiently with large volumes of data.

Expand All @@ -52,9 +61,10 @@ Scalability: Handles large datasets that exceed your system's memory capacity, m
```python
import cellxgene_census
import tiledbsoma as soma
from cellxgene_census.experimental.ml import experiment_dataloader
from cellxgene_census.experimental.ml.datamodule import CensusSCVIDataModule
import tiledbsoma_ml
from scvi.dataloaders import TileDBDataModule
import numpy as np
import scvi

# this test checks the local custom dataloder made by CZI and run several tests with it
census = cellxgene_census.open_soma(census_version="stable")
Expand All @@ -66,25 +76,48 @@ obs_value_filter = (

hv_idx = np.arange(100) # just ot make it smaller and faster for debug

# this is CZI part to be taken once all is ready
batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]
datamodule = CensusSCVIDataModule(
census["census_data"][experiment_name],
# For HVG, we can use the highly_variable_genes function provided in cellxgene_census,
# which can compute HVGs in constant memory:
hvg_query = census["census_data"][experiment_name].axis_query(
measurement_name="RNA",
X_name="raw",
obs_query=soma.AxisQuery(value_filter=obs_value_filter),
var_query=soma.AxisQuery(coords=(list(hv_idx),)),
)

# this is CZI part to be taken once all is ready
batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]
label_keys = ["tissue_general"]
datamodule = TileDBDataModule(
hvg_query,
layer_name="raw",
batch_size=1024,
shuffle=True,
batch_keys=batch_keys,
seed=42,
batch_column_names=batch_keys,
label_keys=label_keys,
train_size=0.9,
unlabeled_category="label_0",
dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
)

# We can now create the scVI model object and train it:
model = scvi.model.SCVI(
adata=None,
registry=datamodule.registry,
gene_likelihood="nb",
encode_covariates=False,
)

# creating the dataloader for trainset
datamodule.setup()

# basicaly we should mimiC everything below to any model census in scvi
adata_orig = synthetic_iid()
scvi.model.SCVI.setup_anndata(adata_orig, batch_key="batch")
model = scvi.model.SCVI(adata_orig)
model.train(
datamodule=datamodule,
max_epochs=1,
batch_size=1024,
train_size=0.9,
early_stopping=False,
)
...
```
Key Differences between them in terms of Custom Dataloaders:
Expand All @@ -110,6 +143,8 @@ When to Use Each:
Writing custom dataloaders requires a good understanding of PyTorch’s DataLoader class and how to integrate it with SCVI, which may be difficult for beginners.
It will also requite maintenance: If the data format or preprocessing needs change, you’ll have to modify and maintain the custom dataloader code, But it can be a greate addition to the model pipeline, in terms of runtime and how much data we can digest.

See relevant tutorials in this subject for further examples.

:::{note}
As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental.
As for SCVI-Tools v1.3.0 Custom Dataloaders are experimental and only supported for adata and SCVI and SCANVI models
:::
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ docs = [
docsbuild = ["scvi-tools[docs,optional]"]

# scvi.autotune
autotune = ["hyperopt>=0.2", "ray[tune]","scib-metrics"]
autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics"]
# scvi.hub.HubModel.pull_from_s3
aws = ["boto3"]
# scvi.data.cellxgene
census = ["cellxgene-census", "numpy<2.0"]
# scvi.hub dependencies
hub = ["huggingface_hub", "igraph", "leidenalg", "dvc[s3]"]
# scvi.data.add_dna_sequence
Expand All @@ -96,13 +94,15 @@ scanpy = ["scanpy>=1.10", "scikit-misc"]
# for convinient files sharing
file_sharing = ["pooch", "cellxgene-census"]
# for parallelization engine
parallel = ["dask[array]>=2023.5.1,<2024.8.0"]
parallel = ["dask[array]>=2023.5.1,<2024.8.0", "zarr<3.0.0"]
# for supervised models interpretability
interpretability = ["captum","shap"]
interpretability = ["captum", "shap"]
# for custom dataloders
dataloaders = ["lamindb>=1.3.0", "biomart", "bionty", "cellxgene_lamin", "cellxgene-census", "numpy<2.0", "tiledbsoma", "tiledb", "tiledbsoma_ml", "torchdata==0.9.0"]


optional = [
"scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability]"
"scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability,dataloaders]"
]
tutorials = [
"cell2location",
Expand Down Expand Up @@ -137,6 +137,7 @@ markers = [
"private: mark tests that uses private keys, like HF",
"multigpu: mark tests that are used to check multi GPU performance",
"autotune: mark tests that are used to check ray autotune capabilities",
"custom dataloaders: mark tests that are used to check different custom data loaders",
]

[tool.ruff]
Expand Down
20 changes: 20 additions & 0 deletions src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from . import _constants

if TYPE_CHECKING:
from collections.abc import Iterator

import numpy.typing as npt
from pandas.api.types import CategoricalDtype
from torch import Tensor
Expand Down Expand Up @@ -361,3 +363,21 @@ def _check_fragment_counts(
) # True if there are more 2s than 1s
ret = not (non_fragments or binary)
return ret


def _validate_adata_dataloader_input(
model,
adata: AnnOrMuData | None = None,
dataloader: Iterator[dict[str, Tensor | None]] | None = None,
):
"""Validate that model uses adata or custom dataloader"""
if adata is not None and dataloader is not None:
raise ValueError("Only one of `adata` or `dataloader` can be provided.")
elif (
hasattr(model, "registry")
and "setup_method_name" in model.registry.keys()
and model.registry["setup_method_name"] == "setup_datamodule"
and dataloader is None
):
raise ValueError("`dataloader` must be provided.")
return
3 changes: 3 additions & 0 deletions src/scvi/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ._ann_dataloader import AnnDataLoader
from ._concat_dataloader import ConcatDataLoader
from ._custom_dataloders import MappedCollectionDataModule, TileDBDataModule
from ._data_splitting import (
DataSplitter,
DeviceBackedDataSplitter,
Expand All @@ -20,4 +21,6 @@
"DataSplitter",
"SemiSupervisedDataSplitter",
"BatchDistributedSampler",
"MappedCollectionDataModule",
"TileDBDataModule",
]
Loading