@@ -35,6 +35,8 @@ def __init__(
3535 inducing_num : int = 512 ,
3636 normalize_spatial : bool = True ,
3737 ):
38+
39+ self .keys = keys
3840 # Source data
3941 source_adata = source_adata .copy ()
4042 source_adata .X = source_adata .X if layer == "X" else source_adata .layers [layer ]
@@ -70,7 +72,6 @@ def __init__(
7072 self .train_y = self .train_y .squeeze ()
7173
7274 self .nx = ot .backend .get_backend (self .train_x , self .train_y )
73-
7475 self .normalize_spatial = normalize_spatial
7576 if self .normalize_spatial :
7677 self .train_x = self .normalize_coords (self .train_x )
@@ -96,15 +97,14 @@ def __init__(
9697
9798 self .PCA_reduction = False
9899 self .info_keys = {"obs_keys" : obs_keys , "var_keys" : var_keys }
99- print (self .info_keys )
100100
101101 # Target data
102102 self .target_points = torch .from_numpy (target_points ).float ()
103103 self .target_points = self .target_points .cpu () if self .device == "cpu" else self .target_points .cuda ()
104104
105105 def normalize_coords (self , data : Union [np .ndarray , torch .Tensor ], given_normalize : bool = False ):
106106 if not given_normalize :
107- self .mean_data = _unsqueeze ( self .nx )( self . nx . mean (data , axis = 0 ), 0 )
107+ self .mean_data = self .nx . mean (data , axis = 0 )[ None , :]
108108 data = data - self .mean_data
109109 if not given_normalize :
110110 self .variance = self .nx .sqrt (self .nx .sum (data ** 2 ) / data .shape [0 ])
@@ -114,12 +114,16 @@ def normalize_coords(self, data: Union[np.ndarray, torch.Tensor], given_normaliz
114114 def inference (
115115 self ,
116116 training_iter : int = 50 ,
117+ verbose : bool = True ,
117118 ):
118119 self .likelihood = GaussianLikelihood ()
119120 if self .method == "SVGP" :
120121 self .GPR_model = Approx_GPModel (inducing_points = self .inducing_points )
121122 elif self .method == "ExactGP" :
122- self .GPR_model = Exact_GPModel (self .train_x , self .train_y , self .likelihood )
123+ self .GPR_models = [
124+ Exact_GPModel (self .train_x , self .train_y [:, i ], self .likelihoods [i ])
125+ for i in range (self .train_y .shape [1 ])
126+ ]
123127 # if to convert to GPU
124128 if self .device != "cpu" :
125129 self .GPR_model = self .GPR_model .cuda ()
@@ -134,6 +138,8 @@ def inference(
134138 method = self .method ,
135139 N = self .N ,
136140 device = self .device ,
141+ verbose = verbose ,
142+ keys = self .keys ,
137143 )
138144
139145 self .GPR_model .eval ()
@@ -181,6 +187,7 @@ def gp_interpolation(
181187 batch_size : int = 1024 ,
182188 shuffle : bool = True ,
183189 inducing_num : int = 512 ,
190+ verbose : bool = True ,
184191) -> AnnData :
185192 """
186193 Learn a continuous mapping from space to gene expression pattern with the Gaussian Process method.
@@ -197,36 +204,60 @@ def gp_interpolation(
197204 Returns:
198205 interp_adata: an anndata object that has interpolated expression.
199206 """
207+ assert keys != None , "`keys` cannot be None."
208+ keys = [keys ] if isinstance (keys , str ) else keys
209+ obs_keys = [key for key in keys if key in source_adata .obs .keys ()]
210+ var_keys = [key for key in keys if key in source_adata .var_names .tolist ()]
211+ info_keys = {"obs_keys" : obs_keys , "var_keys" : var_keys }
212+ print (info_keys )
213+ obs_data = []
214+ var_data = []
215+ if len (obs_keys ) != 0 :
216+ for key in obs_keys :
217+ GPR = Imputation_GPR (
218+ source_adata = source_adata ,
219+ target_points = target_points ,
220+ keys = [key ],
221+ spatial_key = spatial_key ,
222+ layer = layer ,
223+ device = device ,
224+ method = method ,
225+ batch_size = batch_size ,
226+ shuffle = shuffle ,
227+ inducing_num = inducing_num ,
228+ )
229+ GPR .inference (training_iter = training_iter , verbose = verbose )
200230
201- # Inference
202- GPR = Imputation_GPR (
203- source_adata = source_adata ,
204- target_points = target_points ,
205- keys = keys ,
206- spatial_key = spatial_key ,
207- layer = layer ,
208- device = device ,
209- method = method ,
210- batch_size = batch_size ,
211- shuffle = shuffle ,
212- inducing_num = inducing_num ,
213- )
214- GPR .inference (training_iter = training_iter )
231+ # Interpolation
232+ target_info_data = GPR .interpolate (use_chunk = True )
233+ obs_data .append (target_info_data [:, None ])
234+ if len (var_keys ) != 0 :
235+ for key in var_keys :
236+ GPR = Imputation_GPR (
237+ source_adata = source_adata ,
238+ target_points = target_points ,
239+ keys = [key ],
240+ spatial_key = spatial_key ,
241+ layer = layer ,
242+ device = device ,
243+ method = method ,
244+ batch_size = batch_size ,
245+ shuffle = shuffle ,
246+ inducing_num = inducing_num ,
247+ )
248+ GPR .inference (training_iter = training_iter , verbose = verbose )
249+
250+ # Interpolation
251+ target_info_data = GPR .interpolate (use_chunk = True )
252+ var_data .append (target_info_data [:, None ])
215253
216- # Interpolation
217- target_info_data = GPR .interpolate (use_chunk = True )
218- target_info_data = target_info_data [:, None ]
219254 # Output interpolated anndata
220255 lm .main_info ("Creating an adata object with the interpolated expression..." )
221-
222- obs_keys = GPR .info_keys ["obs_keys" ]
223256 if len (obs_keys ) != 0 :
224- obs_data = target_info_data [:, : len ( obs_keys )]
257+ obs_data = np . concatenate ( obs_data , axis = 1 )
225258 obs_data = pd .DataFrame (obs_data , columns = obs_keys )
226-
227- var_keys = GPR .info_keys ["var_keys" ]
228259 if len (var_keys ) != 0 :
229- X = target_info_data [:, len ( obs_keys ) :]
260+ X = np . concatenate ( var_data , axis = 1 )
230261 var_data = pd .DataFrame (index = var_keys )
231262
232263 interp_adata = AnnData (
0 commit comments