Skip to content

Commit f43f7fb

Browse files
committed
Merge branch 'yifan'
2 parents 2485c62 + b7becb3 commit f43f7fb

File tree

9 files changed

+152
-85
lines changed

9 files changed

+152
-85
lines changed
Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
11
# Spatial transcriptomics alignment
22

3-
This section describes the technical details behind Spateo's spatial transcriptomics alignment pipeline.
3+
This section describes the technical details behind Spateo's spatial transcriptomics alignment pipeline.
4+
5+
## Background
6+
7+
The sequential slicing and subsequent spatial transcriptomic profiling at the whole embryo level offer us an unprecedented opportunity to reconstruct the molecular hologram of the entire 3D embryo structure. However, conventional sectioning and downstream library preparation can rotate, transform, deform, and introduce missing regions in each profiled tissue section. In addition, with advancements in technology, spatial transcriptomics techniques with single-cell and even subcellular resolution are gradually emerging, and a single slice often contains hundreds of thousands of cells. Therefore, it is in general necessary to develop scalable and robust algorithms to reconstruct 3D structures to recover the relative spatial locations of single cells across different slices while allowing local distortion within the same slice.
8+
9+
10+
## Methodology
11+
Consider a series of spatially-resolved transcriptomics samples, such as consecutive tissue sections from the same embryo, denote as $\mathcal{D} = \{\mathcal{S}^i\}_{i=1}^k$, where $\mathcal{S}^i=$ is the $i$-th section
12+
13+
14+
### Problem formulation
15+
16+
### Generative process
17+
18+
### Transformation model
19+
20+
### Define prior distributions
21+
22+
### Variational Bayesian Inference
23+
24+
## Function Design
25+
426

spateo/align.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
BA_transform_and_assignment,
44
Mesh_correction,
55
align_preprocess,
6+
calc_distance,
67
calc_exp_dissimilarity,
78
generate_label_transfer_dict,
89
generate_label_transfer_prior,

spateo/alignment/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .methods import (
33
Mesh_correction,
44
align_preprocess,
5+
calc_distance,
56
calc_exp_dissimilarity,
67
generate_label_transfer_dict,
78
)

spateo/alignment/methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_power,
2424
_prod,
2525
_unsqueeze,
26+
calc_distance,
2627
check_backend,
2728
check_exp,
2829
con_K,

spateo/alignment/methods/backend.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def diag(self, a, k=0):
528528
"""
529529
raise NotImplementedError()
530530

531-
def unique(self, a, return_inverse=False):
531+
def unique(self, a, return_index=False, return_inverse=False):
532532
r"""
533533
Finds unique elements of given tensor.
534534
@@ -1058,8 +1058,8 @@ def data(self, a, type_as=None):
10581058
else:
10591059
return np.asarray(a, dtype=type_as.dtype)
10601060

1061-
def unique(self, a, return_inverse=False, axis=None):
1062-
return np.unique(a, return_inverse=return_inverse, axis=axis)
1061+
def unique(self, a, return_index, return_inverse=False, axis=None):
1062+
return np.unique(a, return_index=return_index, return_inverse=return_inverse, axis=axis)
10631063

10641064
def unsqueeze(self, a, axis=-1):
10651065
return np.expand_dims(a, axis)
@@ -1325,8 +1325,27 @@ def data(self, a, type_as=None):
13251325
else:
13261326
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
13271327

1328-
def unique(self, a, return_inverse=False, axis=None):
1329-
return torch.unique(a, return_inverse=return_inverse, dim=axis)
1328+
def unique(self, a, return_index=False, return_inverse=False, axis=None):
1329+
unique_values, inverse_indices = torch.unique(a, sorted=False, return_inverse=True, dim=axis)
1330+
1331+
result = [unique_values]
1332+
1333+
if return_index:
1334+
x_sort, idx_sorted = torch.sort(inverse_indices)
1335+
return_index = idx_sorted[
1336+
torch.hstack(
1337+
[
1338+
torch.where((x_sort[1:] - x_sort[:-1]) != 0)[0],
1339+
torch.tensor([len(inverse_indices) - 1], device=x_sort.device),
1340+
]
1341+
)
1342+
]
1343+
result.append(return_index)
1344+
1345+
if return_inverse:
1346+
result.append(inverse_indices)
1347+
1348+
return tuple(result) if len(result) > 1 else result[0]
13301349

13311350
def unsqueeze(self, a, axis=-1):
13321351
return torch.unsqueeze(a, axis)

spateo/alignment/methods/morpho_class.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def run(
244244
Returns:
245245
np.ndarray: The final cell-cell assignment matrix.
246246
"""
247-
# print(f'summin: {self.nx.sum(self.exp_layers_A[0], axis=1, keepdims=True).min()}')
248247
if self.nn_init:
249248
self._coarse_rigid_alignment()
250249

@@ -560,7 +559,6 @@ def _guidance_pair_preprocess(
560559
self.X_AI = self.nx.from_numpy(self.guidance_pair[1], type_as=self.type_as)
561560
self.V_AI = self.nx.zeros(self.X_AI.shape, type_as=self.type_as)
562561
self.R_AI = self.nx.zeros(self.X_AI.shape, type_as=self.type_as)
563-
# print(self.V_AI)
564562

565563
if self.normalize_c:
566564
# Normalize the guidance pairs
@@ -789,7 +787,6 @@ def _init_probability_parameters(
789787
Y=exp_B[sub_sample_B],
790788
metric=d_s,
791789
)
792-
# print(exp_A[sub_sample_A])
793790
min_exp_dist = self.nx.min(exp_dist, 1)
794791
self.probability_parameters[i] = self.nx.maximum(
795792
min_exp_dist[self.nx.argsort(min_exp_dist)[int(sub_sample_A.shape[0] * 0.05)]] / 5,
@@ -821,13 +818,16 @@ def _construct_kernel(
821818
NotImplementedError: If the specified kernel type is not implemented.
822819
"""
823820

824-
unique_spatial_coords = _unique(self.nx, self.coordsA, 0)
821+
# unique_spatial_coords = _unique(self.nx, self.coordsA, 0)
822+
unique_spatial_coords, unique_idx = self.nx.unique(self.coordsA, return_index=True, axis=0)
825823
inducing_variables_idx = (
826824
np.random.choice(unique_spatial_coords.shape[0], inducing_variables_num, replace=False)
827825
if unique_spatial_coords.shape[0] > inducing_variables_num
828826
else np.arange(unique_spatial_coords.shape[0])
829827
)
830-
self.inducing_variables = unique_spatial_coords[inducing_variables_idx, :]
828+
inducing_variables_idx = unique_idx[inducing_variables_idx]
829+
self.inducing_variables = self.coordsA[inducing_variables_idx, :]
830+
# self.inducing_variables = unique_spatial_coords[inducing_variables_idx, :]
831831
# (self.inducing_variables, _) = sample(
832832
# X=unique_spatial_coords, n_sampling=inducing_variables_num, sampling_method=sampling_method
833833
# )
@@ -839,6 +839,14 @@ def _construct_kernel(
839839
if self.guidance_effect in ["nonrigid", "both"]
840840
else None
841841
)
842+
elif self.kernel_type == "geodist":
843+
pass
844+
# TODO: finish this
845+
# if self.graph is None:
846+
# self.graph = _construct_graph(self.coordsA, self.knn)
847+
# self.GammaSparse = con_K_graph(self.graph, inducing_variables_idx, inducing_variables_idx, beta=self.kernel_bandwidth)
848+
# self.U = con_K_graph(self.graph, np.arange(self.NA), inducing_variables_idx, beta=self.kernel_bandwidth)
849+
842850
else:
843851
raise NotImplementedError(f"Kernel type '{self.kernel_type}' is not implemented.")
844852

@@ -1097,7 +1105,6 @@ def _update_assignment_P(
10971105
exp_layer_dist = calc_distance(
10981106
self.exp_layers_A, exp_layer_B_chunk, self.dissimilarity, self.label_transfer
10991107
)
1100-
11011108
P, K_NA_spatial_chunk, K_NA_sigma2_chunk, sigma2_related_chunk = get_P_core(
11021109
spatial_dist=spatial_dist, exp_dist=exp_layer_dist, **common_kwargs
11031110
)
@@ -1120,7 +1127,6 @@ def _update_assignment_P(
11201127
Y=self.coordsB[self.batch_idx, :] if self.SVI_mode else self.coordsB,
11211128
metric="euc",
11221129
) # NA x batch_size (SVI_mode) / NA x NB (not SVI_mode)
1123-
# print(self.pre_compute_dist)
11241130
if self.pre_compute_dist:
11251131
exp_layer_dist = (
11261132
[exp_layer_d[:, self.batch_idx] for exp_layer_d in self.exp_layer_dist]
@@ -1163,7 +1169,6 @@ def _update_assignment_P(
11631169
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()
11641170

11651171
self.sigma2_related = sigma2_related / (self.Dim * self.Sp_sigma2)
1166-
# print(self.sigma2_related)
11671172

11681173
def _update_gamma(
11691174
self,
@@ -1282,7 +1287,6 @@ def _update_rigid(
12821287
if self.SVI_mode
12831288
else _dot(self.nx)(self.K_NB, self.coordsB)[None, :],
12841289
)
1285-
# print(self.Sp)
12861290
# solve rotation using SVD formula
12871291
mu_XB, mu_XA, mu_Vn = PXB, PXA, PVA
12881292
mu_X_deno, mu_Vn_deno = _copy(self.nx, self.Sp), _copy(self.nx, self.Sp)
@@ -1462,11 +1466,11 @@ def _wrap_output(
14621466

14631467
if not (self.vecfld_key_added is None):
14641468
norm_dict = {
1465-
"mean_transformed": self.nx.to_numpy(self.normalize_means[1]),
1466-
"mean_fixed": self.nx.to_numpy(self.normalize_means[0]),
1469+
"mean_transformed": self.nx.to_numpy(self.normalize_means[0]),
1470+
"mean_fixed": self.nx.to_numpy(self.normalize_means[1]),
14671471
"scale": self.nx.to_numpy(self.normalize_scales[1]),
1468-
"scale_transformed": self.nx.to_numpy(self.normalize_scales[1]),
1469-
"scale_fixed": self.nx.to_numpy(self.normalize_scales[0]),
1472+
"scale_transformed": self.nx.to_numpy(self.normalize_scales[0]),
1473+
"scale_fixed": self.nx.to_numpy(self.normalize_scales[1]),
14701474
}
14711475

14721476
self.vecfld = {
@@ -1488,4 +1492,5 @@ def _wrap_output(
14881492
"NA": self.NA,
14891493
"sigma2_variance": self.nx.to_numpy(self.sigma2_variance),
14901494
"method": "Spateo",
1495+
"norm_dict": norm_dict,
14911496
}

spateo/alignment/methods/morpho_mesh_correction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
try:
3232
from .libfastpd import fastpd
3333
except ImportError:
34-
print("fastpd is not installed. Please compile the fastpd library.")
34+
print("fastpd is not installed. If you need mesh correction, please compile the fastpd library.")
3535

3636

3737
# TODO: add str as the input type for the models

0 commit comments

Comments
 (0)