Skip to content

Commit a4143f5

Browse files
committed
added some fixes based on custom data loader test
1 parent 17282cd commit a4143f5

File tree

4 files changed

+268
-139
lines changed

4 files changed

+268
-139
lines changed

src/scvi/model/base/_archesmixin.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def load_query_data(
9797
validate_single_device=True,
9898
)
9999

100-
attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device)
100+
attr_dict, var_names, load_state_dict = _get_loaded_data(
101+
reference_model, device=device, adata=adata
102+
)
101103

102104
if adata is not None:
103105
if isinstance(adata, MuData):
@@ -216,7 +218,7 @@ def prepare_query_anndata(
216218
Query adata ready to use in `load_query_data` unless `return_reference_var_names`
217219
in which case a pd.Index of reference var names is returned.
218220
"""
219-
_, var_names, _ = _get_loaded_data(reference_model, device="cpu")
221+
_, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata)
220222
var_names = pd.Index(var_names)
221223

222224
if return_reference_var_names:
@@ -364,15 +366,19 @@ def requires_grad(key):
364366
par.requires_grad = False
365367

366368

367-
def _get_loaded_data(reference_model, device=None):
369+
def _get_loaded_data(reference_model, device=None, adata=None):
368370
if isinstance(reference_model, str):
369371
attr_dict, var_names, load_state_dict, _ = _load_saved_files(
370372
reference_model, load_adata=False, map_location=device
371373
)
372374
else:
373375
attr_dict = reference_model._get_user_attributes()
374376
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
375-
var_names = _get_var_names(reference_model.adata)
377+
var_names = (
378+
_get_var_names(reference_model.adata)
379+
if attr_dict["registry_"]["setup_method_name"] != "setup_datamodule"
380+
else _get_var_names(adata)
381+
)
376382
load_state_dict = deepcopy(reference_model.module.state_dict())
377383

378384
return attr_dict, var_names, load_state_dict

tests/dataloaders/test_custom_dataloader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from pprint import pprint
45

56
import numpy as np
67
import scanpy as sc
@@ -41,6 +42,11 @@
4142
# Loading the model (just as a compariosn)
4243
model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata)
4344

45+
# when loading from disk
46+
scvi.model.SCVI.prepare_query_anndata(adata, model_dir)
47+
# O
48+
scvi.model.SCVI.prepare_query_anndata(adata, model_orig_loaded)
49+
4450
# Obtaining model outputs
4551
SCVI_LATENT_KEY = "X_scVI"
4652
latent = model_orig.get_latent_representation()
@@ -53,6 +59,8 @@
5359
# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict()
5460
adata_manager.registry[_constants._FIELD_REGISTRIES_KEY]
5561

62+
pprint(adata_manager.registry)
63+
5664
# Plot UMAP and save the figure for later check
5765
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
5866
sc.tl.umap(adata, neighbors_key="scvi")

0 commit comments

Comments
 (0)