@@ -97,7 +97,9 @@ def load_query_data(
9797 validate_single_device = True ,
9898 )
9999
100- attr_dict , var_names , load_state_dict = _get_loaded_data (reference_model , device = device )
100+ attr_dict , var_names , load_state_dict = _get_loaded_data (
101+ reference_model , device = device , adata = adata
102+ )
101103
102104 if adata is not None :
103105 if isinstance (adata , MuData ):
@@ -216,7 +218,7 @@ def prepare_query_anndata(
216218 Query adata ready to use in `load_query_data` unless `return_reference_var_names`
217219 in which case a pd.Index of reference var names is returned.
218220 """
219- _ , var_names , _ = _get_loaded_data (reference_model , device = "cpu" )
221+ _ , var_names , _ = _get_loaded_data (reference_model , device = "cpu" , adata = adata )
220222 var_names = pd .Index (var_names )
221223
222224 if return_reference_var_names :
@@ -364,15 +366,19 @@ def requires_grad(key):
364366 par .requires_grad = False
365367
366368
367- def _get_loaded_data (reference_model , device = None ):
369+ def _get_loaded_data (reference_model , device = None , adata = None ):
368370 if isinstance (reference_model , str ):
369371 attr_dict , var_names , load_state_dict , _ = _load_saved_files (
370372 reference_model , load_adata = False , map_location = device
371373 )
372374 else :
373375 attr_dict = reference_model ._get_user_attributes ()
374376 attr_dict = {a [0 ]: a [1 ] for a in attr_dict if a [0 ][- 1 ] == "_" }
375- var_names = _get_var_names (reference_model .adata )
377+ var_names = (
378+ _get_var_names (reference_model .adata )
379+ if attr_dict ["registry_" ]["setup_method_name" ] != "setup_datamodule"
380+ else _get_var_names (adata )
381+ )
376382 load_state_dict = deepcopy (reference_model .module .state_dict ())
377383
378384 return attr_dict , var_names , load_state_dict
0 commit comments