Skip to content

decode_latent_samples returnsnon-library size normalised expression values #46

@marie-minaeva

Description

@marie-minaeva

Report

Hey!

We have observed that teh decode_latent_samples returns gene expression values not corrected for the library size. Here is a code snippet that we used for testing:

import sys
import yaml
from functools import partial

import scanpy as sc
import scipy as sp
import anndata as ad
import numpy as np

import torch
import drvi
from typing import Any
import scvi
from scvi import REGISTRY_KEYS
from drvi.nn_modules.noise_model import library_size_correction


# Load data
adata_gex = sc.read_h5ad("barcodedData/all_dc_dataset_ready_to_integr.h5ad")
adata_gex.obs["batch"] = adata_gex.obs["Seq.Run"].tolist()
adata_gex = adata_gex[~(adata_gex.obs["donor"] == "A"), :].copy()
adata_gex.obs["donor"] = adata_gex.obs["donor"].cat.add_categories("")
adata_gex.obs["donor"] = adata_gex.obs["donor"].fillna("")

# Custom decoding function
def iterate_on_decoded_latent_samples(
        vae,
        z: torch.Tensor,
        lib: np.ndarray | None = None,
        cat_values: np.ndarray | None = None,
        cont_values: np.ndarray | None = None,
        batch_size: int = scvi.settings.batch_size,
        map_cat_values: bool = False,
    ) -> torch.tensor:
        """Iterate over decoder outputs and aggregate the results.

        This method processes latent samples through the generative model in batches,
        applies a custom function to each batch output, and aggregates the results.

        Parameters
        ----------
        z
            Latent samples with shape (n_samples, n_latent).
        step_func
            Function to apply to the decoder output at each step.
            Should accept (generative_outputs, store) as arguments.
        aggregation_func
            Function to aggregate the step results from the store.
            Should accept the store list and return the final result.
        lib
            Library size array with shape (n_samples,).
            If None, defaults to 1e4 for all samples.
        cat_values
            Categorical covariates with shape (n_samples, n_cat_covs).
            Required if model has categorical covariates.
        cont_values
            Continuous covariates with shape (n_samples, n_cont_covs).
        batch_size
            Minibatch size for data loading into model.
        map_cat_values
            Whether to map categorical covariates to integers based on
            the AnnData manager pipeline.

        Returns
        -------
        np.ndarray
            Aggregated results from processing all latent samples.

        Notes
        -----
        This method operates in inference mode and processes data in batches
        to manage memory usage. The step_func receives the generative outputs
        and a store variable for accumulating results.

        If map_cat_values is True, categorical values are automatically mapped
        to integers using the model's category mappings.

        Examples
        --------
        >>> import numpy as np
        >>> # Define custom step function to extract means
        >>> def extract_means(gen_output, store):
        ...     store.append(gen_output["params"]["mean"].detach().cpu())
        >>> # Define aggregation function to concatenate results
        >>> def concatenate_results(store):
        ...     return torch.cat(store, dim=0).numpy()
        >>> # Process latent samples
        >>> z = np.random.randn(50, 32)  # assuming 32 latent dimensions
        >>> result = model.iterate_on_decoded_latent_samples(
        ...     z=z, step_func=extract_means, aggregation_func=concatenate_results
        ... )
        >>> print(result.shape)  # (50, n_genes)
        """
        
        def step_func(gen_output: dict[str, Any], store: list[Any]) -> None:
            store.append(gen_output["px"].log_m)
        def aggregation_func(store: list[Any]) -> torch.Tensor:
            return torch.cat(store, dim=0)



        store: list[Any] = []
        vae.module.eval()

        if cat_values is not None and map_cat_values:
            if cat_values.ndim == 1:  # For a user not noticing cat_values should be 2d!
                cat_values = cat_values.reshape(-1, 1)
            mapped_values = np.zeros_like(cat_values)
            for i, (_label, map_keys) in enumerate(
                vae.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["mappings"].items()
            ):
                cat_mapping = dict(zip(map_keys, range(len(map_keys)), strict=False))
                mapped_values[:, i] = np.vectorize(cat_mapping.get)(cat_values[:, i])
            cat_values = mapped_values.astype(np.int32)

        #with torch.no_grad():
        for i in np.arange(0, z.shape[0], batch_size):
            slice = np.arange(i, min(i + batch_size, z.shape[0]))
            z_tensor = torch.tensor(z[slice])
            if lib is None:
                lib_tensor = torch.tensor([1e4] * slice.shape[0])
            else:
                lib_tensor = torch.tensor(lib[slice])
            cat_tensor = torch.tensor(cat_values[slice]) if cat_values is not None else None
            cont_tensor = torch.tensor(cont_values[slice]) if cont_values is not None else None
            batch_tensor = None
            with torch.no_grad():
                gen_input = vae.module._get_generative_input(
                    tensors={
                        REGISTRY_KEYS.BATCH_KEY: batch_tensor,
                        REGISTRY_KEYS.LABELS_KEY: None,
                        REGISTRY_KEYS.CONT_COVS_KEY: cont_tensor,
                        REGISTRY_KEYS.CAT_COVS_KEY: cat_tensor,
                    },
                    inference_outputs={
                        "z": z_tensor,
                        "library": lib_tensor,
                        "gene_likelihood_additional_info": {},
                    },
                )
                gen_output = vae.module.generative(**gen_input)
            #print(gen_output["params"]["mean"].requires_grad)
            step_func(gen_output, store)
        result = aggregation_func(store)
        del store, gen_output
        torch.cuda.empty_cache()
        return result


# Load trained model
vae = drvi.model.DRVI.load("./models/", adata_gex, prefix="all_dc_128_no_ha_")

# Generate random latent
z = np.random.randn(100, 128)

# Decode using original decoding function
out_1 = vae.decode_latent_samples(torch.from_numpy(z).to(dtype=torch.float32), cat_values = np.vstack([np.array(["10x"] * z.shape[0]), np.array(["H.B"] * z.shape[0])]).T, cont_values=None, map_cat_values=True,)

# Decode using fixed decoding function
out_2 = iterate_on_decoded_latent_samples(vae, torch.from_numpy(z).to(dtype=torch.float32), cat_values = np.vstack([np.array(["10x"] * z.shape[0]), np.array(["B"] * z.shape[0])]).T, cont_values=None, map_cat_values=True,)

# Comparing outputs
print(torch.allclose(torch.from_numpy(out_1).cpu(), out_2.cpu()))

>> False

Version information

-----
anndata             0.10.9
drvi                0.1.8
numpy               1.26.4
scanpy              1.9.5
scipy               1.12.0
scvi                1.0.4
session_info        v1.0.1
torch               2.3.1+cu121
yaml                6.0.2
-----
PIL                         11.2.1
absl                        2.3.0
aiohappyeyeballs            2.6.1
aiohttp                     3.12.6
aiosignal                   1.3.2
alembic                     1.16.2
annotated_types             0.7.0
antlr4                      NA
anyio                       NA
arrow                       1.3.0
asttokens                   NA
attr                        25.3.0
attrs                       25.3.0
babel                       2.15.0
backoff                     2.2.1
bs4                         4.12.3
certifi                     2025.04.26
charset_normalizer          3.4.2
chex                        0.1.7
click                       8.2.1
cloudpickle                 3.1.1
cmaes                       0.11.1
colorlog                    NA
comm                        0.2.2
croniter                    NA
cuda                        12.9.0
cupy                        13.4.1
cupy_backends               NA
cupyx                       NA
cycler                      0.12.1
cython_runtime              NA
dask                        2025.2.0
dateutil                    2.9.0.post0
debugpy                     1.8.14
decorator                   5.2.1
deepdiff                    7.0.1
defusedxml                  0.7.1
dockerpycreds               NA
docrep                      0.3.2
etils                       1.12.2
executing                   2.2.0
fastapi                     0.115.12
fastjsonschema              NA
fastrlock                   0.8.3
flax                        0.8.4
fqdn                        NA
frozenlist                  1.6.0
fsspec                      2024.12.0
git                         3.1.44
gitdb                       4.0.12
google                      NA
greenlet                    3.2.3
h11                         0.14.0
h5py                        3.13.0
httpcore                    1.0.5
idna                        3.10
igraph                      0.11.8
importlib_resources         NA
iniconfig                   NA
ipykernel                   6.29.5
ipywidgets                  8.1.7
isoduration                 NA
jaraco                      NA
jax                         0.4.20
jaxlib                      0.4.20
jedi                        0.19.2
jinja2                      3.1.6
joblib                      1.5.1
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.22.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.1
jupyterlab_server           2.27.2
kiwisolver                  1.4.8
leidenalg                   0.10.2
lightning                   2.0.9.post0
lightning_cloud             0.5.70
lightning_utilities         0.14.3
llvmlite                    0.43.0
mako                        1.3.10
markupsafe                  3.0.2
matplotlib                  3.10.3
ml_collections              1.1.0
ml_dtypes                   0.5.1
more_itertools              10.3.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
mudata                      0.3.2
multidict                   6.4.4
multipledispatch            0.6.0
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numpyro                     0.15.0
omegaconf                   2.3.0
opt_einsum                  3.4.0
optax                       0.2.1
optuna                      2.10.1
ordered_set                 4.1.0
overrides                   NA
packaging                   24.2
pandas                      2.2.3
parso                       0.8.4
pexpect                     4.9.0
pkg_resources               NA
platformdirs                4.3.8
pluggy                      1.6.0
prometheus_client           NA
prompt_toolkit              3.0.51
propcache                   0.3.1
psutil                      6.1.1
ptyprocess                  0.7.0
pure_eval                   0.2.3
py                          NA
pyarrow                     19.0.1
pydantic                    2.1.1
pydantic_core               2.4.0
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.2.3
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pygments                    2.19.1
pynvml                      NA
pyparsing                   3.2.3
pyro                        1.9.1
pytest                      8.4.1
python_multipart            0.0.20
pythonjsonlogger            NA
pytz                        2025.2
rapids_dask_dependency      NA
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
send2trash                  NA
sentry_sdk                  2.29.1
setuptools                  80.8.0
six                         1.17.0
sklearn                     1.6.1
smmap                       5.0.2
sniffio                     1.3.1
soupsieve                   2.5
sparse                      0.17.0
sqlalchemy                  2.0.41
stack_data                  0.6.3
starlette                   0.46.2
sympy                       1.14.0
tblib                       3.1.0
texttable                   1.7.0
threadpoolctl               3.6.0
tlz                         1.0.0
tomllib                     NA
toolz                       1.0.0
torchgen                    NA
torchmetrics                1.7.2
torchvision                 0.18.1+cu121
tornado                     6.5.1
tqdm                        4.67.1
traitlets                   5.14.3
tree                        0.1.9
typing_extensions           NA
uri_template                NA
urllib3                     1.26.20
uvicorn                     0.34.2
vscode                      NA
wandb                       0.19.11
wandb_gql                   NA
wandb_graphql               1.1
wandb_promise               2.3
wandb_watchdog              NA
wcwidth                     0.2.13
webcolors                   24.6.0
websocket                   1.8.0
websockets                  12.0
wrapt                       1.17.2
xarray                      2025.4.0
yarl                        1.20.0
zmq                         26.4.0
zoneinfo                    NA
-----
IPython             9.2.0
jupyter_client      8.6.2
jupyter_core        5.8.1
jupyterlab          4.4.5
notebook            7.4.4
-----
Python 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0]
Linux-5.14.0-570.25.1.el9_6.x86_64-x86_64-with-glibc2.34

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions