Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(model): refactor out common minified mode methods #2883

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/scvi/data/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

class _ADATA_MINIFY_TYPE_NT(NamedTuple):
LATENT_POSTERIOR: str = "latent_posterior_parameters"
LATENT_POSTERIOR_WITH_COUNTS: str = "latent_posterior_parameters_with_counts"


ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT()
Expand Down
87 changes: 0 additions & 87 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import numpy as np
import pandas as pd
import torch
from anndata import AnnData

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import (
_ADATA_MINIFY_TYPE_UNS_KEY,
_SETUP_ARGS_KEY,
ADATA_MINIFY_TYPE,
)
from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute
from scvi.data.fields import (
Expand All @@ -25,12 +22,9 @@
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
Expand All @@ -45,11 +39,6 @@

from anndata import AnnData

from scvi._types import MinifiedDataType
from scvi.data.fields import (
BaseAnnDataField,
)

from ._scvi import SCVI

_SCANVI_LATENT_QZM = "_scanvi_latent_qzm"
Expand Down Expand Up @@ -485,79 +474,3 @@ def setup_anndata(
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCANVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCANVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCANVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
):
"""Minifies the model's adata.

Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is
minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_adata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
88 changes: 0 additions & 88 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@
import warnings
from typing import TYPE_CHECKING

import numpy as np

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp

Expand All @@ -32,10 +26,6 @@

from anndata import AnnData

from scvi._types import MinifiedDataType
from scvi.data.fields import (
BaseAnnDataField,
)

_SCVI_LATENT_QZM = "_scvi_latent_qzm"
_SCVI_LATENT_QZV = "_scvi_latent_qzv"
Expand Down Expand Up @@ -231,81 +221,3 @@ def setup_anndata(
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's adata.

Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is
minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_adata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
Loading
Loading