Skip to content

Commit b4f8590

Browse files
ori-kron-wispre-commit-ci[bot]ethanweinbergerEthan Weinbergergithub-actions[bot]
authored
ci: Ori 1.3.x resolvi fixes backport (scverse#3313)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Weinberger <[email protected]> Co-authored-by: Ethan Weinberger <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: ori-kron-wis <[email protected]> Co-authored-by: Justin Hong <[email protected]> Co-authored-by: Can Ergen <[email protected]>
1 parent 64af263 commit b4f8590

File tree

7 files changed

+71
-31
lines changed

7 files changed

+71
-31
lines changed

CHANGELOG.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@ to [Semantic Versioning]. Full commit history is available in the
2323
- Add consideration for missing monitor set during early stopping. {pr}`3226`.
2424
- Fix bug in SysVI get_normalized_expression function. {pr}`3255`.
2525
- Add support for IntegratedGradients for multimodal models. {pr}`3264`.
26+
- Fix bug in resolVI get_normalized expression function. {pr}`3308`.
27+
- Fix bug in resolVI gene-assay dispersion. {pr}`3308`.
2628

2729
#### Changed
2830

2931
- Updated Scvi-Tools AWS hub to Weizmann instead of Berkeley. {pr}`3246`.
32+
- Updated resolVI to use rapids-singlecell. {pr}`3308`.
3033

3134
#### Removed
3235

@@ -53,7 +56,7 @@ to [Semantic Versioning]. Full commit history is available in the
5356
- Add scib-metrics support for {class}`scvi.autotune.AutotuneExperiment` and
5457
{class}`scvi.train._callbacks.ScibCallback` for autotune for scib metrics {pr}`3168`.
5558
- Add Support of dask arrays in AnnTorchDataset. {pr}`3193`.
56-
- Add a common use cases section in the docs user guide. {pr}`3200`.
59+
- Add a {doc}`/user_guide/use_case` section in the docs, {pr}`3200`.
5760
- Add {class}`scvi.external.SysVI` for cycle consistency loss and VampPrior {pr}`3195`.
5861

5962
#### Fixed
@@ -111,7 +114,7 @@ to [Semantic Versioning]. Full commit history is available in the
111114
- Added adaptive handling for last training minibatch of 1-2 cells in case of
112115
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
113116
validation set, if available. {pr}`3036`.
114-
- Add `batch_key` and `labels_key` to {meth}`scvi.external.SCAR.setup_anndata`. {pr}`3045`.
117+
- Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. {pr}`3045`.
115118
- Implemented variance of ZINB distribution. {pr}`3044`.
116119
- Support for minified mode while retaining counts to skip the encoder.
117120
- New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of
@@ -125,7 +128,7 @@ to [Semantic Versioning]. Full commit history is available in the
125128
- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
126129
to correctly compute the maxmimum log-density across in-sample cells rather than the
127130
aggregated posterior log-density {pr}`3007`.
128-
- Fix references to `scvi.external` in {meth}`scvi.external.SCAR.setup_anndata`.
131+
- Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`.
129132
- Fix gimVI to append mini batches first into CPU during get_imputed and get_latent operations {pr}`3058`.
130133

131134
#### Changed
@@ -137,9 +140,9 @@ to [Semantic Versioning]. Full commit history is available in the
137140
#### Added
138141

139142
- Add support for Python 3.12 {pr}`2966`.
140-
- Add support for categorial covariates in scArches in {class}`scvi.model.base.ArchesMixin` {pr}`2936`.
143+
- Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`.
141144
- Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`.
142-
- Add {meth}`scvi.external.POISSONVI.get_region_factors` {pr}`2940`.
145+
- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`.
143146
- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in
144147
{class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`.
145148
- Add option for using external indexes in data splitting classes that are under `scvi.dataloaders`

src/scvi/external/resolvi/_model.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import importlib.util
43
import logging
54
from functools import partial
65
from typing import TYPE_CHECKING
@@ -343,7 +342,11 @@ def setup_anndata(
343342
cls.register_manager(adata_manager)
344343

345344
@staticmethod
346-
def _prepare_data(adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None, **kwargs):
345+
def _prepare_data(
346+
adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None, slice_key=None, **kwargs
347+
):
348+
if slice_key is not None:
349+
batch_key = slice_key
347350
try:
348351
import scanpy
349352
from sklearn.neighbors._base import _kneighbors_from_graph
@@ -365,13 +368,15 @@ def _prepare_data(adata, n_neighbors=10, spatial_rep="X_spatial", batch_key=None
365368

366369
for index in indices:
367370
sub_data = adata[index].copy()
368-
if importlib.util.find_spec("cuml") is not None:
369-
method = "rapids"
370-
else:
371-
method = "umap"
372-
scanpy.pp.neighbors(
373-
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep, method=method
374-
)
371+
try:
372+
import rapids_singlecell
373+
374+
print("RAPIDS SingleCell is installed and can be imported")
375+
rapids_singlecell.pp.neighbors(
376+
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep
377+
)
378+
except ImportError:
379+
scanpy.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
375380
distances = sub_data.obsp["distances"] ** 2
376381

377382
distance_neighbor[index, :], index_neighbor_batch = _kneighbors_from_graph(

src/scvi/external/resolvi/_module.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ def __init__(
163163
init_px_r = torch.full([n_input, n_batch], 0.01)
164164
else:
165165
raise ValueError(
166-
"dispersion must be one of ['gene', 'gene-batch', 'gene-label'], but input was "
167-
"{}.format(self.dispersion)"
166+
f"dispersion must be one of ['gene', 'gene-batch'], but input was {dispersion}."
168167
)
169168
self.register_buffer("px_r", init_px_r)
170169

@@ -751,8 +750,7 @@ def __init__(
751750
init_px_r = torch.full([n_input, n_batch], 0.01)
752751
else:
753752
raise ValueError(
754-
"dispersion must be one of ['gene', 'gene-batch', 'gene-label'], but input was "
755-
"{}.format(dispersion)"
753+
f"dispersion must be one of ['gene', 'gene-batch'], but input was {dispersion}."
756754
)
757755
self.register_buffer("px_r", init_px_r)
758756
self.register_buffer("per_neighbor_diffusion_init", torch.zeros([n_obs, n_neighbors]))
@@ -868,7 +866,10 @@ def forward( # not used arguments to have same set of arguments in model and gu
868866

869867
if self.dispersion == "gene-batch":
870868
px_r_inv = F.linear(
871-
torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch), px_r_mle
869+
torch.nn.functional.one_hot(batch_index.flatten(), self.n_batch).to(
870+
px_r_mle.dtype
871+
),
872+
px_r_mle,
872873
)
873874
elif self.dispersion == "gene":
874875
px_r_inv = px_r_mle

src/scvi/external/resolvi/_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def get_normalized_expression(
229229
library_size
230230
Scale the expression frequencies to a common library size.
231231
This allows gene expression levels to be interpreted on a common scale of relevant
232-
magnitude. If set to `"latent"`, use the latent library size.
232+
magnitude.
233233
n_samples
234234
Number of posterior samples to use for estimation.
235235
n_samples_overall
@@ -301,32 +301,28 @@ def get_normalized_expression(
301301
kwargs["batch_index"],
302302
*categorical_input,
303303
)
304-
z = torch.distributions.Normal(qz_m, qz_v.sqrt()).sample(
305-
[
306-
n_samples,
307-
]
308-
)
304+
z = torch.distributions.Normal(qz_m, qz_v.sqrt()).sample([n_samples])
309305

310306
if kwargs["cat_covs"] is not None:
311307
categorical_input = list(torch.split(kwargs["cat_covs"], 1, dim=1))
312308
else:
313309
categorical_input = ()
314310
if batch is not None:
315-
batch = torch.full_like(kwargs["batch"], batch)
311+
batch = torch.full_like(kwargs["batch_index"], batch)
316312
else:
317313
batch = kwargs["batch_index"]
318314

319315
px_scale, _, px_rate, _ = self.module.model.decoder(
320316
self.module.model.dispersion, z, kwargs["library"], batch, *categorical_input
321317
)
322318
if library_size is not None:
323-
exp_ = library_size * px_scale.reshape(-1, px_scale.shape[-1])
319+
exp_ = library_size * px_scale
324320
else:
325-
exp_ = px_rate.reshape(-1, px_scale.shape[-1])
321+
exp_ = px_rate
326322

327323
exp_ = exp_[..., gene_mask]
328324
per_batch_exprs.append(exp_[None].cpu())
329-
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).numpy()
325+
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).mean(0).numpy()
330326
exprs.append(per_batch_exprs)
331327

332328
exprs = np.concatenate(exprs, axis=1)

src/scvi/train/_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,4 +211,13 @@ def fit(self, *args, **kwargs):
211211
category=UserWarning,
212212
message="`LightningModule.configure_optimizers` returned `None`",
213213
)
214-
super().fit(*args, **kwargs)
214+
try:
215+
super().fit(*args, **kwargs)
216+
except NameError:
217+
import gc
218+
219+
gc.collect()
220+
import torch
221+
222+
if torch.cuda.is_available():
223+
torch.cuda.empty_cache()

src/scvi/train/_trainrunner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,16 @@ def __call__(self):
109109
if hasattr(self.data_splitter, "n_val"):
110110
self.training_plan.n_obs_validation = self.data_splitter.n_val
111111

112-
self.trainer.fit(self.training_plan, self.data_splitter)
112+
try:
113+
self.trainer.fit(self.training_plan, self.data_splitter)
114+
except NameError:
115+
import gc
116+
117+
gc.collect()
118+
import torch
119+
120+
if torch.cuda.is_available():
121+
torch.cuda.empty_cache()
113122
self._update_history()
114123

115124
# data splitter only gets these attrs after fit

tests/external/resolvi/test_resolvi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def test_resolvi_train(adata):
2323
model.train(
2424
max_epochs=2,
2525
)
26+
model = RESOLVI(adata, dispersion="gene-batch")
27+
model.train(
28+
max_epochs=2,
29+
)
2630

2731

2832
def test_resolvi_save_load(adata):
@@ -52,8 +56,21 @@ def test_resolvi_downstream(adata):
5256
)
5357
latent = model.get_latent_representation()
5458
assert latent.shape == (adata.n_obs, model.module.n_latent)
59+
counts = model.get_normalized_expression(n_samples=31, library_size=10000)
60+
counts = model.get_normalized_expression_importance(n_samples=30, library_size=10000)
61+
print("FFFFFF", counts.shape)
5562
model.differential_expression(groupby="labels")
5663
model.differential_expression(groupby="labels", weights="importance")
64+
model.sample_posterior(
65+
model=model.module.model_residuals,
66+
num_samples=30,
67+
return_samples=False,
68+
return_sites=None,
69+
batch_size=1000,
70+
)
71+
model.sample_posterior(
72+
model=model.module.model_residuals, num_samples=30, return_samples=False, batch_size=1000
73+
)
5774
model_query = model.load_query_data(reference_model=model, adata=adata)
5875
model_query = model.load_query_data(reference_model="test_resolvi", adata=adata)
5976
model_query.train(

0 commit comments

Comments
 (0)