11import random
22
3+ import networkx
34import numpy as np
45import ot
6+ import scipy .sparse as sp
57import torch
68from anndata import AnnData
79
3638 check_rep_layer ,
3739 check_spatial_coords ,
3840 con_K ,
41+ con_K_graph ,
42+ construct_knn_graph ,
3943 filter_common_genes ,
4044 get_P_core ,
4145 get_rep ,
@@ -94,7 +98,7 @@ class Morpho_pairwise:
9498 K (Union[int, float]): Number of sparse inducing points used for Nyström approximation for the kernel. Default is 15.
9599 kernel_type (str): Type of kernel used. Default is "euc".
96100 sigma2_init_scale (Optional[Union[int, float]]): Initial value for the spatial dispersion level. Default is 0.1.
97- partial_robust_level (float): Robust level of partial alignment. Default is 25 .
101+ partial_robust_level (float): Robust level of partial alignment. Default is 10 .
98102
99103 normalize_c (bool): Whether to normalize spatial coordinates. Default is True.
100104 normalize_g (bool): Whether to normalize gene expression. Default is True.
@@ -138,6 +142,8 @@ def __init__(
138142 beta : Union [int , float ] = 0.01 ,
139143 K : Union [int , float ] = 15 ,
140144 kernel_type : str = "euc" ,
145+ graph : Optional [networkx .Graph ] = None ,
146+ graph_knn : int = 10 ,
141147 sigma2_init_scale : Optional [Union [int , float ]] = 0.1 ,
142148 sigma2_end : Optional [Union [int , float ]] = None ,
143149 gamma_a : float = 1.0 ,
@@ -193,6 +199,8 @@ def __init__(
193199 self .K = K
194200 self .kernel_type = kernel_type
195201 self .kernel_bandwidth = beta
202+ self .graph = graph
203+ self .graph_knn = graph_knn
196204 self .sigma2_init_scale = sigma2_init_scale
197205 self .sigma2_end = sigma2_end
198206 self .partial_robust_level = partial_robust_level
@@ -840,13 +848,12 @@ def _construct_kernel(
840848 else None
841849 )
842850 elif self .kernel_type == "geodist" :
843- pass
844851 # TODO: finish this
845- # if self.graph is None:
846- # self.graph = _construct_graph (self.coordsA, self.knn )
847- # self.GammaSparse = con_K_graph(self.graph, inducing_variables_idx , inducing_variables_idx, beta=self.kernel_bandwidth)
848- # self.U = con_K_graph( self.graph, np.arange(self.NA), inducing_variables_idx, beta=self.kernel_bandwidth)
849-
852+ if self .graph is None :
853+ self .graph = construct_knn_graph (self .coordsA , self .graph_knn )
854+ self .U = con_K_graph (self .graph , inducing_variables_idx , beta = self .kernel_bandwidth )
855+ self .GammaSparse = self .U [ inducing_variables_idx , :]
856+ self . U_I = None # currently not support as the gudiance points is not in the graph
850857 else :
851858 raise NotImplementedError (f"Kernel type '{ self .kernel_type } ' is not implemented." )
852859
@@ -1163,10 +1170,17 @@ def _update_assignment_P(
11631170 self .Sp_sigma2 = Sp_sigma2
11641171
11651172 if self .sparse_calculation_mode :
1166- self .K_NA = self .K_NA .to_dense ()
1167- self .K_NB = self .K_NB .to_dense ()
1168- self .K_NA_spatial = self .K_NA_spatial .to_dense ()
1169- self .K_NA_sigma2 = self .K_NA_sigma2 .to_dense ()
1173+ if nx_torch (self .nx ):
1174+ self .K_NA = self .K_NA .to_dense ()
1175+ self .K_NB = self .K_NB .to_dense ()
1176+ self .K_NA_spatial = self .K_NA_spatial .to_dense ()
1177+ self .K_NA_sigma2 = self .K_NA_sigma2 .to_dense ()
1178+
1179+ else :
1180+ self .K_NA = self .K_NA .A .squeeze (- 1 )
1181+ self .K_NB = self .K_NB .A .squeeze (0 )
1182+ self .K_NA_spatial = self .K_NA_spatial
1183+ self .K_NA_sigma2 = self .K_NA_sigma2
11701184
11711185 self .sigma2_related = sigma2_related / (self .Dim * self .Sp_sigma2 )
11721186
@@ -1234,38 +1248,38 @@ def _update_nonrigid(
12341248
12351249 """
12361250
1237- SigmaInv = self .sigma2 * self .lambdaVF * self .GammaSparse + _dot ( self .nx ) (
1251+ SigmaInv = self .sigma2 * self .lambdaVF * self .GammaSparse + self .nx . dot (
12381252 self .U .T , self .nx .einsum ("ij,i->ij" , self .U , self .K_NA )
12391253 )
12401254 if self .SVI_mode :
1241- PXB_term = _dot ( self .nx ) (self .P , self .coordsB [self .batch_idx , :]) - self .nx .einsum (
1255+ PXB_term = self .nx . dot (self .P , self .coordsB [self .batch_idx , :]) - self .nx .einsum (
12421256 "ij,i->ij" , self .RnA , self .K_NA
12431257 )
12441258 self .SigmaInv = self .step_size * SigmaInv + (1 - self .step_size ) * self .SigmaInv
12451259 self .PXB_term = self .step_size * PXB_term + (1 - self .step_size ) * self .PXB_term
12461260 else :
1247- self .PXB_term = _dot ( self .nx ) (self .P , self .coordsB ) - self .nx .einsum ("ij,i->ij" , self .RnA , self .K_NA )
1261+ self .PXB_term = self .nx . dot (self .P , self .coordsB ) - self .nx .einsum ("ij,i->ij" , self .RnA , self .K_NA )
12481262 self .SigmaInv = SigmaInv
12491263
1250- UPXB_term = _dot ( self .nx ) (self .U .T , self .PXB_term )
1264+ UPXB_term = self .nx . dot (self .U .T , self .PXB_term )
12511265
12521266 # TODO: can we store these kernel multiple results? They are fixed
12531267 if self .guidance and ((self .guidance_effect == "nonrigid" ) or (self .guidance_effect == "both" )):
1254- self .SigmaInv += (self .sigma2 * self .guidance_weight * self .Sp / self .U_I .shape [0 ]) * _dot ( self .nx ) (
1268+ self .SigmaInv += (self .sigma2 * self .guidance_weight * self .Sp / self .U_I .shape [0 ]) * self .nx . dot (
12551269 self .U_I .T , self .U_I
12561270 )
1257- UPXB_term += (self .sigma2 * self .guidance_weight * self .Sp / self .U_I .shape [0 ]) * _dot ( self .nx ) (
1271+ UPXB_term += (self .sigma2 * self .guidance_weight * self .Sp / self .U_I .shape [0 ]) * self .nx . dot (
12581272 self .U_I .T , self .X_BI - self .R_AI
12591273 )
12601274
12611275 Sigma = _pinv (self .nx )(self .SigmaInv )
1262- self .Coff = _dot ( self .nx ) (Sigma , UPXB_term )
1276+ self .Coff = self .nx . dot (Sigma , UPXB_term )
12631277
1264- self .VnA = _dot ( self .nx ) (self .U , self .Coff )
1278+ self .VnA = self .nx . dot (self .U , self .Coff )
12651279 if self .guidance and ((self .guidance_effect == "nonrigid" ) or (self .guidance_effect == "both" )):
1266- self .V_AI = _dot ( self .nx ) (self .U_I , self .Coff )
1280+ self .V_AI = self .nx . dot (self .U_I , self .Coff )
12671281 self .SigmaDiag = self .sigma2 * self .nx .einsum (
1268- "ij->i" , self .nx .einsum ("ij,ji->ij" , self .U , _dot ( self .nx ) (Sigma , self .U .T ))
1282+ "ij->i" , self .nx .einsum ("ij,ji->ij" , self .U , self .nx . dot (Sigma , self .U .T ))
12691283 )
12701284
12711285 def _update_rigid (
@@ -1281,11 +1295,11 @@ def _update_rigid(
12811295 """
12821296
12831297 PXA , PVA , PXB = (
1284- _dot ( self .nx ) (self .K_NA , self .coordsA )[None , :],
1285- _dot ( self .nx ) (self .K_NA , self .VnA )[None , :],
1286- _dot ( self .nx ) (self .K_NB , self .coordsB [self .batch_idx , :])[None , :]
1298+ self .nx . dot (self .K_NA , self .coordsA )[None , :],
1299+ self .nx . dot (self .K_NA , self .VnA )[None , :],
1300+ self .nx . dot (self .K_NB , self .coordsB [self .batch_idx , :])[None , :]
12871301 if self .SVI_mode
1288- else _dot ( self .nx ) (self .K_NB , self .coordsB )[None , :],
1302+ else self .nx . dot (self .K_NB , self .coordsB )[None , :],
12891303 )
12901304 # solve rotation using SVD formula
12911305 mu_XB , mu_XA , mu_Vn = PXB , PXA , PVA
@@ -1297,10 +1311,10 @@ def _update_rigid(
12971311 mu_X_deno += (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * self .X_BI .shape [0 ]
12981312 mu_Vn_deno += (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * self .X_BI .shape [0 ]
12991313 if self .nn_init :
1300- mu_XB += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * _dot ( self .nx ) (
1314+ mu_XB += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx . dot (
13011315 self .inlier_P .T , self .inlier_B
13021316 )
1303- mu_XA += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * _dot ( self .nx ) (
1317+ mu_XA += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx . dot (
13041318 self .inlier_P .T , self .inlier_A
13051319 )
13061320 mu_X_deno += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx .sum (
@@ -1323,44 +1337,43 @@ def _update_rigid(
13231337 if self .nn_init :
13241338 inlier_A_hat = self .inlier_A - mu_XA
13251339 inlier_B_hat = self .inlier_B - mu_XB
1326-
13271340 A = - (
1328- _dot ( self .nx ) (XA_hat .T , self .nx .einsum ("ij,i->ij" , VnA_hat , self .K_NA ))
1329- - _dot ( self .nx )( _dot ( self .nx ) (XA_hat .T , self .P ), XB_hat )
1341+ self .nx . dot (XA_hat .T , self .nx .einsum ("ij,i->ij" , VnA_hat , self .K_NA ))
1342+ - self .nx . dot ( self .nx . dot (XA_hat .T , self .P ), XB_hat )
13301343 ).T
13311344
13321345 if self .guidance_effect in ("rigid" , "both" ):
1333- A -= (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * _dot ( self .nx ) (
1346+ A -= (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * self .nx . dot (
13341347 X_AI_hat .T , V_AI_hat - X_BI_hat
13351348 ).T
13361349
13371350 if self .nn_init :
1338- A -= (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * _dot ( self .nx ) (
1351+ A -= (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx . dot (
13391352 (inlier_A_hat * self .inlier_P ).T , - inlier_B_hat
13401353 ).T
13411354
13421355 svdU , svdS , svdV = _linalg (self .nx ).svd (A )
1343- self .C [- 1 , - 1 ] = _linalg (self .nx ).det (_dot ( self .nx ) (svdU , svdV ))
1356+ self .C [- 1 , - 1 ] = _linalg (self .nx ).det (self .nx . dot (svdU , svdV ))
13441357
1345- R = _dot ( self .nx )( _dot ( self .nx ) (svdU , self .C ), svdV )
1358+ R = self .nx . dot ( self .nx . dot (svdU , self .C ), svdV )
13461359 if self .SVI_mode and self .step_size < 1 :
13471360 self .R = self .step_size * R + (1 - self .step_size ) * self .R
13481361 else :
13491362 self .R = R
13501363
13511364 # solve translation using SVD formula
1352- t_numerator = PXB - PVA - _dot ( self .nx ) (PXA , self .R .T )
1365+ t_numerator = PXB - PVA - self .nx . dot (PXA , self .R .T )
13531366 t_deno = _copy (self .nx , self .Sp )
13541367
13551368 if self .guidance and (self .guidance_effect in ("rigid" , "both" )):
13561369 t_numerator += (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * self .nx .sum (
1357- self .X_BI - self .V_AI - _dot ( self .nx ) (self .X_AI , self .R .T ), axis = 0
1370+ self .X_BI - self .V_AI - self .nx . dot (self .X_AI , self .R .T ), axis = 0
13581371 )
13591372 t_deno += (self .sigma2 * self .guidance_weight * self .Sp / self .X_BI .shape [0 ]) * self .X_BI .shape [0 ]
13601373
13611374 if self .nn_init :
1362- t_numerator += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * _dot ( self .nx ) (
1363- self .inlier_P .T , self .inlier_B - _dot ( self .nx ) (self .inlier_A , self .R .T )
1375+ t_numerator += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx . dot (
1376+ self .inlier_P .T , self .inlier_B - self .nx . dot (self .inlier_A , self .R .T )
13641377 )
13651378 t_deno += (self .sigma2 * self .nn_init_weight * self .Sp / self .nx .sum (self .inlier_P )) * self .nx .sum (
13661379 self .inlier_P
@@ -1372,11 +1385,11 @@ def _update_rigid(
13721385 else :
13731386 self .t = t
13741387
1375- self .RnA = _dot ( self .nx ) (self .coordsA , self .R .T ) + self .t
1388+ self .RnA = self .nx . dot (self .coordsA , self .R .T ) + self .t
13761389 if self .nn_init :
1377- self .inlier_R = _dot ( self .nx ) (self .inlier_A , self .R .T ) + self .t
1390+ self .inlier_R = self .nx . dot (self .inlier_A , self .R .T ) + self .t
13781391 if self .guidance :
1379- self .R_AI = _dot ( self .nx ) (self .R_AI , self .R .T ) + self .t
1392+ self .R_AI = self .nx . dot (self .R_AI , self .R .T ) + self .t
13801393
13811394 def _update_sigma2 (
13821395 self ,
@@ -1420,23 +1433,24 @@ def _get_optimal_R(
14201433 """
14211434
14221435 mu_XnA , mu_XnB = (
1423- _dot ( self .nx ) (self .K_NA , self .coordsA ) / self .Sp ,
1424- _dot ( self .nx ) (self .K_NB , self .coordsB [self .batch_idx , :]) / self .Sp
1436+ self .nx . dot (self .K_NA , self .coordsA ) / self .Sp ,
1437+ self .nx . dot (self .K_NB , self .coordsB [self .batch_idx , :]) / self .Sp
14251438 if self .SVI_mode
1426- else _dot ( self .nx ) (self .K_NB , self .coordsB ) / self .Sp ,
1439+ else self .nx . dot (self .K_NB , self .coordsB ) / self .Sp ,
14271440 )
14281441 XnABar , XnBBar = (
14291442 self .coordsA - mu_XnA ,
14301443 self .coordsB [self .batch_idx , :] - mu_XnB if self .SVI_mode else self .coordsB - mu_XnB ,
14311444 )
1432- A = _dot ( self .nx )( _dot ( self .nx ) (self .P , XnBBar ).T , XnABar )
1445+ A = self .nx . dot ( self .nx . dot (self .P , XnBBar ).T , XnABar )
14331446
14341447 # get the optimal rotation matrix R
14351448 svdU , svdS , svdV = _linalg (self .nx ).svd (A )
1436- self .C [- 1 , - 1 ] = _linalg (self .nx ).det (_dot (self .nx )(svdU , svdV ))
1437- self .optimal_R = _dot (self .nx )(_dot (self .nx )(svdU , self .C ), svdV )
1438- self .optimal_t = mu_XnB - _dot (self .nx )(mu_XnA , self .optimal_R .T )
1439- self .optimal_RnA = _dot (self .nx )(self .coordsA , self .optimal_R .T ) + self .optimal_t
1449+
1450+ self .C [- 1 , - 1 ] = _linalg (self .nx ).det (self .nx .dot (svdU , svdV ))
1451+ self .optimal_R = self .nx .dot (self .nx .dot (svdU , self .C ), svdV )
1452+ self .optimal_t = mu_XnB - self .nx .dot (mu_XnA , self .optimal_R .T )
1453+ self .optimal_RnA = self .nx .dot (self .coordsA , self .optimal_R .T ) + self .optimal_t
14401454
14411455 def _wrap_output (
14421456 self ,
@@ -1468,7 +1482,7 @@ def _wrap_output(
14681482 norm_dict = {
14691483 "mean_transformed" : self .nx .to_numpy (self .normalize_means [0 ]),
14701484 "mean_fixed" : self .nx .to_numpy (self .normalize_means [1 ]),
1471- "scale" : self .nx .to_numpy (self .normalize_scales [1 ]),
1485+ "scale" : self .nx .to_numpy (self .normalize_scales [0 ]),
14721486 "scale_transformed" : self .nx .to_numpy (self .normalize_scales [0 ]),
14731487 "scale_fixed" : self .nx .to_numpy (self .normalize_scales [1 ]),
14741488 }
@@ -1493,4 +1507,5 @@ def _wrap_output(
14931507 "sigma2_variance" : self .nx .to_numpy (self .sigma2_variance ),
14941508 "method" : "Spateo" ,
14951509 "norm_dict" : norm_dict ,
1510+ "kernel_type" : self .kernel_type ,
14961511 }
0 commit comments