Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Mudata support for MultiVI #3038

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Experimental MuData support for {class}`~scvi.model.MULTIVI` via the method
{meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,26 @@ census = ["cellxgene-census"]
hub = ["huggingface_hub"]
# scvi.model.utils.mde dependencies
pymde = ["pymde"]
# mudata dependencies
muon = ["muon"]
# scvi.data.add_dna_sequence
regseq = ["biopython>=1.81", "genomepy"]
# read loom
loompy = ["loompy>=3.0.6"]
# scvi.criticism and read 10x
scanpy = ["scanpy>=1.6"]
scanpy = ["scanpy>=1.6","scikit-misc"]

optional = [
"scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]"
"scvi-tools[autotune,aws,hub,loompy,muon,pymde,regseq,scanpy]"
]
tutorials = [
"cell2location",
"jupyter",
"leidenalg",
"muon",
"plotnine",
"pooch",
"pynndescent",
"igraph",
"scikit-misc",
"scrublet",
"scib-metrics",
"scvi-tools[optional]",
Expand Down
182 changes: 157 additions & 25 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import Normal

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data import AnnDataManager, fields
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
Expand Down Expand Up @@ -44,8 +44,9 @@
from typing import Literal

from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand All @@ -59,7 +60,8 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.MULTIVI.setup_anndata`.
AnnData/MuData object that has been registered via
:meth:`~scvi.model.MULTIVI.setup_anndata` or :meth:`~scvi.model.MULTIVI.setup_mudata`.
n_genes
The number of gene expression features (genes).
n_regions
Expand Down Expand Up @@ -116,13 +118,15 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
--------
>>> adata_rna = anndata.read_h5ad(path_to_rna_anndata)
>>> adata_atac = scvi.data.read_10x_atac(path_to_atac_anndata)
>>> adata_multi = scvi.data.read_10x_multiome(path_to_multiomic_anndata)
>>> adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac)
>>> scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality")
>>> vae = scvi.model.MULTIVI(adata_mvi)
>>> adata_protein = anndata.read_h5ad(path_to_protein_anndata)
>>> mdata = MuData({"rna": adata_rna, "protein": adata_protein, "atac": adata_atac})
>>> scvi.model.MULTIVI.setup_mudata(mdata, batch_key="batch",
>>> modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna",
>>> "atac_layer": "atac"})
>>> vae = scvi.model.MULTIVI(mdata)
>>> vae.train()

Notes
Notes (for using setup_anndata)
-----
* The model assumes that the features are organized so that all expression features are
consecutive, followed by all accessibility features. For example, if the data has 100 genes
Expand All @@ -140,7 +144,7 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_genes: int,
n_regions: int,
modality_weights: Literal["equal", "cell", "universal"] = "equal",
Expand Down Expand Up @@ -359,7 +363,7 @@ def train(
@torch.inference_mode()
def get_library_size_factors(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
batch_size: int = 128,
) -> dict[str, np.ndarray]:
Expand All @@ -368,8 +372,8 @@ def get_library_size_factors(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Expand Down Expand Up @@ -408,7 +412,7 @@ def get_region_factors(self) -> np.ndarray:
@torch.inference_mode()
def get_latent_representation(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
modality: Literal["joint", "expression", "accessibility"] = "joint",
indices: Sequence[int] | None = None,
give_mean: bool = True,
Expand All @@ -419,8 +423,8 @@ def get_latent_representation(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
modality
Return modality specific or joint latent representation.
indices
Expand Down Expand Up @@ -478,7 +482,7 @@ def get_latent_representation(
@torch.inference_mode()
def get_accessibility_estimates(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
n_samples_overall: int | None = None,
region_list: Sequence[str] | None = None,
Expand All @@ -499,8 +503,8 @@ def get_accessibility_estimates(
Parameters
----------
adata
AnnData object that has been registered with scvi. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object that has been registered with scvi. If `None`, defaults to the
AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -588,13 +592,15 @@ def get_accessibility_estimates(
return pd.DataFrame(
imputed,
index=adata.obs_names[indices],
columns=adata.var_names[self.n_genes :][region_mask],
columns=adata["rna"].var_names[self.n_genes :][region_mask]
if isinstance(adata, MuData)
else adata.var_names[self.n_genes :][region_mask],
)

@torch.inference_mode()
def get_normalized_expression(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
n_samples_overall: int | None = None,
transform_batch: Sequence[Number | str] | None = None,
Expand All @@ -612,8 +618,8 @@ def get_normalized_expression(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -928,7 +934,7 @@ def differential_expression(
@torch.no_grad()
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
def get_protein_foreground_probability(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
transform_batch: Sequence[Number | str] | None = None,
protein_list: Sequence[str] | None = None,
Expand All @@ -945,8 +951,8 @@ def get_protein_foreground_probability(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If ``None``, defaults to
the AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If ``None``, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Expand Down Expand Up @@ -1080,6 +1086,12 @@ def setup_anndata(
`adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign
sequential names to proteins.
"""
warnings.warn(
"MULTIVI is suppose to work with MuData. the use of anndata is "
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
"deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
adata.obs["_indices"] = np.arange(adata.n_obs)
batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
Expand Down Expand Up @@ -1117,3 +1129,123 @@ def _check_adata_modality_weights(self, adata):
"""
if (adata is not None) and (self.module.modality_weights == "cell"):
raise RuntimeError("Held out data not permitted when using per cell weights")

@classmethod
@setup_anndata_dsp.dedent
def setup_mudata(
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
cls,
mdata: MuData,
rna_layer: str | None = None,
atac_layer: str | None = None,
protein_layer: str | None = None,
batch_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
idx_layer: str | None = None,
modalities: dict[str, str] | None = None,
**kwargs,
):
"""%(summary_mdata)s.

Parameters
----------
%(param_mdata)s
rna_layer
RNA layer key. If `None`, will use `.X` of specified modality key.
protein_layer
Protein layer key. If `None`, will use `.X` of specified modality key.
atac_layer
ATAC layer key. If `None`, will use `.X` of specified modality key.
%(param_batch_key)s
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
%(idx_layer)s
%(param_modalities)s

Examples
--------
>>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5")
>>> scvi.model.MULTIVI.setup_mudata(
mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}
)
>>> vae = scvi.model.MULTIVI(mdata)
"""
setup_method_args = cls._get_setup_method_args(**locals())

if modalities is None:
raise ValueError("Modalities cannot be None.")
modalities = cls._create_modalities_attr_dict(modalities, setup_method_args)
mdata.obs["_indices"] = np.arange(mdata.n_obs)

batch_field = fields.MuDataCategoricalObsField(
REGISTRY_KEYS.BATCH_KEY,
batch_key,
mod_key=modalities.batch_key,
)
mudata_fields = [
batch_field,
fields.MuDataCategoricalObsField(
REGISTRY_KEYS.LABELS_KEY,
None,
mod_key=None,
),
fields.MuDataNumericalObsField(
REGISTRY_KEYS.SIZE_FACTOR_KEY,
size_factor_key,
mod_key=modalities.size_factor_key,
required=False,
),
fields.MuDataCategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY,
categorical_covariate_keys,
mod_key=modalities.categorical_covariate_keys,
),
fields.MuDataNumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY,
continuous_covariate_keys,
mod_key=modalities.continuous_covariate_keys,
),
fields.MuDataNumericalObsField(
REGISTRY_KEYS.INDICES_KEY,
"_indices",
mod_key=modalities.idx_layer,
required=False,
),
]
if modalities.rna_layer is not None:
mudata_fields.append(
fields.MuDataLayerField(
REGISTRY_KEYS.X_KEY,
rna_layer,
mod_key=modalities.rna_layer,
is_count_data=True,
mod_required=True,
)
)
if modalities.atac_layer is not None:
mudata_fields.append(
fields.MuDataLayerField(
REGISTRY_KEYS.X_KEY,
atac_layer,
mod_key=modalities.atac_layer,
is_count_data=True,
mod_required=True,
)
)
if modalities.protein_layer is not None:
mudata_fields.append(
fields.MuDataProteinLayerField(
REGISTRY_KEYS.PROTEIN_EXP_KEY,
protein_layer,
mod_key=modalities.protein_layer,
use_batch_mask=True,
batch_field=batch_field,
is_count_data=True,
mod_required=True,
)
)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
15 changes: 11 additions & 4 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand All @@ -46,7 +46,8 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
AnnData/MuData object that has been registered via
:meth:`~scvi.model.TOTALVI.setup_anndata` or :meth:`~scvi.model.TOTALVI.setup_mudata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
Expand Down Expand Up @@ -108,7 +109,7 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_latent: int = 20,
gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein",
Expand Down Expand Up @@ -1214,6 +1215,12 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"TOTALVI is suppose to work with MuData. the use of anndata is "
"deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata",
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
Expand Down Expand Up @@ -1275,7 +1282,7 @@ def setup_mudata(
--------
>>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5")
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"}
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
)
>>> vae = scvi.model.TOTALVI(mdata)
"""
Expand Down
Loading
Loading