-
Couldn't load subscription status.
- Fork 5
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Report
Hey!
We have observed that teh decode_latent_samples returns gene expression values not corrected for the library size. Here is a code snippet that we used for testing:
import sys
import yaml
from functools import partial
import scanpy as sc
import scipy as sp
import anndata as ad
import numpy as np
import torch
import drvi
from typing import Any
import scvi
from scvi import REGISTRY_KEYS
from drvi.nn_modules.noise_model import library_size_correction
# Load data
adata_gex = sc.read_h5ad("barcodedData/all_dc_dataset_ready_to_integr.h5ad")
adata_gex.obs["batch"] = adata_gex.obs["Seq.Run"].tolist()
adata_gex = adata_gex[~(adata_gex.obs["donor"] == "A"), :].copy()
adata_gex.obs["donor"] = adata_gex.obs["donor"].cat.add_categories("")
adata_gex.obs["donor"] = adata_gex.obs["donor"].fillna("")
# Custom decoding function
def iterate_on_decoded_latent_samples(
vae,
z: torch.Tensor,
lib: np.ndarray | None = None,
cat_values: np.ndarray | None = None,
cont_values: np.ndarray | None = None,
batch_size: int = scvi.settings.batch_size,
map_cat_values: bool = False,
) -> torch.tensor:
"""Iterate over decoder outputs and aggregate the results.
This method processes latent samples through the generative model in batches,
applies a custom function to each batch output, and aggregates the results.
Parameters
----------
z
Latent samples with shape (n_samples, n_latent).
step_func
Function to apply to the decoder output at each step.
Should accept (generative_outputs, store) as arguments.
aggregation_func
Function to aggregate the step results from the store.
Should accept the store list and return the final result.
lib
Library size array with shape (n_samples,).
If None, defaults to 1e4 for all samples.
cat_values
Categorical covariates with shape (n_samples, n_cat_covs).
Required if model has categorical covariates.
cont_values
Continuous covariates with shape (n_samples, n_cont_covs).
batch_size
Minibatch size for data loading into model.
map_cat_values
Whether to map categorical covariates to integers based on
the AnnData manager pipeline.
Returns
-------
np.ndarray
Aggregated results from processing all latent samples.
Notes
-----
This method operates in inference mode and processes data in batches
to manage memory usage. The step_func receives the generative outputs
and a store variable for accumulating results.
If map_cat_values is True, categorical values are automatically mapped
to integers using the model's category mappings.
Examples
--------
>>> import numpy as np
>>> # Define custom step function to extract means
>>> def extract_means(gen_output, store):
... store.append(gen_output["params"]["mean"].detach().cpu())
>>> # Define aggregation function to concatenate results
>>> def concatenate_results(store):
... return torch.cat(store, dim=0).numpy()
>>> # Process latent samples
>>> z = np.random.randn(50, 32) # assuming 32 latent dimensions
>>> result = model.iterate_on_decoded_latent_samples(
... z=z, step_func=extract_means, aggregation_func=concatenate_results
... )
>>> print(result.shape) # (50, n_genes)
"""
def step_func(gen_output: dict[str, Any], store: list[Any]) -> None:
store.append(gen_output["px"].log_m)
def aggregation_func(store: list[Any]) -> torch.Tensor:
return torch.cat(store, dim=0)
store: list[Any] = []
vae.module.eval()
if cat_values is not None and map_cat_values:
if cat_values.ndim == 1: # For a user not noticing cat_values should be 2d!
cat_values = cat_values.reshape(-1, 1)
mapped_values = np.zeros_like(cat_values)
for i, (_label, map_keys) in enumerate(
vae.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)["mappings"].items()
):
cat_mapping = dict(zip(map_keys, range(len(map_keys)), strict=False))
mapped_values[:, i] = np.vectorize(cat_mapping.get)(cat_values[:, i])
cat_values = mapped_values.astype(np.int32)
#with torch.no_grad():
for i in np.arange(0, z.shape[0], batch_size):
slice = np.arange(i, min(i + batch_size, z.shape[0]))
z_tensor = torch.tensor(z[slice])
if lib is None:
lib_tensor = torch.tensor([1e4] * slice.shape[0])
else:
lib_tensor = torch.tensor(lib[slice])
cat_tensor = torch.tensor(cat_values[slice]) if cat_values is not None else None
cont_tensor = torch.tensor(cont_values[slice]) if cont_values is not None else None
batch_tensor = None
with torch.no_grad():
gen_input = vae.module._get_generative_input(
tensors={
REGISTRY_KEYS.BATCH_KEY: batch_tensor,
REGISTRY_KEYS.LABELS_KEY: None,
REGISTRY_KEYS.CONT_COVS_KEY: cont_tensor,
REGISTRY_KEYS.CAT_COVS_KEY: cat_tensor,
},
inference_outputs={
"z": z_tensor,
"library": lib_tensor,
"gene_likelihood_additional_info": {},
},
)
gen_output = vae.module.generative(**gen_input)
#print(gen_output["params"]["mean"].requires_grad)
step_func(gen_output, store)
result = aggregation_func(store)
del store, gen_output
torch.cuda.empty_cache()
return result
# Load trained model
vae = drvi.model.DRVI.load("./models/", adata_gex, prefix="all_dc_128_no_ha_")
# Generate random latent
z = np.random.randn(100, 128)
# Decode using original decoding function
out_1 = vae.decode_latent_samples(torch.from_numpy(z).to(dtype=torch.float32), cat_values = np.vstack([np.array(["10x"] * z.shape[0]), np.array(["H.B"] * z.shape[0])]).T, cont_values=None, map_cat_values=True,)
# Decode using fixed decoding function
out_2 = iterate_on_decoded_latent_samples(vae, torch.from_numpy(z).to(dtype=torch.float32), cat_values = np.vstack([np.array(["10x"] * z.shape[0]), np.array(["B"] * z.shape[0])]).T, cont_values=None, map_cat_values=True,)
# Comparing outputs
print(torch.allclose(torch.from_numpy(out_1).cpu(), out_2.cpu()))
>> False
Version information
-----
anndata 0.10.9
drvi 0.1.8
numpy 1.26.4
scanpy 1.9.5
scipy 1.12.0
scvi 1.0.4
session_info v1.0.1
torch 2.3.1+cu121
yaml 6.0.2
-----
PIL 11.2.1
absl 2.3.0
aiohappyeyeballs 2.6.1
aiohttp 3.12.6
aiosignal 1.3.2
alembic 1.16.2
annotated_types 0.7.0
antlr4 NA
anyio NA
arrow 1.3.0
asttokens NA
attr 25.3.0
attrs 25.3.0
babel 2.15.0
backoff 2.2.1
bs4 4.12.3
certifi 2025.04.26
charset_normalizer 3.4.2
chex 0.1.7
click 8.2.1
cloudpickle 3.1.1
cmaes 0.11.1
colorlog NA
comm 0.2.2
croniter NA
cuda 12.9.0
cupy 13.4.1
cupy_backends NA
cupyx NA
cycler 0.12.1
cython_runtime NA
dask 2025.2.0
dateutil 2.9.0.post0
debugpy 1.8.14
decorator 5.2.1
deepdiff 7.0.1
defusedxml 0.7.1
dockerpycreds NA
docrep 0.3.2
etils 1.12.2
executing 2.2.0
fastapi 0.115.12
fastjsonschema NA
fastrlock 0.8.3
flax 0.8.4
fqdn NA
frozenlist 1.6.0
fsspec 2024.12.0
git 3.1.44
gitdb 4.0.12
google NA
greenlet 3.2.3
h11 0.14.0
h5py 3.13.0
httpcore 1.0.5
idna 3.10
igraph 0.11.8
importlib_resources NA
iniconfig NA
ipykernel 6.29.5
ipywidgets 8.1.7
isoduration NA
jaraco NA
jax 0.4.20
jaxlib 0.4.20
jedi 0.19.2
jinja2 3.1.6
joblib 1.5.1
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.22.0
jsonschema_specifications NA
jupyter_events 0.10.0
jupyter_server 2.14.1
jupyterlab_server 2.27.2
kiwisolver 1.4.8
leidenalg 0.10.2
lightning 2.0.9.post0
lightning_cloud 0.5.70
lightning_utilities 0.14.3
llvmlite 0.43.0
mako 1.3.10
markupsafe 3.0.2
matplotlib 3.10.3
ml_collections 1.1.0
ml_dtypes 0.5.1
more_itertools 10.3.0
mpl_toolkits NA
mpmath 1.3.0
msgpack 1.1.0
mudata 0.3.2
multidict 6.4.4
multipledispatch 0.6.0
natsort 8.4.0
nbformat 5.10.4
numba 0.60.0
numpyro 0.15.0
omegaconf 2.3.0
opt_einsum 3.4.0
optax 0.2.1
optuna 2.10.1
ordered_set 4.1.0
overrides NA
packaging 24.2
pandas 2.2.3
parso 0.8.4
pexpect 4.9.0
pkg_resources NA
platformdirs 4.3.8
pluggy 1.6.0
prometheus_client NA
prompt_toolkit 3.0.51
propcache 0.3.1
psutil 6.1.1
ptyprocess 0.7.0
pure_eval 0.2.3
py NA
pyarrow 19.0.1
pydantic 2.1.1
pydantic_core 2.4.0
pydev_ipython NA
pydevconsole NA
pydevd 3.2.3
pydevd_file_utils NA
pydevd_plugins NA
pydevd_tracing NA
pygments 2.19.1
pynvml NA
pyparsing 3.2.3
pyro 1.9.1
pytest 8.4.1
python_multipart 0.0.20
pythonjsonlogger NA
pytz 2025.2
rapids_dask_dependency NA
referencing NA
requests 2.32.3
rfc3339_validator 0.1.4
rfc3986_validator 0.1.1
rich NA
rpds NA
send2trash NA
sentry_sdk 2.29.1
setuptools 80.8.0
six 1.17.0
sklearn 1.6.1
smmap 5.0.2
sniffio 1.3.1
soupsieve 2.5
sparse 0.17.0
sqlalchemy 2.0.41
stack_data 0.6.3
starlette 0.46.2
sympy 1.14.0
tblib 3.1.0
texttable 1.7.0
threadpoolctl 3.6.0
tlz 1.0.0
tomllib NA
toolz 1.0.0
torchgen NA
torchmetrics 1.7.2
torchvision 0.18.1+cu121
tornado 6.5.1
tqdm 4.67.1
traitlets 5.14.3
tree 0.1.9
typing_extensions NA
uri_template NA
urllib3 1.26.20
uvicorn 0.34.2
vscode NA
wandb 0.19.11
wandb_gql NA
wandb_graphql 1.1
wandb_promise 2.3
wandb_watchdog NA
wcwidth 0.2.13
webcolors 24.6.0
websocket 1.8.0
websockets 12.0
wrapt 1.17.2
xarray 2025.4.0
yarl 1.20.0
zmq 26.4.0
zoneinfo NA
-----
IPython 9.2.0
jupyter_client 8.6.2
jupyter_core 5.8.1
jupyterlab 4.4.5
notebook 7.4.4
-----
Python 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0]
Linux-5.14.0-570.25.1.el9_6.x86_64-x86_64-with-glibc2.34
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working