Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Mudata support for MultiVI #3038

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 151 additions & 18 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import warnings
from collections.abc import Iterable as IterableClass
from functools import partial
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -117,13 +118,15 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
--------
>>> adata_rna = anndata.read_h5ad(path_to_rna_anndata)
>>> adata_atac = scvi.data.read_10x_atac(path_to_atac_anndata)
>>> adata_multi = scvi.data.read_10x_multiome(path_to_multiomic_anndata)
>>> adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac)
>>> scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality")
>>> vae = scvi.model.MULTIVI(adata_mvi)
>>> adata_protein = anndata.read_h5ad(path_to_protein_anndata)
>>> mdata = MuData({"rna": adata_rna, "protein": adata_protein, "atac": adata_atac})
>>> scvi.model.MULTIVI.setup_mudata(mdata, batch_key="batch",
>>> modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna",
>>> "atac_layer": "atac"})
>>> vae = scvi.model.MULTIVI(mdata)
>>> vae.train()

Notes
Notes (for using setup_anndata)
-----
* The model assumes that the features are organized so that all expression features are
consecutive, followed by all accessibility features. For example, if the data has 100 genes
Expand Down Expand Up @@ -360,7 +363,7 @@ def train(
@torch.inference_mode()
def get_library_size_factors(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
batch_size: int = 128,
) -> dict[str, np.ndarray]:
Expand All @@ -369,8 +372,8 @@ def get_library_size_factors(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Expand Down Expand Up @@ -409,7 +412,7 @@ def get_region_factors(self) -> np.ndarray:
@torch.inference_mode()
def get_latent_representation(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
modality: Literal["joint", "expression", "accessibility"] = "joint",
indices: Sequence[int] | None = None,
give_mean: bool = True,
Expand All @@ -420,8 +423,8 @@ def get_latent_representation(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
modality
Return modality specific or joint latent representation.
indices
Expand Down Expand Up @@ -479,7 +482,7 @@ def get_latent_representation(
@torch.inference_mode()
def get_accessibility_estimates(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
n_samples_overall: int | None = None,
region_list: Sequence[str] | None = None,
Expand All @@ -500,8 +503,8 @@ def get_accessibility_estimates(
Parameters
----------
adata
AnnData object that has been registered with scvi. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object that has been registered with scvi. If `None`, defaults to the
AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -590,14 +593,14 @@ def get_accessibility_estimates(
imputed,
index=adata.obs_names[indices],
columns=adata["rna"].var_names[self.n_genes :][region_mask]
if type(adata).__name__ == "MuData"
if isinstance(adata, MuData)
else adata.var_names[self.n_genes :][region_mask],
)

@torch.inference_mode()
def get_normalized_expression(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
n_samples_overall: int | None = None,
transform_batch: Sequence[Number | str] | None = None,
Expand All @@ -615,8 +618,8 @@ def get_normalized_expression(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -928,6 +931,130 @@ def differential_expression(

return result

@torch.no_grad()
def get_protein_foreground_probability(
self,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
transform_batch: Sequence[Number | str] | None = None,
protein_list: Sequence[str] | None = None,
n_samples: int = 1,
batch_size: int | None = None,
use_z_mean: bool = True,
return_mean: bool = True,
return_numpy: bool | None = None,
):
r"""Returns the foreground probability for proteins.

This is denoted as :math:`(1 - \pi_{nt})` in the totalVI paper.

Parameters
----------
adata
AnnOrMuData object with equivalent structure to initial AnnData. If ``None``, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Batch to condition on.
If transform_batch is:

* ``None`` - real observed batch is used
* ``int`` - batch transform_batch is used
* ``List[int]`` - average over batches in list
protein_list
Return protein expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
n_samples
Number of posterior samples to use for estimation.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame
includes gene names as columns. If either ``n_samples=1`` or ``return_mean=True``,
defaults to ``False``. Otherwise, it defaults to `True`.

Returns
-------
- **foreground_probability** - probability foreground for each protein

If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is
:class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
post = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)

if protein_list is None:
protein_mask = slice(None)
else:
all_proteins = self.scvi_setup_dict_["protein_names"]
protein_mask = [True if p in protein_list else False for p in all_proteins]

if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` is "
"`False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
if indices is None:
indices = np.arange(adata.n_obs)

py_mixings = []
if not isinstance(transform_batch, IterableClass):
transform_batch = [transform_batch]

transform_batch = _get_batch_code_from_category(self.adata_manager, transform_batch)
for tensors in post:
y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
py_mixing = torch.zeros_like(y[..., protein_mask])
if n_samples > 1:
py_mixing = torch.stack(n_samples * [py_mixing])
for _ in transform_batch:
# generative_kwargs = dict(transform_batch=b)
generative_kwargs = {"use_z_mean": use_z_mean}
inference_kwargs = {"n_samples": n_samples}
_, generative_outputs = self.module.forward(
tensors=tensors,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
compute_loss=False,
)
py_mixing += torch.sigmoid(generative_outputs["py_"]["mixing"])[
..., protein_mask
].cpu()
py_mixing /= len(transform_batch)
py_mixings += [py_mixing]
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
py_mixings = torch.cat(py_mixings, dim=1)
# (cells, features, samples)
py_mixings = py_mixings.permute(1, 2, 0)
else:
py_mixings = torch.cat(py_mixings, dim=0)

if return_mean is True and n_samples > 1:
py_mixings = torch.mean(py_mixings, dim=-1)

py_mixings = py_mixings.cpu().numpy()

if return_numpy is True:
return 1 - py_mixings
else:
pro_names = self.protein_state_registry.column_names
foreground_prob = pd.DataFrame(
1 - py_mixings,
columns=pro_names[protein_mask],
index=adata.obs_names[indices],
)
return foreground_prob

@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
Expand Down Expand Up @@ -959,6 +1086,12 @@ def setup_anndata(
`adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign
sequential names to proteins.
"""
warnings.warn(
"MULTIVI is suppose to work with MuData. the use of anndata is "
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
"deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
adata.obs["_indices"] = np.arange(adata.n_obs)
batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
Expand Down
6 changes: 6 additions & 0 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,12 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"TOTALVI is suppose to work with MuData. the use of anndata is "
"deprecated and will be remove in scvi-tools 1.4. Please use setup_mudata",
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
Expand Down
Loading