1+ from typing import List , Tuple , Optional
12import numpy as np
2- import anndata
3+ from anndata import AnnData
34import ot
45from sklearn .decomposition import NMF
5- from .helper import kl_divergence , intersect , kl_divergence_backend , to_dense_array , extract_data_matrix
6- import time
6+ from .helper import intersect , kl_divergence_backend , to_dense_array , extract_data_matrix
77
8- def pairwise_align (sliceA , sliceB , alpha = 0.1 , dissimilarity = 'kl' , use_rep = None , G_init = None , a_distribution = None , b_distribution = None , norm = False , numItermax = 200 , backend = ot .backend .NumpyBackend (), use_gpu = False , return_obj = False , verbose = False , gpu_verbose = True , ** kwargs ):
8+ def pairwise_align (
9+ sliceA : AnnData ,
10+ sliceB : AnnData ,
11+ alpha : float = 0.1 ,
12+ dissimilarity : str = 'kl' ,
13+ use_rep : Optional [str ] = None ,
14+ G_init = None ,
15+ a_distribution = None ,
16+ b_distribution = None ,
17+ norm : bool = False ,
18+ numItermax : int = 200 ,
19+ backend = ot .backend .NumpyBackend (),
20+ use_gpu : bool = False ,
21+ return_obj : bool = False ,
22+ verbose : bool = False ,
23+ gpu_verbose : bool = True ,
24+ ** kwargs ) -> Tuple [np .ndarray , Optional [int ]]:
925 """
1026 Calculates and returns optimal alignment of two slices.
1127
12- param: sliceA - AnnData object of spatial slice
13- param: sliceB - AnnData object of spatial slice
14- param: alpha - Alignment tuning parameter. Note: 0 ≤ alpha ≤ 1
15- param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
16- param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
17- param: G_init - initial mapping to be used in FGW-OT, otherwise default is uniform mapping
18- param: a_distribution - distribution of sliceA spots (1-d numpy array), otherwise default is uniform
19- param: b_distribution - distribution of sliceB spots (1-d numpy array), otherwise default is uniform
20- param: numItermax - max number of iterations during FGW-OT
21- param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
22- param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
23- param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
24- param: return_obj - returns objective function output of FGW-OT if True, nothing if False
25- param: verbose - FGW-OT is verbose if True, nothing if False
26- param: gpu_verbose - Print whether gpu is being used to user, nothing if False
28+ Args:
29+ sliceA: Slice A to align.
30+ sliceB: Slice B to align.
31+ alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
32+ dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
33+ use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``.
34+ G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping.
35+ a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform.
36+ b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform.
37+ numItermax: Max number of iterations during FGW-OT.
38+ norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
39+ backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
40+ use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
41+ return_obj: If ``True``, additionally returns objective function output of FGW-OT.
42+ verbose: If ``True``, FGW-OT is verbose.
43+ gpu_verbose: If ``True``, print whether gpu is being used to user.
2744
28-
29- return: pi - alignment of spots
30- return: log['fgw_dist'] - objective function output of FGW-OT
45+ Returns:
46+ - Alignment of spots.
47+
48+ If ``return_obj = True``, additionally returns:
49+
50+ - Objective function output of FGW-OT.
3151 """
3252
3353 # Determine if gpu or cpu is being used
@@ -131,31 +151,47 @@ def pairwise_align(sliceA, sliceB, alpha = 0.1, dissimilarity='kl', use_rep = No
131151 return pi
132152
133153
134- def center_align (A , slices , lmbda = None , alpha = 0.1 , n_components = 15 , threshold = 0.001 , max_iter = 10 , dissimilarity = 'kl' , use_rep = None , norm = False , random_seed = None , pis_init = None , distributions = None , backend = ot .backend .NumpyBackend (), use_gpu = False , verbose = False , gpu_verbose = True ):
154+ def center_align (
155+ A : AnnData ,
156+ slices : List [AnnData ],
157+ lmbda = None ,
158+ alpha : float = 0.1 ,
159+ n_components : int = 15 ,
160+ threshold : float = 0.001 ,
161+ max_iter : int = 10 ,
162+ dissimilarity : str = 'kl' ,
163+ norm : bool = False ,
164+ random_seed : Optional [int ] = None ,
165+ pis_init : Optional [List [np .ndarray ]] = None ,
166+ distributions = None ,
167+ backend = ot .backend .NumpyBackend (),
168+ use_gpu : bool = False ,
169+ verbose : bool = False ,
170+ gpu_verbose : bool = True ) -> Tuple [AnnData , List [np .ndarray ]]:
135171 """
136172 Computes center alignment of slices.
137173
138- param: A - Initialization of starting AnnData Spatial Object; Make sure to include gene expression AND spatial info
139- param: slices - List of slices (AnnData objects) used to calculate center alignment
140- param: lmbda - List of probability weights assigned to each slice; default is uniform weights
141- param: n_components - Number of components in NMF decomposition
142- param: threshold - Threshold for convergence of W and H
143- param: max_iter - maximum number of iterations for solving for center slice
144- param: dissimilarity - Expression dissimilarity measure: 'kl' or 'euclidean'
145- param: use_rep - If none, uses slice.X to calculate dissimilarity between spots, otherwise uses the representation given by slice.obsm[use_rep]
146- param: norm - scales spatial distances such that neighboring spots are at distance 1 if True, otherwise spatial distances remain unchanged
147- param: random_seed - set random seed for reproducibility
148- param: pis_init - initial list of mappings between 'A' and 'slices' to solver, otherwise will calculate default mappings
149- param: distributions - distributions of spots for each slice (list of 1-d numpy array), otherwise default is uniform
150- param: backend - type of backend to run calculations. For list of backends available on system: ot.backend.get_backend_list()
151- param: use_gpu - Whether to run on gpu or cpu. Currently we only have gpu support for Pytorch.
152- param: verbose - FGW-OT is verbose if True, nothing if False
153- param: gpu_verbose - Print whether gpu is being used to user, nothing if False
174+ Args:
175+ A: Slice to use as the initialization for center alignment; Make sure to include gene expression and spatial information.
176+ slices: List of slices to use in the center alignment.
177+ lmbda (array-like, optional): List of probability weights assigned to each slice; If ``None``, use uniform weights.
178+ alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1.
179+ n_components: Number of components in NMF decomposition.
180+ threshold: Threshold for convergence of W and H during NMF decomposition.
181+ max_iter: Maximum number of iterations for our center alignment algorithm.
182+ dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``.
183+ norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged.
184+ random_seed: Set random seed for reproducibility.
185+ pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings.
186+ distributions (List[array-like], optional): Distributions of spots for each slice. Otherwise, default is uniform.
187+ backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``.
188+ use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch.
189+ verbose: If ``True``, FGW-OT is verbose.
190+ gpu_verbose: If ``True``, print whether gpu is being used to user.
154191
155-
156- return: center_slice - inferred center slice (AnnData object) with full and low dimensional representations (W, H) of
157- the gene expression matrix
158- return: pi - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns)
192+ Returns:
193+ - Inferred center slice with full and low dimensional representations (W, H) of the gene expression matrix.
194+ - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns).
159195 """
160196
161197 # Determine if gpu or cpu is being used
@@ -213,10 +249,10 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
213249 center_coordinates = A .obsm ['spatial' ]
214250
215251 if not isinstance (center_coordinates , np .ndarray ):
216- print ("Warning: A.obsm['spatial'] is not of type numpy array ." )
252+ print ("Warning: A.obsm['spatial'] is not of type numpy array." )
217253
218254 # Initialize center_slice
219- center_slice = anndata . AnnData (np .dot (W ,H ))
255+ center_slice = AnnData (np .dot (W ,H ))
220256 center_slice .var .index = common_genes
221257 center_slice .obs .index = A .obs .index
222258 center_slice .obsm ['spatial' ] = center_coordinates
@@ -246,7 +282,7 @@ def center_align(A, slices, lmbda = None, alpha = 0.1, n_components = 15, thresh
246282#--------------------------- HELPER METHODS -----------------------------------
247283
248284def center_ot (W , H , slices , center_coordinates , common_genes , alpha , backend , use_gpu , dissimilarity = 'kl' , norm = False , G_inits = None , distributions = None , verbose = False ):
249- center_slice = anndata . AnnData (np .dot (W ,H ))
285+ center_slice = AnnData (np .dot (W ,H ))
250286 center_slice .var .index = common_genes
251287 center_slice .obsm ['spatial' ] = center_coordinates
252288
@@ -313,4 +349,4 @@ def df(G):
313349 return res , log
314350
315351 else :
316- return ot .gromov .cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
352+ return ot .gromov .cg (p , q , (1 - alpha ) * M , alpha , f , df , G0 , armijo = armijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
0 commit comments