Skip to content

Commit b7d315c

Browse files
Re-Added cluster_method arg in scc
1 parent 5b68e1a commit b7d315c

File tree

4 files changed

+84
-10
lines changed

4 files changed

+84
-10
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ nbconvert
1919
networkx>=2.6.3
2020
# ngs_tools>=1.6.0
2121
numba>=0.46.0
22-
numpy>=1.18.1,<=1.23.5
22+
numpy>=1.18.1
2323
opencv-python>=4.5.4.60
2424
# pandana
25-
pandas>=0.25.1,<=1.5.3
25+
pandas>=0.25.1
2626
# paste-bio>=1.4.0
2727
plotly>=5.1.0
2828
POT>=0.8.1

spateo/tools/cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .cluster_spagcn import spagcn_vanilla
2-
from .find_clusters import optimize_cluster, scc, spagcn_pyg
2+
from .find_clusters import scc, smooth, spagcn_pyg
33
from .utils import (
44
compute_pca_components,
55
ecp_silhouette,

spateo/tools/cluster/find_clusters.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from scipy.spatial import distance
1313

1414
from ...configuration import SKM
15-
from .leiden import calculate_leiden_partition
15+
from .leiden import calculate_leiden_partition, calculate_louvain_partition
1616
from .spagcn_utils import *
1717
from .utils import spatial_adj
1818

@@ -195,6 +195,7 @@ def scc(
195195
e_neigh: int = 30,
196196
s_neigh: int = 6,
197197
resolution: Optional[float] = None,
198+
cluster_method="louvain",
198199
) -> Optional[anndata.AnnData]:
199200
"""Spatially constrained clustering (scc) to identify continuous tissue domains.
200201
@@ -213,7 +214,7 @@ def scc(
213214
pca_key: label for the .obsm key containing PCA information (without the potential prefix "X_")
214215
e_neigh: the number of nearest neighbor in gene expression space.
215216
s_neigh: the number of nearest neighbor in physical space.
216-
resolution: the resolution parameter of the louvain clustering algorithm.
217+
resolution: the resolution parameter of the leiden clustering algorithm.
217218
218219
Returns:
219220
adata: An `~anndata.AnnData` object with cluster info in .obs.
@@ -229,10 +230,16 @@ def scc(
229230
)
230231

231232
# Perform Leiden clustering:
232-
clusters = calculate_leiden_partition(
233-
adj=adj,
234-
resolution=resolution,
235-
)
233+
if cluster_method == "louvain":
234+
clusters = calculate_louvain_partition(
235+
adj=adj,
236+
resolution=resolution,
237+
)
238+
else:
239+
clusters = calculate_leiden_partition(
240+
adj=adj,
241+
resolution=resolution,
242+
)
236243

237244
adata.obs[key_added] = clusters
238245
adata.obs[key_added] = adata.obs[key_added].astype(str)
@@ -241,7 +248,7 @@ def scc(
241248

242249

243250
@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE)
244-
def optimize_cluster(adata: anndata.AnnData, radius: int = 50, key: str = "label") -> list:
251+
def smooth(adata: anndata.AnnData, radius: int = 50, key: str = "label") -> list:
245252
"""
246253
Optimize the label by majority voting in the neighborhood.
247254

spateo/tools/cluster/leiden.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,70 @@ def calculate_leiden_partition(
121121
clusters = np.array(partition.membership, dtype=int)
122122
logger.finish_progress(progress_name="Community clustering with %s" % ("leiden"))
123123
return clusters
124+
125+
126+
def calculate_louvain_partition(
127+
adj: Optional[Union[scipy.sparse.spmatrix, np.ndarray]] = None,
128+
input_mat: Optional[np.ndarray] = None,
129+
num_neighbors: int = 10,
130+
graph_type: Literal["distance", "embedding"] = "distance",
131+
resolution: float = 1.0,
132+
n_iterations: int = -1,
133+
) -> np.ndarray:
134+
"""Performs Louvain clustering on a given dataset.
135+
136+
Args:
137+
adj: Optional precomputed adjacency matrix
138+
input_mat: Optional, will be used only if 'adj' is not given. The input data, will be interepreted as either a
139+
distance matrix (if :param `graph_type` is "distance" or an embedding matrix (if :param `graph_type` is
140+
"embedding")
141+
num_neighbors: Only used if 'adj' is not given- the number of nearest neighbors for constructing the graph
142+
graph_type: Only used if 'adj' is not given- specifies the input type, either 'distance' or 'embedding'
143+
resolution: The resolution parameter for the Louvain algorithm
144+
n_iterations: The number of iterations for the Louvain algorithm (-1 for unlimited iterations)
145+
146+
Returns:
147+
clusters: Array containing cluster assignments
148+
"""
149+
import louvain
150+
151+
from ...logging import logger_manager as lm
152+
153+
logger = lm.get_main_logger()
154+
if adj is None and input_mat is None:
155+
raise ValueError("Either `adj` or `input_mat` must be specified")
156+
157+
logger.info("using adj_matrix from arg for clustering...")
158+
159+
if adj is not None:
160+
if scipy.sparse.issparse(adj):
161+
pass
162+
else:
163+
adj = scipy.sparse.csr_matrix(adj)
164+
sources, targets = adj.nonzero()
165+
weights = adj[sources, targets]
166+
if isinstance(weights, np.matrix):
167+
weights = weights.A1
168+
G = igraph.Graph(directed=None)
169+
G.add_vertices(adj.shape[0]) # this adds adjacency.shape[0] vertices
170+
G.add_edges(list(zip(sources, targets)))
171+
try:
172+
G.es["weight"] = weights
173+
except KeyError:
174+
pass
175+
if G.vcount() != adj.shape[0]:
176+
print(
177+
f"The constructed graph has only {G.vcount()} nodes. "
178+
"Your adjacency matrix contained redundant nodes."
179+
)
180+
else:
181+
if graph_type == "distance":
182+
G = distance_knn_graph(input_mat, num_neighbors)
183+
elif graph_type == "embedding":
184+
G = embedding_knn_graph(input_mat, num_neighbors)
185+
logger.info("Converting graph_sparse_matrix to igraph object", indent_level=2)
186+
partition_kwargs = {"resolution_parameter": resolution, "seed": 42, "weights": G.es["weight"]}
187+
partition = louvain.find_partition(G, louvain.RBConfigurationVertexPartition, **partition_kwargs)
188+
clusters = np.array(partition.membership, dtype=int)
189+
logger.finish_progress(progress_name="Community clustering with %s" % ("louvain"))
190+
return clusters

0 commit comments

Comments
 (0)