Skip to content

Commit 17282cd

Browse files
committed
Fixed attr_dict
1 parent 14f343d commit 17282cd

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

src/scvi/model/base/_archesmixin.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import torch
1010
from anndata import AnnData
11+
from lightning import LightningDataModule
1112
from mudata import MuData
1213
from scipy.sparse import csr_matrix
1314

@@ -39,8 +40,9 @@ class ArchesMixin:
3940
@devices_dsp.dedent
4041
def load_query_data(
4142
cls,
42-
adata: AnnOrMuData,
43-
reference_model: Union[str, BaseModelClass],
43+
adata: None | AnnOrMuData = None,
44+
reference_model: Union[str, BaseModelClass] = None,
45+
datamodule: None | LightningDataModule = None,
4446
inplace_subset_query_vars: bool = False,
4547
accelerator: str = "auto",
4648
device: Union[int, str] = "auto",
@@ -83,6 +85,11 @@ def load_query_data(
8385
freeze_classifier
8486
Whether to freeze classifier completely. Only applies to `SCANVI`.
8587
"""
88+
if reference_model is None:
89+
raise ValueError("Please provide a reference model as string or loaded model.")
90+
if adata is None and datamodule is None:
91+
raise ValueError("Please provide either an AnnData or a datamodule.")
92+
8693
_, _, device = parse_device_args(
8794
accelerator=accelerator,
8895
devices=device,
@@ -92,44 +99,45 @@ def load_query_data(
9299

93100
attr_dict, var_names, load_state_dict = _get_loaded_data(reference_model, device=device)
94101

95-
if isinstance(adata, MuData):
96-
for modality in adata.mod:
102+
if adata is not None:
103+
if isinstance(adata, MuData):
104+
for modality in adata.mod:
105+
if inplace_subset_query_vars:
106+
logger.debug(f"Subsetting {modality} query vars to reference vars.")
107+
adata[modality]._inplace_subset_var(var_names[modality])
108+
_validate_var_names(adata[modality], var_names[modality])
109+
110+
else:
97111
if inplace_subset_query_vars:
98-
logger.debug(f"Subsetting {modality} query vars to reference vars.")
99-
adata[modality]._inplace_subset_var(var_names[modality])
100-
_validate_var_names(adata[modality], var_names[modality])
112+
logger.debug("Subsetting query vars to reference vars.")
113+
adata._inplace_subset_var(var_names)
114+
_validate_var_names(adata, var_names)
101115

102-
else:
103116
if inplace_subset_query_vars:
104117
logger.debug("Subsetting query vars to reference vars.")
105118
adata._inplace_subset_var(var_names)
106119
_validate_var_names(adata, var_names)
107120

108-
if inplace_subset_query_vars:
109-
logger.debug("Subsetting query vars to reference vars.")
110-
adata._inplace_subset_var(var_names)
111-
_validate_var_names(adata, var_names)
121+
registry = attr_dict.pop("registry_")
122+
if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
123+
raise ValueError("It appears you are loading a model from a different class.")
112124

113-
registry = attr_dict.pop("registry_")
114-
if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
115-
raise ValueError("It appears you are loading a model from a different class.")
125+
if _SETUP_ARGS_KEY not in registry:
126+
raise ValueError(
127+
"Saved model does not contain original setup inputs. "
128+
"Cannot load the original setup."
129+
)
116130

117-
if _SETUP_ARGS_KEY not in registry:
118-
raise ValueError(
119-
"Saved model does not contain original setup inputs. "
120-
"Cannot load the original setup."
131+
setup_method = getattr(cls, registry[_SETUP_METHOD_NAME])
132+
setup_method(
133+
adata,
134+
source_registry=registry,
135+
extend_categories=True,
136+
allow_missing_labels=True,
137+
**registry[_SETUP_ARGS_KEY],
121138
)
122139

123-
setup_method = getattr(cls, registry[_SETUP_METHOD_NAME])
124-
setup_method(
125-
adata,
126-
source_registry=registry,
127-
extend_categories=True,
128-
allow_missing_labels=True,
129-
**registry[_SETUP_ARGS_KEY],
130-
)
131-
132-
model = _initialize_model(cls, adata, attr_dict)
140+
model = _initialize_model(cls, adata, datamodule, attr_dict)
133141
adata_manager = model.get_anndata_manager(adata, required=True)
134142

135143
if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry:

0 commit comments

Comments
 (0)