@@ -728,6 +728,96 @@ def align_preprocess(
728728 )
729729
730730
731+ def paste_align_preprocess (
732+ samples : List [AnnData ],
733+ genes : Optional [Union [list , np .ndarray ]] = None ,
734+ spatial_key : str = "spatial" ,
735+ layer : str = "X" ,
736+ use_rep : Optional [str ] = None ,
737+ normalize_c : bool = False ,
738+ normalize_g : bool = False ,
739+ select_high_exp_genes : Union [bool , float , int ] = False ,
740+ dtype : str = "float64" ,
741+ device : str = "cpu" ,
742+ verbose : bool = True ,
743+ ** kwargs ,
744+ ) -> Tuple [
745+ TorchBackend or NumpyBackend ,
746+ torch .Tensor or np .ndarray ,
747+ list ,
748+ list ,
749+ list ,
750+ Optional [float ],
751+ Optional [list ],
752+ ]:
753+ """
754+ Data preprocessing before alignment.
755+
756+ Args:
757+ samples: A list of anndata object.
758+ genes: Genes used for calculation. If None, use all common genes for calculation.
759+ spatial_key: The key in `.obsm` that corresponds to the raw spatial coordinates.
760+ layer: If `'X'`, uses ``sample.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``sample.layers[layer]``.
761+ normalize_c: Whether to normalize spatial coordinates.
762+ normalize_g: Whether to normalize gene expression.
763+ select_high_exp_genes: Whether to select genes with high differences in gene expression.
764+ dtype: The floating-point number type. Only float32 and float64.
765+ device: Equipment used to run the program. You can also set the specified GPU for running. E.g.: '0'.
766+ verbose: If ``True``, print progress updates.
767+ """
768+
769+ # Determine if gpu or cpu is being used
770+ nx , type_as = check_backend (device = device , dtype = dtype )
771+ # Subset for common genes
772+ new_samples = [s .copy () for s in samples ]
773+ all_samples_genes = [s [0 ].var .index for s in new_samples ]
774+ common_genes = filter_common_genes (* all_samples_genes , verbose = verbose )
775+ common_genes = common_genes if genes is None else intersect_lsts (common_genes , genes )
776+ new_samples = [s [:, common_genes ] for s in new_samples ]
777+
778+ # Gene expression matrix of all samples
779+ if (use_rep is None ) or (not isinstance (use_rep , str )) or (use_rep not in samples [0 ].obsm .keys ()) or (use_rep not in samples [1 ].obsm .keys ()):
780+ exp_matrices = [nx .from_numpy (check_exp (sample = s , layer = layer ), type_as = type_as ) for s in new_samples ]
781+ else :
782+ exp_matrices = [nx .from_numpy (s .obsm [use_rep ], type_as = type_as ) for s in new_samples ] + [nx .from_numpy (check_exp (sample = s , layer = layer ), type_as = type_as ) for s in new_samples ]
783+ if not (select_high_exp_genes is False ):
784+ # Select significance genes if select_high_exp_genes is True
785+ ExpressionData = _cat (nx = nx , x = exp_matrices , dim = 0 )
786+
787+ ExpressionVar = _var (nx , ExpressionData , 0 )
788+ exp_threshold = 10 if isinstance (select_high_exp_genes , bool ) else select_high_exp_genes
789+ EvidenceExpression = nx .where (ExpressionVar > exp_threshold )[0 ]
790+ exp_matrices = [exp_matrix [:, EvidenceExpression ] for exp_matrix in exp_matrices ]
791+ if verbose :
792+ lm .main_info (message = f"Evidence expression number: { len (EvidenceExpression )} ." )
793+
794+ # Spatial coordinates of all samples
795+ spatial_coords = [
796+ nx .from_numpy (check_spatial_coords (sample = s , spatial_key = spatial_key ), type_as = type_as ) for s in new_samples
797+ ]
798+ coords_dims = nx .unique (_data (nx , [c .shape [1 ] for c in spatial_coords ], type_as ))
799+ # coords_dims = np.unique(np.asarray([c.shape[1] for c in spatial_coords]))
800+ assert len (coords_dims ) == 1 , "Spatial coordinate dimensions are different, please check again."
801+
802+ normalize_scale_list , normalize_mean_list = None , None
803+ if normalize_c :
804+ spatial_coords , normalize_scale_list , normalize_mean_list = normalize_coords (
805+ coords = spatial_coords , nx = nx , verbose = verbose
806+ )
807+ if normalize_g and ((use_rep is None ) or (not isinstance (use_rep , str )) or (use_rep not in samples [0 ].obsm .keys ()) or (use_rep not in samples [1 ].obsm .keys ())):
808+ exp_matrices = normalize_exps (matrices = exp_matrices , nx = nx , verbose = verbose )
809+
810+ return (
811+ nx ,
812+ type_as ,
813+ new_samples ,
814+ exp_matrices ,
815+ spatial_coords ,
816+ normalize_scale_list ,
817+ normalize_mean_list ,
818+ )
819+
820+
731821# Finished
732822def guidance_pair_preprocess (
733823 nx : Union [TorchBackend , NumpyBackend ],
0 commit comments