diff --git a/spateo/alignment/methods/__init__.py b/spateo/alignment/methods/__init__.py index 9723e231..eb8c797f 100644 --- a/spateo/alignment/methods/__init__.py +++ b/spateo/alignment/methods/__init__.py @@ -2,6 +2,7 @@ # from .morpho_sparse import BA_align_sparse from .backend import NumpyBackend, TorchBackend from .deprecated_utils import ( + paste_align_preprocess, align_preprocess, cal_dist, cal_dot, diff --git a/spateo/alignment/methods/backend.py b/spateo/alignment/methods/backend.py index 671fe6d8..559b28dd 100644 --- a/spateo/alignment/methods/backend.py +++ b/spateo/alignment/methods/backend.py @@ -1067,7 +1067,7 @@ def data(self, a, type_as=None): else: return np.asarray(a, dtype=type_as.dtype) - def unique(self, a, return_index, return_inverse=False, axis=None): + def unique(self, a, return_index=False, return_inverse=False, axis=None): return np.unique(a, return_index=return_index, return_inverse=return_inverse, axis=axis) def unsqueeze(self, a, axis=-1): diff --git a/spateo/alignment/methods/deprecated_utils.py b/spateo/alignment/methods/deprecated_utils.py index 5594e4d1..aa6e9986 100755 --- a/spateo/alignment/methods/deprecated_utils.py +++ b/spateo/alignment/methods/deprecated_utils.py @@ -728,6 +728,96 @@ def align_preprocess( ) +def paste_align_preprocess( + samples: List[AnnData], + genes: Optional[Union[list, np.ndarray]] = None, + spatial_key: str = "spatial", + layer: str = "X", + use_rep: Optional[str] = None, + normalize_c: bool = False, + normalize_g: bool = False, + select_high_exp_genes: Union[bool, float, int] = False, + dtype: str = "float64", + device: str = "cpu", + verbose: bool = True, + **kwargs, +) -> Tuple[ + TorchBackend or NumpyBackend, + torch.Tensor or np.ndarray, + list, + list, + list, + Optional[float], + Optional[list], +]: + """ + Data preprocessing before alignment. + + Args: + samples: A list of anndata object. + genes: Genes used for calculation. If None, use all common genes for calculation. + spatial_key: The key in `.obsm` that corresponds to the raw spatial coordinates. + layer: If `'X'`, uses ``sample.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``sample.layers[layer]``. + normalize_c: Whether to normalize spatial coordinates. + normalize_g: Whether to normalize gene expression. + select_high_exp_genes: Whether to select genes with high differences in gene expression. + dtype: The floating-point number type. Only float32 and float64. + device: Equipment used to run the program. You can also set the specified GPU for running. E.g.: '0'. + verbose: If ``True``, print progress updates. + """ + + # Determine if gpu or cpu is being used + nx, type_as = check_backend(device=device, dtype=dtype) + # Subset for common genes + new_samples = [s.copy() for s in samples] + all_samples_genes = [s[0].var.index for s in new_samples] + common_genes = filter_common_genes(*all_samples_genes, verbose=verbose) + common_genes = common_genes if genes is None else intersect_lsts(common_genes, genes) + new_samples = [s[:, common_genes] for s in new_samples] + + # Gene expression matrix of all samples + if (use_rep is None) or (not isinstance(use_rep, str)) or (use_rep not in samples[0].obsm.keys()) or (use_rep not in samples[1].obsm.keys()): + exp_matrices = [nx.from_numpy(check_exp(sample=s, layer=layer), type_as=type_as) for s in new_samples] + else: + exp_matrices = [nx.from_numpy(s.obsm[use_rep], type_as=type_as) for s in new_samples] + [nx.from_numpy(check_exp(sample=s, layer=layer), type_as=type_as) for s in new_samples] + if not (select_high_exp_genes is False): + # Select significance genes if select_high_exp_genes is True + ExpressionData = _cat(nx=nx, x=exp_matrices, dim=0) + + ExpressionVar = _var(nx, ExpressionData, 0) + exp_threshold = 10 if isinstance(select_high_exp_genes, bool) else select_high_exp_genes + EvidenceExpression = nx.where(ExpressionVar > exp_threshold)[0] + exp_matrices = [exp_matrix[:, EvidenceExpression] for exp_matrix in exp_matrices] + if verbose: + lm.main_info(message=f"Evidence expression number: {len(EvidenceExpression)}.") + + # Spatial coordinates of all samples + spatial_coords = [ + nx.from_numpy(check_spatial_coords(sample=s, spatial_key=spatial_key), type_as=type_as) for s in new_samples + ] + coords_dims = nx.unique(_data(nx, [c.shape[1] for c in spatial_coords], type_as)) + # coords_dims = np.unique(np.asarray([c.shape[1] for c in spatial_coords])) + assert len(coords_dims) == 1, "Spatial coordinate dimensions are different, please check again." + + normalize_scale_list, normalize_mean_list = None, None + if normalize_c: + spatial_coords, normalize_scale_list, normalize_mean_list = normalize_coords( + coords=spatial_coords, nx=nx, verbose=verbose + ) + if normalize_g and ((use_rep is None) or (not isinstance(use_rep, str)) or (use_rep not in samples[0].obsm.keys()) or (use_rep not in samples[1].obsm.keys())): + exp_matrices = normalize_exps(matrices=exp_matrices, nx=nx, verbose=verbose) + + return ( + nx, + type_as, + new_samples, + exp_matrices, + spatial_coords, + normalize_scale_list, + normalize_mean_list, + ) + + # Finished def guidance_pair_preprocess( nx: Union[TorchBackend, NumpyBackend], diff --git a/spateo/alignment/methods/paste.py b/spateo/alignment/methods/paste.py index cc1723be..738ee6b1 100644 --- a/spateo/alignment/methods/paste.py +++ b/spateo/alignment/methods/paste.py @@ -9,6 +9,7 @@ from spateo.logging import logger_manager as lm from .deprecated_utils import ( + paste_align_preprocess, align_preprocess, calc_exp_dissimilarity, check_exp, @@ -70,7 +71,7 @@ def paste_pairwise_align( """ # Preprocessing - (nx, type_as, new_samples, exp_matrices, spatial_coords, normalize_scale, normalize_mean_list,) = align_preprocess( + (nx, type_as, new_samples, exp_matrices, spatial_coords, normalize_scale, normalize_mean_list,) = paste_align_preprocess( samples=[sampleA, sampleB], genes=genes, spatial_key=spatial_key, @@ -116,7 +117,7 @@ def paste_pairwise_align( except ImportError: from ot.gromov import cg - pi, log = ot.gromov.cg( + pi, log = cg( a, b, (1 - alpha) * M,