Skip to content

Commit 45aa2c1

Browse files
committed
Fix the bugs with cpu + sparse calculation
1 parent e1df648 commit 45aa2c1

File tree

7 files changed

+289
-61
lines changed

7 files changed

+289
-61
lines changed

spateo/align.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
generate_label_transfer_prior,
1010
get_optimal_mapping_relationship,
1111
grid_deformation,
12+
group_pca,
1213
morpho_align,
1314
morpho_align_ref,
1415
paste_align,
1516
paste_align_ref,
1617
paste_transform,
1718
solve_RT_by_correspondence,
1819
split_slice,
20+
tps_deformation,
1921
)

spateo/alignment/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
generate_label_transfer_prior,
1515
get_labels_based_on_coords,
1616
get_optimal_mapping_relationship,
17+
group_pca,
1718
mapping_aligned_coords,
1819
mapping_center_coords,
1920
solve_RT_by_correspondence,
2021
split_slice,
22+
tps_deformation,
2123
)

spateo/alignment/methods/backend.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import numpy as np
2929
import scipy
3030
import scipy.linalg
31+
import scipy.sparse as sp
3132
import scipy.special as special
3233
from scipy.sparse import coo_matrix, csr_matrix, issparse
3334

@@ -1044,10 +1045,18 @@ def log(self, a):
10441045
return np.log(a)
10451046

10461047
def concatenate(self, arrays, axis=0):
1047-
return np.concatenate(arrays, axis)
1048+
if all(issparse(arr) for arr in arrays):
1049+
return sp.vstack(arrays) if axis == 0 else sp.hstack(arrays)
1050+
elif all(isinstance(arr, np.ndarray) for arr in arrays):
1051+
return np.concatenate(arrays, axis)
1052+
else:
1053+
raise ValueError("All arrays should be of the same type")
10481054

10491055
def sum(self, a, axis=None, keepdims=False):
1050-
return np.sum(a, axis, keepdims=keepdims)
1056+
if issparse(a):
1057+
return a.sum(axis=axis)
1058+
else:
1059+
return np.sum(a, axis, keepdims=keepdims)
10511060

10521061
def arange(self, stop, start=0, step=1, type_as=None):
10531062
return np.arange(start, stop, step)
@@ -1071,7 +1080,15 @@ def power(self, a, exponents):
10711080
return np.power(a, exponents)
10721081

10731082
def dot(self, a, b):
1074-
return np.dot(a, b)
1083+
if sp.issparse(a):
1084+
if sp.issparse(b):
1085+
return a.dot(b)
1086+
else:
1087+
return a.dot(b)
1088+
elif sp.issparse(b):
1089+
return b.T.dot(a.T).T
1090+
else:
1091+
return np.dot(a, b)
10751092

10761093
def prod(self, a, axis=0):
10771094
return np.prod(a, axis=axis)

spateo/alignment/methods/morpho_class.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import random
22

3+
import networkx
34
import numpy as np
45
import ot
6+
import scipy.sparse as sp
57
import torch
68
from anndata import AnnData
79

@@ -36,6 +38,8 @@
3638
check_rep_layer,
3739
check_spatial_coords,
3840
con_K,
41+
con_K_graph,
42+
construct_knn_graph,
3943
filter_common_genes,
4044
get_P_core,
4145
get_rep,
@@ -94,7 +98,7 @@ class Morpho_pairwise:
9498
K (Union[int, float]): Number of sparse inducing points used for Nyström approximation for the kernel. Default is 15.
9599
kernel_type (str): Type of kernel used. Default is "euc".
96100
sigma2_init_scale (Optional[Union[int, float]]): Initial value for the spatial dispersion level. Default is 0.1.
97-
partial_robust_level (float): Robust level of partial alignment. Default is 25.
101+
partial_robust_level (float): Robust level of partial alignment. Default is 10.
98102
99103
normalize_c (bool): Whether to normalize spatial coordinates. Default is True.
100104
normalize_g (bool): Whether to normalize gene expression. Default is True.
@@ -138,6 +142,8 @@ def __init__(
138142
beta: Union[int, float] = 0.01,
139143
K: Union[int, float] = 15,
140144
kernel_type: str = "euc",
145+
graph: Optional[networkx.Graph] = None,
146+
graph_knn: int = 10,
141147
sigma2_init_scale: Optional[Union[int, float]] = 0.1,
142148
sigma2_end: Optional[Union[int, float]] = None,
143149
gamma_a: float = 1.0,
@@ -193,6 +199,8 @@ def __init__(
193199
self.K = K
194200
self.kernel_type = kernel_type
195201
self.kernel_bandwidth = beta
202+
self.graph = graph
203+
self.graph_knn = graph_knn
196204
self.sigma2_init_scale = sigma2_init_scale
197205
self.sigma2_end = sigma2_end
198206
self.partial_robust_level = partial_robust_level
@@ -840,13 +848,12 @@ def _construct_kernel(
840848
else None
841849
)
842850
elif self.kernel_type == "geodist":
843-
pass
844851
# TODO: finish this
845-
# if self.graph is None:
846-
# self.graph = _construct_graph(self.coordsA, self.knn)
847-
# self.GammaSparse = con_K_graph(self.graph, inducing_variables_idx, inducing_variables_idx, beta=self.kernel_bandwidth)
848-
# self.U = con_K_graph(self.graph, np.arange(self.NA), inducing_variables_idx, beta=self.kernel_bandwidth)
849-
852+
if self.graph is None:
853+
self.graph = construct_knn_graph(self.coordsA, self.graph_knn)
854+
self.U = con_K_graph(self.graph, inducing_variables_idx, beta=self.kernel_bandwidth)
855+
self.GammaSparse = self.U[inducing_variables_idx, :]
856+
self.U_I = None # currently not support as the gudiance points is not in the graph
850857
else:
851858
raise NotImplementedError(f"Kernel type '{self.kernel_type}' is not implemented.")
852859

@@ -1163,10 +1170,17 @@ def _update_assignment_P(
11631170
self.Sp_sigma2 = Sp_sigma2
11641171

11651172
if self.sparse_calculation_mode:
1166-
self.K_NA = self.K_NA.to_dense()
1167-
self.K_NB = self.K_NB.to_dense()
1168-
self.K_NA_spatial = self.K_NA_spatial.to_dense()
1169-
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()
1173+
if nx_torch(self.nx):
1174+
self.K_NA = self.K_NA.to_dense()
1175+
self.K_NB = self.K_NB.to_dense()
1176+
self.K_NA_spatial = self.K_NA_spatial.to_dense()
1177+
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()
1178+
1179+
else:
1180+
self.K_NA = self.K_NA.A.squeeze(-1)
1181+
self.K_NB = self.K_NB.A.squeeze(0)
1182+
self.K_NA_spatial = self.K_NA_spatial
1183+
self.K_NA_sigma2 = self.K_NA_sigma2
11701184

11711185
self.sigma2_related = sigma2_related / (self.Dim * self.Sp_sigma2)
11721186

@@ -1234,38 +1248,38 @@ def _update_nonrigid(
12341248
12351249
"""
12361250

1237-
SigmaInv = self.sigma2 * self.lambdaVF * self.GammaSparse + _dot(self.nx)(
1251+
SigmaInv = self.sigma2 * self.lambdaVF * self.GammaSparse + self.nx.dot(
12381252
self.U.T, self.nx.einsum("ij,i->ij", self.U, self.K_NA)
12391253
)
12401254
if self.SVI_mode:
1241-
PXB_term = _dot(self.nx)(self.P, self.coordsB[self.batch_idx, :]) - self.nx.einsum(
1255+
PXB_term = self.nx.dot(self.P, self.coordsB[self.batch_idx, :]) - self.nx.einsum(
12421256
"ij,i->ij", self.RnA, self.K_NA
12431257
)
12441258
self.SigmaInv = self.step_size * SigmaInv + (1 - self.step_size) * self.SigmaInv
12451259
self.PXB_term = self.step_size * PXB_term + (1 - self.step_size) * self.PXB_term
12461260
else:
1247-
self.PXB_term = _dot(self.nx)(self.P, self.coordsB) - self.nx.einsum("ij,i->ij", self.RnA, self.K_NA)
1261+
self.PXB_term = self.nx.dot(self.P, self.coordsB) - self.nx.einsum("ij,i->ij", self.RnA, self.K_NA)
12481262
self.SigmaInv = SigmaInv
12491263

1250-
UPXB_term = _dot(self.nx)(self.U.T, self.PXB_term)
1264+
UPXB_term = self.nx.dot(self.U.T, self.PXB_term)
12511265

12521266
# TODO: can we store these kernel multiple results? They are fixed
12531267
if self.guidance and ((self.guidance_effect == "nonrigid") or (self.guidance_effect == "both")):
1254-
self.SigmaInv += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * _dot(self.nx)(
1268+
self.SigmaInv += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * self.nx.dot(
12551269
self.U_I.T, self.U_I
12561270
)
1257-
UPXB_term += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * _dot(self.nx)(
1271+
UPXB_term += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * self.nx.dot(
12581272
self.U_I.T, self.X_BI - self.R_AI
12591273
)
12601274

12611275
Sigma = _pinv(self.nx)(self.SigmaInv)
1262-
self.Coff = _dot(self.nx)(Sigma, UPXB_term)
1276+
self.Coff = self.nx.dot(Sigma, UPXB_term)
12631277

1264-
self.VnA = _dot(self.nx)(self.U, self.Coff)
1278+
self.VnA = self.nx.dot(self.U, self.Coff)
12651279
if self.guidance and ((self.guidance_effect == "nonrigid") or (self.guidance_effect == "both")):
1266-
self.V_AI = _dot(self.nx)(self.U_I, self.Coff)
1280+
self.V_AI = self.nx.dot(self.U_I, self.Coff)
12671281
self.SigmaDiag = self.sigma2 * self.nx.einsum(
1268-
"ij->i", self.nx.einsum("ij,ji->ij", self.U, _dot(self.nx)(Sigma, self.U.T))
1282+
"ij->i", self.nx.einsum("ij,ji->ij", self.U, self.nx.dot(Sigma, self.U.T))
12691283
)
12701284

12711285
def _update_rigid(
@@ -1281,11 +1295,11 @@ def _update_rigid(
12811295
"""
12821296

12831297
PXA, PVA, PXB = (
1284-
_dot(self.nx)(self.K_NA, self.coordsA)[None, :],
1285-
_dot(self.nx)(self.K_NA, self.VnA)[None, :],
1286-
_dot(self.nx)(self.K_NB, self.coordsB[self.batch_idx, :])[None, :]
1298+
self.nx.dot(self.K_NA, self.coordsA)[None, :],
1299+
self.nx.dot(self.K_NA, self.VnA)[None, :],
1300+
self.nx.dot(self.K_NB, self.coordsB[self.batch_idx, :])[None, :]
12871301
if self.SVI_mode
1288-
else _dot(self.nx)(self.K_NB, self.coordsB)[None, :],
1302+
else self.nx.dot(self.K_NB, self.coordsB)[None, :],
12891303
)
12901304
# solve rotation using SVD formula
12911305
mu_XB, mu_XA, mu_Vn = PXB, PXA, PVA
@@ -1297,10 +1311,10 @@ def _update_rigid(
12971311
mu_X_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]
12981312
mu_Vn_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]
12991313
if self.nn_init:
1300-
mu_XB += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
1314+
mu_XB += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
13011315
self.inlier_P.T, self.inlier_B
13021316
)
1303-
mu_XA += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
1317+
mu_XA += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
13041318
self.inlier_P.T, self.inlier_A
13051319
)
13061320
mu_X_deno += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.sum(
@@ -1323,44 +1337,43 @@ def _update_rigid(
13231337
if self.nn_init:
13241338
inlier_A_hat = self.inlier_A - mu_XA
13251339
inlier_B_hat = self.inlier_B - mu_XB
1326-
13271340
A = -(
1328-
_dot(self.nx)(XA_hat.T, self.nx.einsum("ij,i->ij", VnA_hat, self.K_NA))
1329-
- _dot(self.nx)(_dot(self.nx)(XA_hat.T, self.P), XB_hat)
1341+
self.nx.dot(XA_hat.T, self.nx.einsum("ij,i->ij", VnA_hat, self.K_NA))
1342+
- self.nx.dot(self.nx.dot(XA_hat.T, self.P), XB_hat)
13301343
).T
13311344

13321345
if self.guidance_effect in ("rigid", "both"):
1333-
A -= (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * _dot(self.nx)(
1346+
A -= (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.nx.dot(
13341347
X_AI_hat.T, V_AI_hat - X_BI_hat
13351348
).T
13361349

13371350
if self.nn_init:
1338-
A -= (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
1351+
A -= (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
13391352
(inlier_A_hat * self.inlier_P).T, -inlier_B_hat
13401353
).T
13411354

13421355
svdU, svdS, svdV = _linalg(self.nx).svd(A)
1343-
self.C[-1, -1] = _linalg(self.nx).det(_dot(self.nx)(svdU, svdV))
1356+
self.C[-1, -1] = _linalg(self.nx).det(self.nx.dot(svdU, svdV))
13441357

1345-
R = _dot(self.nx)(_dot(self.nx)(svdU, self.C), svdV)
1358+
R = self.nx.dot(self.nx.dot(svdU, self.C), svdV)
13461359
if self.SVI_mode and self.step_size < 1:
13471360
self.R = self.step_size * R + (1 - self.step_size) * self.R
13481361
else:
13491362
self.R = R
13501363

13511364
# solve translation using SVD formula
1352-
t_numerator = PXB - PVA - _dot(self.nx)(PXA, self.R.T)
1365+
t_numerator = PXB - PVA - self.nx.dot(PXA, self.R.T)
13531366
t_deno = _copy(self.nx, self.Sp)
13541367

13551368
if self.guidance and (self.guidance_effect in ("rigid", "both")):
13561369
t_numerator += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.nx.sum(
1357-
self.X_BI - self.V_AI - _dot(self.nx)(self.X_AI, self.R.T), axis=0
1370+
self.X_BI - self.V_AI - self.nx.dot(self.X_AI, self.R.T), axis=0
13581371
)
13591372
t_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]
13601373

13611374
if self.nn_init:
1362-
t_numerator += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
1363-
self.inlier_P.T, self.inlier_B - _dot(self.nx)(self.inlier_A, self.R.T)
1375+
t_numerator += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
1376+
self.inlier_P.T, self.inlier_B - self.nx.dot(self.inlier_A, self.R.T)
13641377
)
13651378
t_deno += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.sum(
13661379
self.inlier_P
@@ -1372,11 +1385,11 @@ def _update_rigid(
13721385
else:
13731386
self.t = t
13741387

1375-
self.RnA = _dot(self.nx)(self.coordsA, self.R.T) + self.t
1388+
self.RnA = self.nx.dot(self.coordsA, self.R.T) + self.t
13761389
if self.nn_init:
1377-
self.inlier_R = _dot(self.nx)(self.inlier_A, self.R.T) + self.t
1390+
self.inlier_R = self.nx.dot(self.inlier_A, self.R.T) + self.t
13781391
if self.guidance:
1379-
self.R_AI = _dot(self.nx)(self.R_AI, self.R.T) + self.t
1392+
self.R_AI = self.nx.dot(self.R_AI, self.R.T) + self.t
13801393

13811394
def _update_sigma2(
13821395
self,
@@ -1420,23 +1433,24 @@ def _get_optimal_R(
14201433
"""
14211434

14221435
mu_XnA, mu_XnB = (
1423-
_dot(self.nx)(self.K_NA, self.coordsA) / self.Sp,
1424-
_dot(self.nx)(self.K_NB, self.coordsB[self.batch_idx, :]) / self.Sp
1436+
self.nx.dot(self.K_NA, self.coordsA) / self.Sp,
1437+
self.nx.dot(self.K_NB, self.coordsB[self.batch_idx, :]) / self.Sp
14251438
if self.SVI_mode
1426-
else _dot(self.nx)(self.K_NB, self.coordsB) / self.Sp,
1439+
else self.nx.dot(self.K_NB, self.coordsB) / self.Sp,
14271440
)
14281441
XnABar, XnBBar = (
14291442
self.coordsA - mu_XnA,
14301443
self.coordsB[self.batch_idx, :] - mu_XnB if self.SVI_mode else self.coordsB - mu_XnB,
14311444
)
1432-
A = _dot(self.nx)(_dot(self.nx)(self.P, XnBBar).T, XnABar)
1445+
A = self.nx.dot(self.nx.dot(self.P, XnBBar).T, XnABar)
14331446

14341447
# get the optimal rotation matrix R
14351448
svdU, svdS, svdV = _linalg(self.nx).svd(A)
1436-
self.C[-1, -1] = _linalg(self.nx).det(_dot(self.nx)(svdU, svdV))
1437-
self.optimal_R = _dot(self.nx)(_dot(self.nx)(svdU, self.C), svdV)
1438-
self.optimal_t = mu_XnB - _dot(self.nx)(mu_XnA, self.optimal_R.T)
1439-
self.optimal_RnA = _dot(self.nx)(self.coordsA, self.optimal_R.T) + self.optimal_t
1449+
1450+
self.C[-1, -1] = _linalg(self.nx).det(self.nx.dot(svdU, svdV))
1451+
self.optimal_R = self.nx.dot(self.nx.dot(svdU, self.C), svdV)
1452+
self.optimal_t = mu_XnB - self.nx.dot(mu_XnA, self.optimal_R.T)
1453+
self.optimal_RnA = self.nx.dot(self.coordsA, self.optimal_R.T) + self.optimal_t
14401454

14411455
def _wrap_output(
14421456
self,
@@ -1468,7 +1482,7 @@ def _wrap_output(
14681482
norm_dict = {
14691483
"mean_transformed": self.nx.to_numpy(self.normalize_means[0]),
14701484
"mean_fixed": self.nx.to_numpy(self.normalize_means[1]),
1471-
"scale": self.nx.to_numpy(self.normalize_scales[1]),
1485+
"scale": self.nx.to_numpy(self.normalize_scales[0]),
14721486
"scale_transformed": self.nx.to_numpy(self.normalize_scales[0]),
14731487
"scale_fixed": self.nx.to_numpy(self.normalize_scales[1]),
14741488
}
@@ -1493,4 +1507,5 @@ def _wrap_output(
14931507
"sigma2_variance": self.nx.to_numpy(self.sigma2_variance),
14941508
"method": "Spateo",
14951509
"norm_dict": norm_dict,
1510+
"kernel_type": self.kernel_type,
14961511
}

0 commit comments

Comments
 (0)