Skip to content

Commit 62fbc4c

Browse files
committed
upgrade to v1.2.1
1 parent d166c11 commit 62fbc4c

File tree

14 files changed

+399
-149
lines changed

14 files changed

+399
-149
lines changed

build/lib/paste/PASTE.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,53 @@
1+
from typing import List, Tuple, Optional
12
import numpy as np
2-
import anndata
3+
from anndata import AnnData
34
import ot
45
from sklearn.decomposition import NMF
5-
from .helper import kl_divergence, intersect, kl_divergence_backend, to_dense_array, extract_data_matrix
6-
import time
6+
from .helper import intersect, kl_divergence_backend, to_dense_array, extract_data_matrix
77

8-
def pairwise_align(sliceA, sliceB, alpha = 0.1, dissimilarity='kl', use_rep = None, G_init = None, a_distribution = None, b_distribution = None, norm = False, numItermax = 200, backend=ot.backend.NumpyBackend(), use_gpu = False, return_obj = False, verbose = False, gpu_verbose = True, **kwargs):
8+
def pairwise_align(
9+
sliceA: AnnData,
10+
sliceB: AnnData,
11+
alpha: float = 0.1,
12+
dissimilarity: str ='kl',
13+
use_rep: Optional[str] = None,
14+
G_init = None,
15+
a_distribution = None,
16+
b_distribution = None,
17+
norm: bool = False,
18+
numItermax: int = 200,
19+
backend = ot.backend.NumpyBackend(),
20+
use_gpu: bool = False,
21+
return_obj: bool = False,
22+
verbose: bool = False,
23+
gpu_verbose: bool = True,
24+
**kwargs) -> Tuple[np.ndarray, Optional[int]]:
925
"""
1026
Calculates and returns optimal alignment of two slices.
1127
12-
param: sliceA - AnnData object of spatial slice
13-
param: sliceB - AnnData object of spatial slice
14-
param: alpha - Alignment tuning parameter. Note: 0 ≤ alpha ≤ 1
15-
param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
16-
param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
17-
param: G_init - initial mapping to be used in FGW-OT, otherwise default is uniform mapping
18-
param: a_distribution - distribution of sliceA spots (1-d numpy array), otherwise default is uniform
19-
param: b_distribution - distribution of sliceB spots (1-d numpy array), otherwise default is uniform
20-
param: numItermax - max number of iterations during FGW-OT
21-
param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
22-
param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
23-
param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
24-
param: return_obj - returns objective function output of FGW-OT if True, nothing if False
25-
param: verbose - FGW-OT is verbose if True, nothing if False
26-
param: gpu_verbose - Print whether gpu is being used to user, nothing if False
28+
Args:
29+
sliceA: Slice A to align.
30+
sliceB: Slice B to align.
31+
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
32+
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
33+
use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``.
34+
G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping.
35+
a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform.
36+
b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform.
37+
numItermax: Max number of iterations during FGW-OT.
38+
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
39+
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
40+
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
41+
return_obj: If ``True``, additionally returns objective function output of FGW-OT.
42+
verbose: If ``True``, FGW-OT is verbose.
43+
gpu_verbose: If ``True``, print whether gpu is being used to user.
2744
28-
29-
return: pi - alignment of spots
30-
return: log['fgw_dist'] - objective function output of FGW-OT
45+
Returns:
46+
- Alignment of spots.
47+
48+
If ``return_obj = True``, additionally returns:
49+
50+
- Objective function output of FGW-OT.
3151
"""
3252

3353
# Determine if gpu or cpu is being used
@@ -131,31 +151,47 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, dissimilarity='kl', use_rep = No
131151
return pi
132152

133153

134-
def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, threshold = 0.001, max_iter = 10, dissimilarity='kl', use_rep = None, norm = False, random_seed = None, pis_init = None, distributions=None, backend = ot.backend.NumpyBackend(), use_gpu = False, verbose = False, gpu_verbose = True):
154+
def center_align(
155+
A: AnnData,
156+
slices: List[AnnData],
157+
lmbda = None,
158+
alpha: float = 0.1,
159+
n_components: int = 15,
160+
threshold: float = 0.001,
161+
max_iter: int = 10,
162+
dissimilarity: str ='kl',
163+
norm: bool = False,
164+
random_seed: Optional[int] = None,
165+
pis_init: Optional[List[np.ndarray]] = None,
166+
distributions = None,
167+
backend = ot.backend.NumpyBackend(),
168+
use_gpu: bool = False,
169+
verbose: bool = False,
170+
gpu_verbose: bool = True) -> Tuple[AnnData, List[np.ndarray]]:
135171
"""
136172
Computes center alignment of slices.
137173
138-
param: A - Initialization of starting AnnData Spatial Object; Make sure to include gene expression AND spatial info
139-
param: slices - List of slices (AnnData objects) used to calculate center alignment
140-
param: lmbda - List of probability weights assigned to each slice; default is uniform weights
141-
param: n_components - Number of components in NMF decomposition
142-
param: threshold - Threshold for convergence of W and H
143-
param: max_iter - maximum number of iterations for solving for center slice
144-
param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
145-
param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
146-
param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
147-
param: random_seed - set random seed for reproducibility
148-
param: pis_init - initial list of mappings between 'A' and 'slices' to solver, otherwise will calculate default mappings
149-
param: distributions - distributions of spots for each slice (list of 1-d numpy array), otherwise default is uniform
150-
param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
151-
param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
152-
param: verbose - FGW-OT is verbose if True, nothing if False
153-
param: gpu_verbose - Print whether gpu is being used to user, nothing if False
174+
Args:
175+
A: Slice to use as the initialization for center alignment; Make sure to include gene expression and spatial information.
176+
slices: List of slices to use in the center alignment.
177+
lmbda (array-like, optional): List of probability weights assigned to each slice; If ``None``, use uniform weights.
178+
alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
179+
n_components: Number of components in NMF decomposition.
180+
threshold: Threshold for convergence of W and H during NMF decomposition.
181+
max_iter: Maximum number of iterations for our center alignment algorithm.
182+
dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
183+
norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
184+
random_seed: Set random seed for reproducibility.
185+
pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings.
186+
distributions (List[array-like], optional): Distributions of spots for each slice. Otherwise, default is uniform.
187+
backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
188+
use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
189+
verbose: If ``True``, FGW-OT is verbose.
190+
gpu_verbose: If ``True``, print whether gpu is being used to user.
154191
155-
156-
return: center_slice - inferred center slice (AnnData object) with full and low dimensional representations (W, H) of
157-
the gene expression matrix
158-
return: pi - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns)
192+
Returns:
193+
- Inferred center slice with full and low dimensional representations (W, H) of the gene expression matrix.
194+
- List of pairwise alignment mappings of the center slice (rows) to each input slice (columns).
159195
"""
160196

161197
# Determine if gpu or cpu is being used
@@ -213,10 +249,10 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
213249
center_coordinates = A.obsm['spatial']
214250

215251
if not isinstance(center_coordinates, np.ndarray):
216-
print("Warning: A.obsm['spatial'] is not of type numpy array .")
252+
print("Warning: A.obsm['spatial'] is not of type numpy array.")
217253

218254
# Initialize center_slice
219-
center_slice = anndata.AnnData(np.dot(W,H))
255+
center_slice = AnnData(np.dot(W,H))
220256
center_slice.var.index = common_genes
221257
center_slice.obs.index = A.obs.index
222258
center_slice.obsm['spatial'] = center_coordinates
@@ -246,7 +282,7 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
246282
#--------------------------- HELPER METHODS -----------------------------------
247283

248284
def center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity = 'kl', norm = False, G_inits = None, distributions=None, verbose = False):
249-
center_slice = anndata.AnnData(np.dot(W,H))
285+
center_slice = AnnData(np.dot(W,H))
250286
center_slice.var.index = common_genes
251287
center_slice.obsm['spatial'] = center_coordinates
252288

@@ -313,4 +349,4 @@ def df(G):
313349
return res, log
314350

315351
else:
316-
return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
352+
return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)

build/lib/paste/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .PASTE import pairwise_align, center_align
2-
from .helper import kl_divergence, kl_divergence_backend, intersect, match_spots_using_spatial_heuristic, filter_for_common_genes
2+
from .helper import match_spots_using_spatial_heuristic
33
from .visualization import plot_slice, stack_slices_pairwise, stack_slices_center

build/lib/paste/helper.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,44 @@
22
import scipy
33
import ot
44

5-
def filter_for_common_genes(slices):
5+
def match_spots_using_spatial_heuristic(
6+
X,
7+
Y,
8+
use_ot: bool = True) -> np.ndarray:
69
"""
7-
param: slices - list of slices (AnnData objects)
8-
"""
9-
assert len(slices) > 0, "Cannot have empty list."
10+
Calculates and returns a mapping of spots using a spatial heuristic.
11+
12+
Args:
13+
X (array-like, optional): Coordinates for spots X.
14+
Y (array-like, optional): Coordinates for spots Y.
15+
use_ot: If ``True``, use optimal transport ``ot.emd()`` to calculate mapping. Otherwise, use Scipy's ``min_weight_full_bipartite_matching()`` algorithm.
1016
11-
common_genes = slices[0].var.index
12-
for s in slices:
13-
common_genes = intersect(common_genes, s.var.index)
14-
for i in range(len(slices)):
15-
slices[i] = slices[i][:, common_genes]
16-
print('Filtered all slices for common genes. There are ' + str(len(common_genes)) + ' common genes.')
17+
Returns:
18+
Mapping of spots using a spatial heuristic.
19+
"""
20+
n1,n2=len(X),len(Y)
21+
X,Y = norm_and_center_coordinates(X),norm_and_center_coordinates(Y)
22+
dist = scipy.spatial.distance_matrix(X,Y)
23+
if use_ot:
24+
pi = ot.emd(np.ones(n1)/n1, np.ones(n2)/n2, dist)
25+
else:
26+
row_ind, col_ind = scipy.sparse.csgraph.min_weight_full_bipartite_matching(scipy.sparse.csr_matrix(dist))
27+
pi = np.zeros((n1,n2))
28+
pi[row_ind, col_ind] = 1/max(n1,n2)
29+
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
30+
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
31+
return pi
1732

1833
def kl_divergence(X, Y):
1934
"""
2035
Returns pairwise KL divergence (over all pairs of samples) of two matrices X and Y.
2136
22-
param: X - np array with dim (n_samples by n_features)
23-
param: Y - np array with dim (m_samples by n_features)
37+
Args:
38+
X: np array with dim (n_samples by n_features)
39+
Y: np array with dim (m_samples by n_features)
2440
25-
return: D - np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
41+
Returns:
42+
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
2643
"""
2744
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
2845

@@ -40,10 +57,12 @@ def kl_divergence_backend(X, Y):
4057
4158
Takes advantage of POT backend to speed up computation.
4259
43-
param: X - np array with dim (n_samples by n_features)
44-
param: Y - np array with dim (m_samples by n_features)
60+
Args:
61+
X: np array with dim (n_samples by n_features)
62+
Y: np array with dim (m_samples by n_features)
4563
46-
return: D - np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
64+
Returns:
65+
D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix.
4766
"""
4867
assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
4968

@@ -61,10 +80,14 @@ def kl_divergence_backend(X, Y):
6180

6281
def intersect(lst1, lst2):
6382
"""
64-
param: lst1 - list
65-
param: lst2 - list
83+
Gets and returns intersection of two lists.
84+
85+
Args:
86+
lst1: List
87+
lst2: List
6688
67-
return: list of common elements
89+
Returns:
90+
lst3: List of common elements.
6891
"""
6992

7093
temp = set(lst2)
@@ -73,33 +96,17 @@ def intersect(lst1, lst2):
7396

7497
def norm_and_center_coordinates(X):
7598
"""
76-
param: X - numpy array
99+
Normalizes and centers coordinates at the origin.
100+
101+
Args:
102+
X: Numpy array
77103
78-
return:
104+
Returns:
105+
X_new: Updated coordiantes.
79106
"""
80107
return (X-X.mean(axis=0))/min(scipy.spatial.distance.pdist(X))
81108

82109

83-
def match_spots_using_spatial_heuristic(X,Y,use_ot=True):
84-
"""
85-
param: X - numpy array
86-
param: Y - numpy array
87-
88-
return: pi- mapping of spots using spatial heuristic
89-
"""
90-
n1,n2=len(X),len(Y)
91-
X,Y = norm_and_center_coordinates(X),norm_and_center_coordinates(Y)
92-
dist = scipy.spatial.distance_matrix(X,Y)
93-
if use_ot:
94-
pi = ot.emd(np.ones(n1)/n1, np.ones(n2)/n2, dist)
95-
else:
96-
row_ind, col_ind = scipy.sparse.csgraph.min_weight_full_bipartite_matching(scipy.sparse.csr_matrix(dist))
97-
pi = np.zeros((n1,n2))
98-
pi[row_ind, col_ind] = 1/max(n1,n2)
99-
if n1<n2: pi[:, [(j not in col_ind) for j in range(n2)]] = 1/(n1*n2)
100-
elif n2<n1: pi[[(i not in row_ind) for i in range(n1)], :] = 1/(n1*n2)
101-
return pi
102-
103110
## Covert a sparse matrix into a dense np array
104111
to_dense_array = lambda X: X.toarray() if isinstance(X,scipy.sparse.csr.spmatrix) else np.array(X)
105112

0 commit comments

Comments
 (0)