88import pandas as pd
99import torch
1010from anndata import AnnData
11+ from lightning import LightningDataModule
1112from mudata import MuData
1213from scipy .sparse import csr_matrix
1314
@@ -39,8 +40,9 @@ class ArchesMixin:
3940 @devices_dsp .dedent
4041 def load_query_data (
4142 cls ,
42- adata : AnnOrMuData ,
43- reference_model : Union [str , BaseModelClass ],
43+ adata : None | AnnOrMuData = None ,
44+ reference_model : Union [str , BaseModelClass ] = None ,
45+ datamodule : None | LightningDataModule = None ,
4446 inplace_subset_query_vars : bool = False ,
4547 accelerator : str = "auto" ,
4648 device : Union [int , str ] = "auto" ,
@@ -83,6 +85,11 @@ def load_query_data(
8385 freeze_classifier
8486 Whether to freeze classifier completely. Only applies to `SCANVI`.
8587 """
88+ if reference_model is None :
89+ raise ValueError ("Please provide a reference model as string or loaded model." )
90+ if adata is None and datamodule is None :
91+ raise ValueError ("Please provide either an AnnData or a datamodule." )
92+
8693 _ , _ , device = parse_device_args (
8794 accelerator = accelerator ,
8895 devices = device ,
@@ -92,44 +99,45 @@ def load_query_data(
9299
93100 attr_dict , var_names , load_state_dict = _get_loaded_data (reference_model , device = device )
94101
95- if isinstance (adata , MuData ):
96- for modality in adata .mod :
102+ if adata is not None :
103+ if isinstance (adata , MuData ):
104+ for modality in adata .mod :
105+ if inplace_subset_query_vars :
106+ logger .debug (f"Subsetting { modality } query vars to reference vars." )
107+ adata [modality ]._inplace_subset_var (var_names [modality ])
108+ _validate_var_names (adata [modality ], var_names [modality ])
109+
110+ else :
97111 if inplace_subset_query_vars :
98- logger .debug (f "Subsetting { modality } query vars to reference vars." )
99- adata [ modality ] ._inplace_subset_var (var_names [ modality ] )
100- _validate_var_names (adata [ modality ] , var_names [ modality ] )
112+ logger .debug ("Subsetting query vars to reference vars." )
113+ adata ._inplace_subset_var (var_names )
114+ _validate_var_names (adata , var_names )
101115
102- else :
103116 if inplace_subset_query_vars :
104117 logger .debug ("Subsetting query vars to reference vars." )
105118 adata ._inplace_subset_var (var_names )
106119 _validate_var_names (adata , var_names )
107120
108- if inplace_subset_query_vars :
109- logger .debug ("Subsetting query vars to reference vars." )
110- adata ._inplace_subset_var (var_names )
111- _validate_var_names (adata , var_names )
121+ registry = attr_dict .pop ("registry_" )
122+ if _MODEL_NAME_KEY in registry and registry [_MODEL_NAME_KEY ] != cls .__name__ :
123+ raise ValueError ("It appears you are loading a model from a different class." )
112124
113- registry = attr_dict .pop ("registry_" )
114- if _MODEL_NAME_KEY in registry and registry [_MODEL_NAME_KEY ] != cls .__name__ :
115- raise ValueError ("It appears you are loading a model from a different class." )
125+ if _SETUP_ARGS_KEY not in registry :
126+ raise ValueError (
127+ "Saved model does not contain original setup inputs. "
128+ "Cannot load the original setup."
129+ )
116130
117- if _SETUP_ARGS_KEY not in registry :
118- raise ValueError (
119- "Saved model does not contain original setup inputs. "
120- "Cannot load the original setup."
131+ setup_method = getattr (cls , registry [_SETUP_METHOD_NAME ])
132+ setup_method (
133+ adata ,
134+ source_registry = registry ,
135+ extend_categories = True ,
136+ allow_missing_labels = True ,
137+ ** registry [_SETUP_ARGS_KEY ],
121138 )
122139
123- setup_method = getattr (cls , registry [_SETUP_METHOD_NAME ])
124- setup_method (
125- adata ,
126- source_registry = registry ,
127- extend_categories = True ,
128- allow_missing_labels = True ,
129- ** registry [_SETUP_ARGS_KEY ],
130- )
131-
132- model = _initialize_model (cls , adata , attr_dict )
140+ model = _initialize_model (cls , adata , datamodule , attr_dict )
133141 adata_manager = model .get_anndata_manager (adata , required = True )
134142
135143 if REGISTRY_KEYS .CAT_COVS_KEY in adata_manager .data_registry :
0 commit comments