From d40edbf75440201d5cd7fcc54e6a226bee61d850 Mon Sep 17 00:00:00 2001 From: Yao-14 Date: Wed, 13 Sep 2023 12:31:32 +0800 Subject: [PATCH] update backbone clustering --- spateo/tdr/models/__init__.py | 1 + spateo/tdr/models/models_backbone/__init__.py | 2 +- spateo/tdr/models/models_backbone/backbone.py | 82 +++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/spateo/tdr/models/__init__.py b/spateo/tdr/models/__init__.py index 285bb8a3..cbbaac63 100644 --- a/spateo/tdr/models/__init__.py +++ b/spateo/tdr/models/__init__.py @@ -2,6 +2,7 @@ ElPiGraph_method, PrinCurve_method, SimplePPT_method, + backbone_scc, construct_backbone, map_gene_to_backbone, map_points_to_backbone, diff --git a/spateo/tdr/models/models_backbone/__init__.py b/spateo/tdr/models/models_backbone/__init__.py index 1bd195fb..b463431b 100644 --- a/spateo/tdr/models/models_backbone/__init__.py +++ b/spateo/tdr/models/models_backbone/__init__.py @@ -1,3 +1,3 @@ -from .backbone import construct_backbone, update_backbone +from .backbone import backbone_scc, construct_backbone, update_backbone from .backbone_methods import ElPiGraph_method, PrinCurve_method, SimplePPT_method from .backbone_utils import map_gene_to_backbone, map_points_to_backbone diff --git a/spateo/tdr/models/models_backbone/backbone.py b/spateo/tdr/models/models_backbone/backbone.py index 58d91051..851d5ec5 100644 --- a/spateo/tdr/models/models_backbone/backbone.py +++ b/spateo/tdr/models/models_backbone/backbone.py @@ -5,9 +5,12 @@ except ImportError: from typing_extensions import Literal +import anndata as ad import numpy as np +import pandas as pd from anndata import AnnData from pyvista import PolyData, UnstructuredGrid +from scipy.sparse import issparse from scipy.spatial.distance import cdist @@ -149,3 +152,82 @@ def update_backbone( updated_backbone.point_data[key_added] = np.arange(0, updated_backbone.n_points, 1) return updated_backbone + + +def backbone_scc( + adata: AnnData, + backbone: PolyData, + genes: Optional[list] = None, + adata_nodes_key: str = "backbone_nodes", + backbone_nodes_key: str = "updated_nodes", + key_added: Optional[str] = "backbone_scc", + layer: Optional[str] = None, + e_neigh: int = 10, + s_neigh: int = 6, + cluster_method: Literal["leiden", "louvain"] = "leiden", + resolution: Optional[float] = None, + inplace: bool = True, +) -> Optional[AnnData]: + """ + Spatially constrained clustering (scc) along the backbone. + + Args: + adata: The anndata object. + backbone: The backbone model. + genes: The list of genes that will be used to subset the data for clustering. If ``genes = None``, all genes will be used. + adata_nodes_key: The key that corresponds to the nodes in the adata. + backbone_nodes_key: The key that corresponds to the nodes in the backbone. + key_added: adata.obs key under which to add the cluster labels. + layer: The layer that will be used to retrieve data for dimension reduction and clustering. If ``layer = None``, ``.X`` is used. + e_neigh: the number of nearest neighbor in gene expression space. + s_neigh: the number of nearest neighbor in physical space. + cluster_method: the method that will be used to cluster the cells. + resolution: the resolution parameter of the louvain clustering algorithm. + inplace: Whether to copy adata or modify it inplace. + + Returns: + An ``AnnData`` object is updated/copied with the ``key_added`` in the ``.obs`` attribute, storing the clustering results. + """ + import dynamo as dyn + from dynamo.tools.utils import fetch_X_data + + from ....tools import scc + + adata = adata if inplace else adata.copy() + if "pp" not in adata.uns.keys(): + adata.uns["pp"] = {} + genes, X_data = fetch_X_data(adata, genes, layer) + X_data = X_data.A if issparse(X_data) else X_data + X_data = pd.DataFrame(X_data, columns=genes) + X_data[adata_nodes_key] = adata.obs[adata_nodes_key].values + X_data = pd.DataFrame(X_data.groupby(by=adata_nodes_key).mean()) + backbone_nodes = X_data.index + + X_spatial = pd.DataFrame(backbone.points, index=backbone.point_data[backbone_nodes_key]) + X_spatial = X_spatial.loc[backbone_nodes, :].values + + backbone_adata = ad.AnnData( + X=X_data.values, + var=pd.DataFrame(index=X_data.columns), + obs=pd.DataFrame(backbone_nodes, columns=[adata_nodes_key]), + obsm={"spatial": X_spatial}, + uns={"__type": "UMI", "pp": {}}, + ) + + dyn.pp.normalize(backbone_adata) + dyn.pp.log1p(backbone_adata) + backbone_adata.obsm["X_backbone"] = backbone_adata.X + scc( + backbone_adata, + spatial_key="spatial", + pca_key="X_backbone", + e_neigh=e_neigh, + s_neigh=s_neigh, + resolution=resolution, + key_added="scc", + cluster_method=cluster_method, + ) + + cluster_dict = {i: c for i, c in zip(backbone_adata.obs[adata_nodes_key], backbone_adata.obs["scc"])} + adata.obs[key_added] = adata.obs[adata_nodes_key].map(lambda x: cluster_dict[x]) + return None if inplace else adata