diff --git a/requirements.txt b/requirements.txt index 1cff1227..dcf30e08 100755 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,11 @@ gpytorch ipywidgets>=7.7.1 kneed>=0.7.0 kornia>=0.6.4 +leidenalg>=0.10.0 loompy>=3.0.5 mapclassify>=2.4.2 matplotlib<=3.5.3 +mpi4py nbconvert networkx>=2.6.3 # ngs_tools>=1.6.0 diff --git a/spateo/plotting/static/interactions.py b/spateo/plotting/static/interactions.py index 1064f362..e3c1aaec 100644 --- a/spateo/plotting/static/interactions.py +++ b/spateo/plotting/static/interactions.py @@ -4,6 +4,10 @@ """ from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from matplotlib.collections import PolyCollection +from matplotlib.ticker import StrMethodFormatter +from mpl_toolkits.axes_grid1 import make_axes_locatable + try: from typing import Literal except ImportError: @@ -17,708 +21,721 @@ import numpy as np import pandas as pd import scipy -import seaborn as sns from anndata import AnnData from matplotlib import rcParams -from matplotlib.collections import PolyCollection -from matplotlib.ticker import StrMethodFormatter -from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.cluster import hierarchy as sch from ...configuration import SKM, config_spateo_rcParams, set_pub_style from ...logging import logger_manager as lm from ...plotting.static.dotplot import CCDotplot -from ...tools.find_neighbors import generate_spatial_weights_fixed_nbrs +from ...tools.find_neighbors import neighbors from ...tools.labels import Label, interlabel_connections from .utils import _dendrogram_sig, save_return_show_fig_utils @SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") -def plot_connections( +def ligrec( adata: AnnData, - cat_key: str, - spatial_key: str = "spatial", - n_spatial_neighbors: Union[None, int] = 6, - spatial_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, - expr_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, - reverse_expr_plot_orientation: bool = False, - ax: Union[None, mpl.axes.Axes] = None, - figsize: tuple = (3, 3), - zero_self_connections: bool = True, - normalize_by_self_connections: bool = False, - shapes_style: bool = True, - label_outline: bool = False, - max_scale: float = 0.46, - colormap: Union[str, dict, "mpl.colormap"] = "Spectral", - title_str: Union[None, str] = None, - title_fontsize: Union[None, float] = None, - label_fontsize: Union[None, float] = None, + dict_key: str, + source_groups: Union[None, str, List[str]] = None, + target_groups: Union[None, str, List[str]] = None, + means_range: Tuple[float, float] = (-np.inf, np.inf), + pvalue_threshold: float = 1.0, + remove_empty_interactions: bool = True, + remove_nonsig_interactions: bool = False, + dendrogram: Union[None, str] = None, + alpha: float = 0.001, + swap_axes: bool = False, + title: Union[None, str] = None, + figsize: Union[None, Tuple[float, float]] = None, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, + **kwargs, ): - """Plot spatial_connections between labels- visualization of how closely labels are colocalized + """ + Dotplot for visualizing results of ligand-receptor interaction analysis - Args: - adata: AnnData object - cat_key: Key in .obs containing categorical grouping labels. Colocalization will be assessed - for pairwise combinations of these labels. - spatial_key: Key in .obsm containing coordinates in the physical space. Not used unless - 'spatial_weights_matrix' is None, in which case this is required. Defaults to "spatial". - n_spatial_neighbors: Optional, number of neighbors in the physical space for each cell. Not used unless - 'spatial_weights_matrix' is None. - spatial_weights_matrix: Spatial distance matrix, weighted by distance between spots. If not given, - will compute at runtime. - expr_weights_matrix: Gene expression distance matrix, weighted by distance in transcriptomic or PCA space. - If not given, only the spatial distance matrix will be plotted. If given, will plot the spatial distance - matrix in the left plot and the gene expression distance matrix in the right plot. - reverse_expr_plot_orientation: If True, plot the gene expression connections in the form of a lower right - triangle. If False, gene expression connections will be an upper left triangle just like the spatial - connections. - ax: Existing axes object, if applicable - figsize: Width x height of desired figure window in inches - zero_self_connections: If True, ignores intra-label interactions - normalize_by_self_connections: Only used if 'zero_self_connections' is False. If True, normalize intra-label - connections by the number of spots of that label - shapes_style: If True plots squares, if False plots heatmap - label_outline: If True, gives dark outline to axis tick label text - max_scale: Only used for the case that 'shape_style' is True, gives maximum size of square - colormap: Specifies colors to use for plotting. If dictionary, keys should be numerical labels corresponding - to those of the Label object. - title_str: Optionally used to give plot a title - title_fontsize: Size of plot title- only used if 'title_str' is given. - label_fontsize: Size of labels along the axes of the graph - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed - and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. + For each L:R pair on the dotplot, molecule 1 is sent from the cluster(s) labeled on the top of the plot (or on the + right, if 'swap_axes' is True), whereas molecule 2 is the receptor on the cluster(s) labeled on the bottom. - Returns: - (fig, ax): Returns plot and axis object if 'save_show_or_return' is "all" + Args: + adata: Object of :class `anndata.AnnData` + dict_key: Key in .uns to dictionary containing cell-cell communication information. Should contain keys labeled + "means" and "pvalues", with values being dataframes for the mean cell type-cell type L:R product and + significance values. + source_groups: Source interaction clusters. If `None`, select all clusters. + target_groups: Target interaction clusters. If `None`, select all clusters. + means_range: Only show interactions whose means are within this **closed** interval + pvalue_threshold: Only show interactions with p-value <= `pvalue_threshold` + remove_empty_interactions: Remove rows and columns that contain NaN values + remove_nonsig_interactions: Remove rows and columns that only contain interactions that are larger than `alpha` + dendrogram: How to cluster based on the p-values. Valid options are: + - None (no input) - do not perform clustering. + - `'interacting_molecules'` - cluster the interacting molecules. + - `'interacting_clusters'` - cluster the interacting clusters. + - `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown. + alpha: Significance threshold. All elements with p-values <= `alpha` will be marked by tori instead of dots. + swap_axes: Whether to show the cluster combinations as rows and the interacting pairs as columns + title: Title of the plot + figsize: The width and height of a figure + save_show_or_return: Options: "save", "show", "return", "both", "all" + - "both" for save and show + save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary + and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. But to change any of these + parameters, this dictionary can be used to do so. + kwargs : + Keyword arguments for :func `style` or :func `legend` of :class `Dotplot` """ - from ...plotting.static.utils import save_fig - from ...tools.utils import update_dict - logger = lm.get_main_logger() + config_spateo_rcParams() - title_fontsize = rcParams.get("axes.titlesize") if title_fontsize is None else title_fontsize - label_fontsize = rcParams.get("axes.labelsize") if label_fontsize is None else label_fontsize + set_pub_style() - if ax is None: - if expr_weights_matrix is not None: - figsize = (figsize[0] * 2.25, figsize[1]) - fig, axes = plt.subplots(1, 2, figsize=figsize) - ax_sp, ax_expr = axes[0], axes[1] + if figsize is None: + figsize = rcParams.get("figure.figsize") - if reverse_expr_plot_orientation: - # Allow subplot boundaries to technically be partially overlapping (for better visual) - box = ax_expr.get_position() - box.x0 = box.x0 - 0.3 - box.x1 = box.x1 - 0.3 - ax_expr.set_position(box) - else: - fig, ax_sp = plt.subplots(1, 1, figsize=figsize) - else: - ax = ax - if len(ax) > 1: - ax_sp, ax_expr = ax[0], ax[1] - else: - ax_sp = ax - fig = ax.get_figure() + if title is None: + title = "Ligand-Receptor Inference" - # Convert cell type labels to numerical using Label object: - categories_str_cat = np.unique(adata.obs[cat_key].values) - categories_num_cat = range(len(categories_str_cat)) - map_dict = dict(zip(categories_num_cat, categories_str_cat)) - categories_str = adata.obs[cat_key] - categories_num = adata.obs[cat_key].replace(categories_str_cat, categories_num_cat) + dict = adata.uns[dict_key] - label = Label(categories_num.to_numpy(), str_map=map_dict) + def filter_values( + pvals: pd.DataFrame, means: pd.DataFrame, *, mask: pd.DataFrame, kind: str + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + mask_rows = mask.any(axis=1) + pvals = pvals.loc[mask_rows] + means = means.loc[mask_rows] - # If spatial weights matrix is not given, compute it. 'spatial_key' needs to be present in the AnnData object: - if spatial_weights_matrix is None: - if spatial_key not in adata.obsm_keys(): - logger.error( - f"Given 'spatial_key' {spatial_key} does not exist as key in adata.obsm. Options: " - f"{adata.obsm_keys()}." - ) - spatial_weights_matrix, _, _ = generate_spatial_weights_fixed_nbrs( - adata, spatial_key=spatial_key, num_neighbors=n_spatial_neighbors, decay_type="reciprocal" - ) + if pvals.empty: + raise ValueError(f"After removing rows with only {kind} interactions, none remain.") - # Compute spatial connections array: - spatial_connections = interlabel_connections(label, spatial_weights_matrix) + mask_cols = mask.any(axis=0) + pvals = pvals.loc[:, mask_cols] + means = means.loc[:, mask_cols] - if zero_self_connections: - np.fill_diagonal(spatial_connections, 0) - elif normalize_by_self_connections: - spatial_connections /= spatial_connections.diagonal()[:, np.newaxis] + if pvals.empty: + raise ValueError(f"After removing columns with only {kind} interactions, none remain.") - spatial_connections_max = np.amax(spatial_connections) + return pvals, means - # Optionally, compute gene expression connections array: - if expr_weights_matrix is not None: - expr_connections = interlabel_connections(label, expr_weights_matrix) + def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]: + z_var = sch.linkage( + adata.X, + metric="correlation", + method=linkage, + # Unlikely to ever be profiling this many LR pairings, but cap at 1500 + optimal_ordering=adata.n_obs <= 1500, + ) + dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True) + # this is what the DotPlot requires + return { + "linkage": z_var, + "cat_key": ["groups"], + "cor_method": "pearson", + "use_rep": None, + "linkage_method": linkage, + "categories_ordered": dendro_info["ivl"], + "categories_idx_ordered": dendro_info["leaves"], + "dendrogram_info": dendro_info, + } - if zero_self_connections: - np.fill_diagonal(expr_connections, 0) - elif normalize_by_self_connections: - expr_connections /= expr_connections.diagonal()[:, np.newaxis] + if len(means_range) != 2: + logger.error(f"Expected `means_range` to be a sequence of size `2`, found `{len(means_range)}`.") + means_range = tuple(sorted(means_range)) - expr_connections_max = np.amax(expr_connections) + if alpha is not None and not (0 <= alpha <= 1): + logger.error(f"Expected `alpha` to be in range `[0, 1]`, found `{alpha}`.") - # Set label colors: - if isinstance(colormap, str): - cmap = mpl.cm.get_cmap(colormap) - else: - cmap = colormap + if source_groups is None: + source_groups = dict["pvalues"].columns.get_level_values(0) + elif isinstance(source_groups, str): + source_groups = (source_groups,) - # If colormap is given, map label ID to points along the colormap. If dictionary is given, instead map each label - # to a color using the dictionary keys as guides. - if isinstance(cmap, dict): - if type(list(cmap.keys())[0]) == str: - id_colors = {n_id: cmap[id] for n_id, id in zip(label.ids, label.str_ids)} - else: - id_colors = {id: cmap[id] for id in label.ids} - else: - id_colors = {id: cmap(id / label.max_id) for id in label.ids} + if target_groups is None: + target_groups = dict["pvalues"].columns.get_level_values(1) + if isinstance(target_groups, str): + target_groups = (target_groups,) - # -------------------------------- Spatial Connections Plot- Setup -------------------------------- # - if shapes_style: - # Cell types/labels will be represented using triangles: - left_triangle = np.array( - ( - (-1.0, 1.0), - # (1., 1.), - (1.0, -1.0), - (-1.0, -1.0), - ) - ) + # Get specified source and target groups from the dictionary: + pvals: pd.DataFrame = dict["pvalues"].loc[:, (source_groups, target_groups)] + means: pd.DataFrame = dict["means"].loc[:, (source_groups, target_groups)] - right_triangle = np.array( - ( - (-1.0, 1.0), - (1.0, 1.0), - (1.0, -1.0), - # (-1., -1.) - ) - ) + if pvals.empty: + raise ValueError("No valid clusters have been selected.") - polygon_list = [] - color_list = [] + means = means[(means >= means_range[0]) & (means <= means_range[1])] + pvals = pvals[pvals <= pvalue_threshold] - ax_sp.set_ylim(-0.55, label.num_labels - 0.45) - ax_sp.set_xlim(-0.55, label.num_labels - 0.45) + if remove_empty_interactions: + pvals, means = filter_values(pvals, means, mask=~(pd.isnull(means) | pd.isnull(pvals)), kind="NaN") + if remove_nonsig_interactions and alpha is not None: + pvals, means = filter_values(pvals, means, mask=pvals <= alpha, kind="non-significant") - for label_1 in range(spatial_connections.shape[0]): - for label_2 in range(spatial_connections.shape[1]): - if label_1 <= label_2: - for triangle in [left_triangle, right_triangle]: - center = np.array((label_1, label_2))[np.newaxis, :] - scale_factor = spatial_connections[label_1, label_2] / spatial_connections_max - offsets = triangle * max_scale * scale_factor - polygon_list.append(center + offsets) + start, label_ranges = 0, {} - color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) + if dendrogram == "interacting_clusters": + # Set rows to be cluster combinations, not LR pairs: + pvals = pvals.T + means = means.T - collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) + for cls, size in (pvals.groupby(level=0, axis=1)).size().to_dict().items(): + label_ranges[cls] = (start, start + size - 1) + start += size + label_ranges = {k: label_ranges[k] for k in sorted(label_ranges.keys())} - ax_sp.add_collection(collection) + pvals = pvals[label_ranges.keys()].astype("float") + # Add minimum value to p-values to avoid value error- 3.0 will be the largest possible value: + pvals = -np.log10(pvals + min(1e-3, alpha if alpha is not None else 1e-3)).fillna(0) - # Remove ticks - ax_sp.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) - ax_sp.xaxis.set_tick_params(pad=-2) - else: - # Heatmap of connection strengths - heatmap = ax_sp.imshow(spatial_connections, cmap=colormap, interpolation="nearest") + pvals.columns = map(" | ".join, pvals.columns.to_flat_index()) + pvals.index = map(" | ".join, pvals.index.to_flat_index()) - divider = make_axes_locatable(ax_sp) - cax = divider.append_axes("right", size="5%", pad=0.1) + means = means[label_ranges.keys()].fillna(0) + means.columns = map(" | ".join, means.columns.to_flat_index()) + means.index = map(" | ".join, means.index.to_flat_index()) + means = np.log2(means + 1) - fig.colorbar(heatmap, cax=cax) - cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) + var = pd.DataFrame(pvals.columns) + var = var.set_index(var.columns[0]) - # Change formatting if values too small - if spatial_connections_max < 0.001: - cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) + # Instantiate new AnnData object containing plot values: + adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var, dtype=pvals.values.dtype) + adata.obs_names = pvals.index + minn = np.nanmin(adata.X) + delta = np.nanmax(adata.X) - minn + adata.X = (adata.X - minn) / delta + # To satisfy conditional check that happens on instantiating dotplot: + adata.uns["__type"] = "UMI" - # Formatting adjustments - ax_sp.set_aspect("equal") + try: + if dendrogram == "both": + row_order, col_order, _, _ = _dendrogram_sig( + adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500 + ) + adata = adata[row_order, :][:, col_order] + pvals = pvals.iloc[row_order, :].iloc[:, col_order] + means = means.iloc[row_order, :].iloc[:, col_order] + elif dendrogram is not None: + adata.uns["dendrogram"] = get_dendrogram(adata) + except Exception as e: + logger.warning(f"Unable to create a dendrogram. Reason: `{e}`. Will display without one.") + dendrogram = None - ax_sp.set_xticks( - np.arange(label.num_labels), - ) - text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None + kwargs["dot_edge_lw"] = 0 + kwargs.setdefault("cmap", "magma") + kwargs.setdefault("grid", True) + kwargs.pop("color_on", None) - # If label has categorical labels associated, use those to label the axes instead: - if label.str_map is not None: - ax_sp.set_xticklabels( - label.str_ids, - fontsize=label_fontsize, - fontweight="bold", - rotation=90, - path_effects=text_outline, - ) - else: - ax_sp.set_xticklabels( - label.ids, - fontsize=label_fontsize, - fontweight="bold", - rotation=0, - path_effects=text_outline, - ) + # Set style and legend kwargs: + dotplot_style_params = {k for k in signature(CCDotplot.style).parameters.keys()} + dotplot_style_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_style_params} + dotplot_legend_params = {k for k in signature(CCDotplot.legend).parameters.keys()} + dotplot_legend_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_legend_params} - ax_sp.set_yticks(np.arange(label.num_labels)) - if label.str_map is not None: - ax_sp.set_yticklabels( - label.str_ids, - fontsize=label_fontsize, - fontweight="bold", - path_effects=text_outline, + dp = ( + CCDotplot( + delta=delta, + minn=minn, + alpha=alpha, + adata=adata, + var_names=adata.var_names, + cat_key="groups", + dot_color_df=means, + dot_size_df=pvals, + title=title, + var_group_labels=None if dendrogram == "both" else list(label_ranges.keys()), + var_group_positions=None if dendrogram == "both" else list(label_ranges.values()), + standard_scale=None, + figsize=figsize, ) - else: - ax_sp.set_yticklabels( - label.ids, - fontsize=label_fontsize, - fontweight="bold", - path_effects=text_outline, + .style(**dotplot_style_kwargs) + .legend( + size_title=r"$-\log_{10} ~ P$", + colorbar_title=r"$log_2(molecule_1 * molecule_2 + 1)$", + **dotplot_legend_kwargs, ) + ) + if dendrogram in ["interacting_molecules", "interacting_clusters"]: + dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram") + if swap_axes: + dp.swap_axes() - for ticklabels in [ax_sp.get_xticklabels(), ax_sp.get_yticklabels()]: - for n, id in enumerate(label.ids): - ticklabels[n].set_color(id_colors[id]) - - title_str_sp = "Spatial Connections" if title_str is None else title_str - ax_sp.set_title(title_str_sp, fontsize=title_fontsize, fontweight="bold") - - # ------------------------------ Optional Gene Expression Connections Plot- Setup ------------------------------ # - if expr_weights_matrix is not None: - if shapes_style: - polygon_list = [] - color_list = [] - - ax_expr.set_ylim(-0.55, label.num_labels - 0.45) - ax_expr.set_xlim(-0.55, label.num_labels - 0.45) - - for label_1 in range(expr_connections.shape[0]): - for label_2 in range(expr_connections.shape[1]): - if label_1 <= label_2: - for triangle in [left_triangle, right_triangle]: - center = np.array((label_1, label_2))[np.newaxis, :] - scale_factor = expr_connections[label_1, label_2] / expr_connections_max - offsets = triangle * max_scale * scale_factor - polygon_list.append(center + offsets) - - color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) - - # Remove ticks - if reverse_expr_plot_orientation: - ax_expr.tick_params( - labelbottom=True, - labeltop=False, - labelleft=False, - labelright=True, - top=False, - bottom=False, - left=False, - ) - # Flip x- and y-axes of the expression plot: - ax_expr.invert_xaxis() - ax_expr.invert_yaxis() - else: - ax_expr.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) - ax_expr.xaxis.set_tick_params(pad=-2) - - collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) + dp.make_figure() - ax_expr.add_collection(collection) + if dendrogram != "both": + # Remove the target part in: source | target + labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels() + for text in labs: + text.set_text(text.get_text().split(" | ")[1]) + if swap_axes: + dp.ax_dict["mainplot_ax"].set_yticklabels(labs) else: - # Heatmap of connection strengths - heatmap = ax_expr.imshow(expr_connections, cmap=colormap, interpolation="nearest") + dp.ax_dict["mainplot_ax"].set_xticklabels(labs) - divider = make_axes_locatable(ax_expr) - cax = divider.append_axes("right", size="5%", pad=0.1) + if alpha is not None: + yy, xx = np.where((pvals.values + alpha) >= -np.log10(alpha)) + if len(xx) and len(yy): + # for dendrogram='both', they are already re-ordered + mapper = ( + np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) + if "dendrogram" in adata.uns + else np.arange(len(pvals)) + ) + logger.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`") + ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot) - fig.colorbar(heatmap, cax=cax) - cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) + yy = np.array([mapper[y] for y in yy]) + if swap_axes: + xx, yy = yy, xx + dp.ax_dict["mainplot_ax"].scatter( + xx + 0.5, + yy + 0.5, + color="white", + edgecolor=kwargs["dot_edge_color"], + linewidth=kwargs["dot_edge_lw"], + s=ss, + lw=0, + ) - # Change formatting if values too small - if spatial_connections_max < 0.001: - cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) + # Save, show or return figures: + return save_return_show_fig_utils( + save_show_or_return=save_show_or_return, + # Doesn't matter what show_legend is for this plotting function + show_legend=False, + background="white", + prefix="dotplot", + save_kwargs=save_kwargs, + total_panels=1, + fig=dp.fig, + axes=dp.ax_dict, + # Return all parameters are for returning multiple values for 'axes', but this function uses a single dictionary + return_all=False, + return_all_list=None, + ) - # Formatting adjustments - ax_expr.set_aspect("equal") - ax_expr.set_xticks( - np.arange(label.num_labels), - ) - if reverse_expr_plot_orientation: - # Despine both spatial connections & gene expression connections plots: - ax_sp.spines["right"].set_visible(False) - ax_sp.spines["top"].set_visible(False) - ax_sp.spines["left"].set_visible(False) - ax_sp.spines["bottom"].set_visible(False) +@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") +def plot_connections( + adata: AnnData, + cat_key: str, + spatial_key: str = "spatial", + n_spatial_neighbors: Union[None, int] = 6, + spatial_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, + expr_weights_matrix: Union[None, scipy.sparse.csr_matrix, np.ndarray] = None, + reverse_expr_plot_orientation: bool = True, + ax: Union[None, mpl.axes.Axes] = None, + figsize: tuple = (3, 3), + zero_self_connections: bool = True, + normalize_by_self_connections: bool = False, + shapes_style: bool = True, + label_outline: bool = False, + max_scale: float = 0.46, + colormap: Union[str, dict, "mpl.colormap"] = "Spectral", + title_str: Union[None, str] = None, + title_fontsize: Union[None, float] = None, + label_fontsize: Union[None, float] = None, + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Optional[dict] = {}, +): + """Plot spatial_connections between labels- visualization of how closely labels are colocalized - ax_expr.spines["right"].set_visible(False) - ax_expr.spines["top"].set_visible(False) - ax_expr.spines["left"].set_visible(False) - ax_expr.spines["bottom"].set_visible(False) + Args: + adata: AnnData object + cat_key: Key in .obs containing categorical grouping labels. Colocalization will be assessed + for pairwise combinations of these labels. + spatial_key: Key in .obsm containing coordinates in the physical space. Not used unless + 'spatial_weights_matrix' is None, in which case this is required. Defaults to "spatial". + n_spatial_neighbors: Optional, number of neighbors in the physical space for each cell. Not used unless + 'spatial_weights_matrix' is None. + spatial_weights_matrix: Spatial distance matrix, weighted by distance between spots. If not given, + will compute at runtime. + expr_weights_matrix: Gene expression distance matrix, weighted by distance in transcriptomic or PCA space. + If not given, only the spatial distance matrix will be plotted. If given, will plot the spatial distance + matrix in the left plot and the gene expression distance matrix in the right plot. + reverse_expr_plot_orientation: If True, plot the gene expression connections in the form of a lower right + triangle. If False, gene expression connections will be an upper left triangle just like the spatial + connections. + ax: Existing axes object, if applicable + figsize: Width x height of desired figure window in inches + zero_self_connections: If True, ignores intra-label interactions + normalize_by_self_connections: Only used if 'zero_self_connections' is False. If True, normalize intra-label + connections by the number of spots of that label + shapes_style: If True plots squares, if False plots heatmap + label_outline: If True, gives dark outline to axis tick label text + max_scale: Only used for the case that 'shape_style' is True, gives maximum size of square + colormap: Specifies colors to use for plotting. If dictionary, keys should be numerical labels corresponding + to those of the Label object. + title_str: Optionally used to give plot a title + title_fontsize: Size of plot title- only used if 'title_str' is given. + label_fontsize: Size of labels along the axes of the graph + save_show_or_return: Whether to save, show or return the figure. + If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed + and the associated axis and other object will be return. + save_kwargs: A dictionary that will passed to the save_fig function. + By default it is an empty dictionary and the save_fig function will use the + {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, + "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those + keys according to your needs. - text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None + Returns: + (fig, ax): Returns plot and axis object if 'save_show_or_return' is "all" + """ + from ...plotting.static.utils import save_fig + from ...tools.utils import update_dict - # If label has categorical labels associated, use those to label the axes instead: - if label.str_map is not None: - ax_expr.set_xticklabels( - label.str_ids, - fontsize=label_fontsize, - fontweight="bold", - rotation=90, - path_effects=text_outline, - ) - else: - ax_expr.set_xticklabels( - label.ids, - fontsize=label_fontsize, - fontweight="bold", - rotation=0, - path_effects=text_outline, - ) + logger = lm.get_main_logger() + config_spateo_rcParams() + title_fontsize = rcParams.get("axes.titlesize") if title_fontsize is None else title_fontsize + label_fontsize = rcParams.get("axes.labelsize") if label_fontsize is None else label_fontsize - ax_expr.set_yticks(np.arange(label.num_labels)) - if label.str_map is not None: - ax_expr.set_yticklabels( - label.str_ids, - fontsize=label_fontsize, - fontweight="bold", - path_effects=text_outline, - ) + if ax is None: + if expr_weights_matrix is not None: + figsize = (figsize[0] * 2.25, figsize[1]) + fig, axes = plt.subplots(1, 2, figsize=figsize) + ax_sp, ax_expr = axes[0], axes[1] + + if reverse_expr_plot_orientation: + # Allow subplot boundaries to technically be partially overlapping (for better visual) + box = ax_expr.get_position() + box.x0 = box.x0 - 0.4 + box.x1 = box.x1 - 0.3 + ax_expr.set_position(box) else: - ax_expr.set_yticklabels( - label.ids, - fontsize=label_fontsize, - fontweight="bold", - path_effects=text_outline, + fig, ax_sp = plt.subplots(1, 1, figsize=figsize) + else: + ax = ax + if len(ax) > 1: + ax_sp, ax_expr = ax[0], ax[1] + else: + ax_sp = ax + fig = ax.get_figure() + + # Convert cell type labels to numerical using Label object: + # Remove cell types with fewer than 30 cells: + logger.info("Filtering out cell types with fewer than 30 cells...") + categories_str = adata.obs[cat_key] + # Count occurrences for each category + category_counts = categories_str.value_counts() + # Filter categories with at least 30 occurrences + filtered_categories = category_counts[category_counts >= 30].index + # Update the AnnData object to include only the filtered categories + adata = adata[categories_str.isin(filtered_categories)].copy() + categories_str_cat = np.unique(adata.obs[cat_key].values) + categories_num_cat = range(len(categories_str_cat)) + map_dict = dict(zip(categories_num_cat, categories_str_cat)) + categories_num = adata.obs[cat_key].replace(categories_str_cat, categories_num_cat) + + # Update expression weights matrix if applicable to only include filtered categories: + if expr_weights_matrix is not None: + mask = categories_str.isin(filtered_categories) + indices_to_retain = np.where(mask)[0] + expr_weights_matrix = expr_weights_matrix[indices_to_retain, :][:, indices_to_retain] + + label = Label(categories_num.to_numpy(), str_map=map_dict) + + # If spatial weights matrix is not given, compute it. 'spatial_key' needs to be present in the AnnData object: + if spatial_weights_matrix is None: + if spatial_key not in adata.obsm_keys(): + logger.error( + f"Given 'spatial_key' {spatial_key} does not exist as key in adata.obsm. Options: " + f"{adata.obsm_keys()}." ) + _, adata = neighbors(adata, basis="spatial", spatial_key=spatial_key, n_neighbors=n_spatial_neighbors) + spatial_weights_matrix = adata.obsp["connectivities"] - for ticklabels in [ax_expr.get_xticklabels(), ax_expr.get_yticklabels()]: - for n, id in enumerate(label.ids): - ticklabels[n].set_color(id_colors[id]) + # Compute spatial connections array: + spatial_connections = interlabel_connections(label, spatial_weights_matrix) - title_str_expr = "Gene Expression Similarity" if title_str is None else title_str - if reverse_expr_plot_orientation: - if label_fontsize <= 8: - y = -0.35 - elif label_fontsize > 8: - y = -0.45 - else: - y = None - ax_expr.set_title(title_str_expr, fontsize=title_fontsize, fontweight="bold", y=y) + if zero_self_connections: + np.fill_diagonal(spatial_connections, 0) + elif normalize_by_self_connections: + spatial_connections /= spatial_connections.diagonal()[:, np.newaxis] - prefix = "spatial_connections" if expr_weights_matrix is None else "spatial_and_expr_connections" - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": prefix, - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - s_kwargs = update_dict(s_kwargs, save_kwargs) + spatial_connections_max = np.amax(spatial_connections) - save_fig(**s_kwargs) + # Optionally, compute gene expression connections array: + if expr_weights_matrix is not None: + expr_connections = interlabel_connections(label, expr_weights_matrix) - elif save_show_or_return in ["show", "both", "all"]: - plt.show() - elif save_show_or_return in ["return", "all"]: - if expr_weights_matrix is not None: - ax = axes - else: - ax = ax_sp - return (fig, ax) + if zero_self_connections: + np.fill_diagonal(expr_connections, 0) + elif normalize_by_self_connections: + expr_connections /= expr_connections.diagonal()[:, np.newaxis] + expr_connections_max = np.amax(expr_connections) -@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") -def ligrec( - adata: AnnData, - dict_key: str, - source_groups: Union[None, str, List[str]] = None, - target_groups: Union[None, str, List[str]] = None, - means_range: Tuple[float, float] = (-np.inf, np.inf), - pvalue_threshold: float = 1.0, - remove_empty_interactions: bool = True, - remove_nonsig_interactions: bool = False, - dendrogram: Union[None, str] = None, - alpha: float = 0.001, - swap_axes: bool = False, - title: Union[None, str] = None, - figsize: Union[None, Tuple[float, float]] = None, - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", - save_kwargs: Optional[dict] = {}, - **kwargs, -): - """ - Dotplot for visualizing results of ligand-receptor interaction analysis + # Set label colors: + if isinstance(colormap, str): + cmap = mpl.cm.get_cmap(colormap) + else: + cmap = colormap - For each L:R pair on the dotplot, molecule 1 is sent from the cluster(s) labeled on the top of the plot (or on the - right, if 'swap_axes' is True), whereas molecule 2 is the receptor on the cluster(s) labeled on the bottom. + # If colormap is given, map label ID to points along the colormap. If dictionary is given, instead map each label + # to a color using the dictionary keys as guides. + if isinstance(cmap, dict): + if type(list(cmap.keys())[0]) == str: + id_colors = {n_id: cmap[id] for n_id, id in zip(label.ids, label.str_ids)} + else: + id_colors = {id: cmap[id] for id in label.ids} + else: + id_colors = {id: cmap(id / label.max_id) for id in label.ids} - Args: - adata: Object of :class `anndata.AnnData` - dict_key: Key in .uns to dictionary containing cell-cell communication information. Should contain keys labeled - "means" and "pvalues", with values being dataframes for the mean cell type-cell type L:R product and - significance values. - source_groups: Source interaction clusters. If `None`, select all clusters. - target_groups: Target interaction clusters. If `None`, select all clusters. - means_range: Only show interactions whose means are within this **closed** interval - pvalue_threshold: Only show interactions with p-value <= `pvalue_threshold` - remove_empty_interactions: Remove rows and columns that contain NaN values - remove_nonsig_interactions: Remove rows and columns that only contain interactions that are larger than `alpha` - dendrogram: How to cluster based on the p-values. Valid options are: - - None (no input) - do not perform clustering. - - `'interacting_molecules'` - cluster the interacting molecules. - - `'interacting_clusters'` - cluster the interacting clusters. - - `'both'` - cluster both rows and columns. Note that in this case, the dendrogram is not shown. - alpha: Significance threshold. All elements with p-values <= `alpha` will be marked by tori instead of dots. - swap_axes: Whether to show the cluster combinations as rows and the interacting pairs as columns - title: Title of the plot - figsize: The width and height of a figure - save_show_or_return: Options: "save", "show", "return", "both", "all" - - "both" for save and show - save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary - and the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', - "transparent": True, "close": True, "verbose": True} as its parameters. But to change any of these - parameters, this dictionary can be used to do so. - kwargs : - Keyword arguments for :func `style` or :func `legend` of :class `Dotplot` - """ - logger = lm.get_main_logger() + # -------------------------------- Spatial Connections Plot- Setup -------------------------------- # + if shapes_style: + # Cell types/labels will be represented using triangles: + left_triangle = np.array( + ( + (-1.0, 1.0), + # (1., 1.), + (1.0, -1.0), + (-1.0, -1.0), + ) + ) - config_spateo_rcParams() - set_pub_style() + right_triangle = np.array( + ( + (-1.0, 1.0), + (1.0, 1.0), + (1.0, -1.0), + # (-1., -1.) + ) + ) - if figsize is None: - figsize = rcParams.get("figure.figsize") + polygon_list = [] + color_list = [] - if title is None: - title = "Ligand-Receptor Inference" + ax_sp.set_ylim(-0.55, label.num_labels - 0.45) + ax_sp.set_xlim(-0.55, label.num_labels - 0.45) - dict = adata.uns[dict_key] + for label_1 in range(spatial_connections.shape[0]): + for label_2 in range(spatial_connections.shape[1]): - def filter_values( - pvals: pd.DataFrame, means: pd.DataFrame, *, mask: pd.DataFrame, kind: str - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - mask_rows = mask.any(axis=1) - pvals = pvals.loc[mask_rows] - means = means.loc[mask_rows] + if label_1 <= label_2: - if pvals.empty: - raise ValueError(f"After removing rows with only {kind} interactions, none remain.") + for triangle in [left_triangle, right_triangle]: + center = np.array((label_1, label_2))[np.newaxis, :] + scale_factor = np.sqrt(spatial_connections[label_1, label_2] / spatial_connections_max) + offsets = triangle * max_scale * scale_factor + polygon_list.append(center + offsets) - mask_cols = mask.any(axis=0) - pvals = pvals.loc[:, mask_cols] - means = means.loc[:, mask_cols] + color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) - if pvals.empty: - raise ValueError(f"After removing columns with only {kind} interactions, none remain.") + collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) - return pvals, means + ax_sp.add_collection(collection) - def get_dendrogram(adata: AnnData, linkage: str = "complete") -> Mapping[str, Any]: - z_var = sch.linkage( - adata.X, - metric="correlation", - method=linkage, - # Unlikely to ever be profiling this many LR pairings, but cap at 1500 - optimal_ordering=adata.n_obs <= 1500, - ) - dendro_info = sch.dendrogram(z_var, labels=adata.obs_names.values, no_plot=True) - # this is what the DotPlot requires - return { - "linkage": z_var, - "cat_key": ["groups"], - "cor_method": "pearson", - "use_rep": None, - "linkage_method": linkage, - "categories_ordered": dendro_info["ivl"], - "categories_idx_ordered": dendro_info["leaves"], - "dendrogram_info": dendro_info, - } + # Remove ticks + ax_sp.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) + ax_sp.xaxis.set_tick_params(pad=-2) + else: + # Heatmap of connection strengths + heatmap = ax_sp.imshow(spatial_connections, cmap=colormap, interpolation="nearest") - if len(means_range) != 2: - logger.error(f"Expected `means_range` to be a sequence of size `2`, found `{len(means_range)}`.") - means_range = tuple(sorted(means_range)) + divider = make_axes_locatable(ax_sp) + cax = divider.append_axes("right", size="5%", pad=0.1) - if alpha is not None and not (0 <= alpha <= 1): - logger.error(f"Expected `alpha` to be in range `[0, 1]`, found `{alpha}`.") + fig.colorbar(heatmap, cax=cax) + cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) - if source_groups is None: - source_groups = dict["pvalues"].columns.get_level_values(0) - elif isinstance(source_groups, str): - source_groups = (source_groups,) + # Change formatting if values too small + if spatial_connections_max < 0.001: + cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) - if target_groups is None: - target_groups = dict["pvalues"].columns.get_level_values(1) - if isinstance(target_groups, str): - target_groups = (target_groups,) + # Formatting adjustments + ax_sp.set_aspect("equal") - # Get specified source and target groups from the dictionary: - pvals: pd.DataFrame = dict["pvalues"].loc[:, (source_groups, target_groups)] - means: pd.DataFrame = dict["means"].loc[:, (source_groups, target_groups)] + ax_sp.set_xticks( + np.arange(label.num_labels), + ) + text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None - if pvals.empty: - raise ValueError("No valid clusters have been selected.") + # If label has categorical labels associated, use those to label the axes instead: + if label.str_map is not None: + ax_sp.set_xticklabels( + label.str_ids, + fontsize=label_fontsize, + fontweight="bold", + rotation=90, + path_effects=text_outline, + ) + else: + ax_sp.set_xticklabels( + label.ids, + fontsize=label_fontsize, + fontweight="bold", + rotation=0, + path_effects=text_outline, + ) + + ax_sp.set_yticks(np.arange(label.num_labels)) + if label.str_map is not None: + ax_sp.set_yticklabels( + label.str_ids, + fontsize=label_fontsize, + fontweight="bold", + path_effects=text_outline, + ) + else: + ax_sp.set_yticklabels( + label.ids, + fontsize=label_fontsize, + fontweight="bold", + path_effects=text_outline, + ) + + for ticklabels in [ax_sp.get_xticklabels(), ax_sp.get_yticklabels()]: + for n, id in enumerate(label.ids): + ticklabels[n].set_color(id_colors[id]) - means = means[(means >= means_range[0]) & (means <= means_range[1])] - pvals = pvals[pvals <= pvalue_threshold] + title_str_sp = "Spatial Connections" if title_str is None else title_str + ax_sp.set_title(title_str_sp, fontsize=title_fontsize, fontweight="bold") - if remove_empty_interactions: - pvals, means = filter_values(pvals, means, mask=~(pd.isnull(means) | pd.isnull(pvals)), kind="NaN") - if remove_nonsig_interactions and alpha is not None: - pvals, means = filter_values(pvals, means, mask=pvals <= alpha, kind="non-significant") + # ------------------------------ Optional Gene Expression Connections Plot- Setup ------------------------------ # + if expr_weights_matrix is not None: + if shapes_style: + polygon_list = [] + color_list = [] - start, label_ranges = 0, {} + ax_expr.set_ylim(-0.55, label.num_labels - 0.45) + ax_expr.set_xlim(-0.55, label.num_labels - 0.45) - if dendrogram == "interacting_clusters": - # Set rows to be cluster combinations, not LR pairs: - pvals = pvals.T - means = means.T + for label_1 in range(expr_connections.shape[0]): + for label_2 in range(expr_connections.shape[1]): - for cls, size in (pvals.groupby(level=0, axis=1)).size().to_dict().items(): - label_ranges[cls] = (start, start + size - 1) - start += size - label_ranges = {k: label_ranges[k] for k in sorted(label_ranges.keys())} + if label_1 <= label_2: + for triangle in [left_triangle, right_triangle]: + center = np.array((label_1, label_2))[np.newaxis, :] + scale_factor = np.sqrt(expr_connections[label_1, label_2] / expr_connections_max) + offsets = triangle * max_scale * scale_factor + polygon_list.append(center + offsets) - pvals = pvals[label_ranges.keys()].astype("float") - # Add minimum value to p-values to avoid value error- 3.0 will be the largest possible value: - pvals = -np.log10(pvals + min(1e-3, alpha if alpha is not None else 1e-3)).fillna(0) + color_list += (id_colors[label.ids[label_2]], id_colors[label.ids[label_1]]) - pvals.columns = map(" | ".join, pvals.columns.to_flat_index()) - pvals.index = map(" | ".join, pvals.index.to_flat_index()) + # Remove ticks + if reverse_expr_plot_orientation: + ax_expr.tick_params( + labelbottom=True, + labeltop=False, + labelleft=False, + labelright=True, + top=False, + bottom=False, + left=False, + ) + # Flip x- and y-axes of the expression plot: + ax_expr.invert_xaxis() + ax_expr.invert_yaxis() + else: + ax_expr.tick_params(labelbottom=False, labeltop=True, top=False, bottom=False, left=False) + ax_expr.xaxis.set_tick_params(pad=-2) - means = means[label_ranges.keys()].fillna(0) - means.columns = map(" | ".join, means.columns.to_flat_index()) - means.index = map(" | ".join, means.index.to_flat_index()) - means = np.log2(means + 1) + collection = PolyCollection(polygon_list, facecolors=color_list, edgecolors="face", linewidths=0) - var = pd.DataFrame(pvals.columns) - var = var.set_index(var.columns[0]) + ax_expr.add_collection(collection) + else: + # Heatmap of connection strengths + heatmap = ax_expr.imshow(expr_connections, cmap=colormap, interpolation="nearest") - # Instantiate new AnnData object containing plot values: - adata = AnnData(pvals.values, obs={"groups": pd.Categorical(pvals.index)}, var=var, dtype=pvals.values.dtype) - adata.obs_names = pvals.index - minn = np.nanmin(adata.X) - delta = np.nanmax(adata.X) - minn - adata.X = (adata.X - minn) / delta - # To satisfy conditional check that happens on instantiating dotplot: - adata.uns["__type"] = "UMI" + divider = make_axes_locatable(ax_expr) + cax = divider.append_axes("right", size="5%", pad=0.1) - try: - if dendrogram == "both": - row_order, col_order, _, _ = _dendrogram_sig( - adata.X, method="complete", metric="correlation", optimal_ordering=adata.n_obs <= 1500 - ) - adata = adata[row_order, :][:, col_order] - pvals = pvals.iloc[row_order, :].iloc[:, col_order] - means = means.iloc[row_order, :].iloc[:, col_order] - elif dendrogram is not None: - adata.uns["dendrogram"] = get_dendrogram(adata) - except Exception as e: - logger.warning(f"Unable to create a dendrogram. Reason: `{e}`. Will display without one.") - dendrogram = None + fig.colorbar(heatmap, cax=cax) + cax.tick_params(axis="both", which="major", labelsize=6, rotation=-45) - kwargs["dot_edge_lw"] = 0 - kwargs.setdefault("cmap", "magma") - kwargs.setdefault("grid", True) - kwargs.pop("color_on", None) + # Change formatting if values too small + if spatial_connections_max < 0.001: + cax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1e}")) - # Set style and legend kwargs: - dotplot_style_params = {k for k in signature(CCDotplot.style).parameters.keys()} - dotplot_style_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_style_params} - dotplot_legend_params = {k for k in signature(CCDotplot.legend).parameters.keys()} - dotplot_legend_kwargs = {k: v for k, v in kwargs.items() if k in dotplot_legend_params} + # Formatting adjustments + ax_expr.set_facecolor("none") + ax_expr.set_aspect("equal") - dp = ( - CCDotplot( - delta=delta, - minn=minn, - alpha=alpha, - adata=adata, - var_names=adata.var_names, - cat_key="groups", - dot_color_df=means, - dot_size_df=pvals, - title=title, - var_group_labels=None if dendrogram == "both" else list(label_ranges.keys()), - var_group_positions=None if dendrogram == "both" else list(label_ranges.values()), - standard_scale=None, - figsize=figsize, - ) - .style(**dotplot_style_kwargs) - .legend( - size_title=r"$-\log_{10} ~ P$", - colorbar_title=r"$log_2(molecule_1 * molecule_2 + 1)$", - **dotplot_legend_kwargs, + ax_expr.set_xticks( + np.arange(label.num_labels), ) - ) - if dendrogram in ["interacting_molecules", "interacting_clusters"]: - dp.add_dendrogram(size=1.6, dendrogram_key="dendrogram") - if swap_axes: - dp.swap_axes() + if reverse_expr_plot_orientation: + # Despine both spatial connections & gene expression connections plots: + ax_sp.spines["right"].set_visible(False) + ax_sp.spines["top"].set_visible(False) + ax_sp.spines["left"].set_visible(False) + ax_sp.spines["bottom"].set_visible(False) - dp.make_figure() + ax_expr.spines["right"].set_visible(False) + ax_expr.spines["top"].set_visible(False) + ax_expr.spines["left"].set_visible(False) + ax_expr.spines["bottom"].set_visible(False) - if dendrogram != "both": - # Remove the target part in: source | target - labs = dp.ax_dict["mainplot_ax"].get_yticklabels() if swap_axes else dp.ax_dict["mainplot_ax"].get_xticklabels() - for text in labs: - text.set_text(text.get_text().split(" | ")[1]) - if swap_axes: - dp.ax_dict["mainplot_ax"].set_yticklabels(labs) - else: - dp.ax_dict["mainplot_ax"].set_xticklabels(labs) + text_outline = [PathEffects.Stroke(linewidth=0.5, foreground="black", alpha=0.8)] if label_outline else None - if alpha is not None: - yy, xx = np.where((pvals.values + alpha) >= -np.log10(alpha)) - if len(xx) and len(yy): - # for dendrogram='both', they are already re-ordered - mapper = ( - np.argsort(adata.uns["dendrogram"]["categories_idx_ordered"]) - if "dendrogram" in adata.uns - else np.arange(len(pvals)) + # If label has categorical labels associated, use those to label the axes instead: + if label.str_map is not None: + ax_expr.set_xticklabels( + label.str_ids, + fontsize=label_fontsize, + fontweight="bold", + rotation=90, + path_effects=text_outline, + ) + else: + ax_expr.set_xticklabels( + label.ids, + fontsize=label_fontsize, + fontweight="bold", + rotation=0, + path_effects=text_outline, ) - logger.info(f"Found `{len(yy)}` significant interactions at level `{alpha}`") - ss = 0.33 * (adata.X[yy, xx] * (dp.largest_dot - dp.smallest_dot) + dp.smallest_dot) - yy = np.array([mapper[y] for y in yy]) - if swap_axes: - xx, yy = yy, xx - dp.ax_dict["mainplot_ax"].scatter( - xx + 0.5, - yy + 0.5, - color="white", - edgecolor=kwargs["dot_edge_color"], - linewidth=kwargs["dot_edge_lw"], - s=ss, - lw=0, + ax_expr.set_yticks(np.arange(label.num_labels)) + if label.str_map is not None: + ax_expr.set_yticklabels( + label.str_ids, + fontsize=label_fontsize, + fontweight="bold", + path_effects=text_outline, + ) + else: + ax_expr.set_yticklabels( + label.ids, + fontsize=label_fontsize, + fontweight="bold", + path_effects=text_outline, ) - # Save, show or return figures: - return save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - # Doesn't matter what show_legend is for this plotting function - show_legend=False, - background="white", - prefix="dotplot", - save_kwargs=save_kwargs, - total_panels=1, - fig=dp.fig, - axes=dp.ax_dict, - # Return all parameters are for returning multiple values for 'axes', but this function uses a single dictionary - return_all=False, - return_all_list=None, - ) + for ticklabels in [ax_expr.get_xticklabels(), ax_expr.get_yticklabels()]: + for n, id in enumerate(label.ids): + ticklabels[n].set_color(id_colors[id]) + + title_str_expr = "Gene Expression Similarity" if title_str is None else title_str + if reverse_expr_plot_orientation: + if label_fontsize <= 8: + y = -0.3 + elif label_fontsize > 8: + y = -0.35 + else: + y = None + ax_expr.set_title(title_str_expr, fontsize=title_fontsize, fontweight="bold", y=y) + + prefix = "spatial_connections" if expr_weights_matrix is None else "spatial_and_expr_connections" + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": prefix, + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, save_kwargs) + + save_fig(**s_kwargs) + + elif save_show_or_return in ["show", "both", "all"]: + plt.show() + elif save_show_or_return in ["return", "all"]: + if expr_weights_matrix is not None: + ax = axes + else: + ax = ax_sp + return (fig, ax) diff --git a/spateo/tools/ST_regression/__init__.py b/spateo/tools/ST_regression/__init__.py deleted file mode 100644 index 9643c2ef..00000000 --- a/spateo/tools/ST_regression/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -General and generalized linear modeling of spatial transcriptomics - -Option to call functions from ST_regression (e.g. st.tl.ST_regression.Niche_Model) or directly from Spateo (e.g. -st.tl.Niche_Model). -""" -from .spatial_regression import ( - Category_Model, - Lagged_Model, - Niche_LR_Model, - Niche_Model, -) diff --git a/spateo/tools/ST_regression/generalized_lm.py b/spateo/tools/ST_regression/generalized_lm.py deleted file mode 100644 index 98bdf0d8..00000000 --- a/spateo/tools/ST_regression/generalized_lm.py +++ /dev/null @@ -1,995 +0,0 @@ -""" -Generalized linear model regression for spatially-aware regression of spatial transcriptomic (gene expression) data. -Rather than assuming the response variable necessarily follows the normal distribution, instead allows the -specification of models whose response variable follows different distributions (e.g. Poisson or Gamma), -although allows also for normal (Gaussian) modeling. -Additionally features capability to perform elastic net regularized regression. -""" -import time -from typing import List, Tuple, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - -import numpy as np -import pandas as pd -from anndata import AnnData -from scipy.sparse import diags, issparse -from scipy.special import expit, loggamma -from sklearn.base import BaseEstimator -from sklearn.model_selection import GridSearchCV -from sklearn.utils.validation import check_array, check_is_fitted, check_X_y - -from ...configuration import SKM -from ...logging import logger_manager as lm -from ...preprocessing.normalize import normalize_total -from ...preprocessing.transform import log1p -from ...tools.find_neighbors import transcriptomic_connectivity -from .regression_utils import L1_L2_penalty, softplus - - -# --------------------------------------------------------------------------------------------------- -# Intermediates for generalized linear modeling -# --------------------------------------------------------------------------------------------------- -def calc_z(beta0: float, beta: np.ndarray, X: np.ndarray, fit_intercept: bool) -> np.ndarray: - """Computes z, an intermediate comprising the result of a linear regression, just before non-linearity is applied. - - Args: - beta0: The intercept - beta: Array of shape [n_features,]; learned model coefficients - X: Array of shape [n_samples, n_features]; input data - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function. - Defaults to True. - - Returns: - z: Array of shape [n_samples, n_features]; prediction of the target values - """ - if fit_intercept: - z = beta0 + np.dot(X, beta) - else: - z = np.dot(X, np.r_[beta0, beta]) - return z - - -def apply_nonlinear( - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - z: np.ndarray, - eta: float, - fit_intercept: bool, -) -> np.ndarray: - """Applies nonlinear operation to linear estimation. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma" - z: Array of shape [n_samples, n_features]; prediction of the target values - eta: A threshold parameter that linearizes the exp() function above threshold eta - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function. - Defaults to True. - - Returns: - nl: An array of size [n_samples, n_features]; result following application of the nonlinear layer - """ - logger = lm.get_main_logger() - - if distr in ["softplus", "gamma", "neg-binomial"]: - nl = softplus(z) - elif distr == "poisson": - nl = z.copy() - beta0 = (1 - eta) * np.exp(eta) if fit_intercept else 0.0 - nl[z > eta] = z[z > eta] * np.exp(eta) + beta0 - nl[z <= eta] = np.exp(z[z <= eta]) - elif distr == "gaussian": - nl = z - - return nl - - -def nonlinear_gradient( - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], z: np.ndarray, eta: float -): - """Derivative of the non-linearity. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma" - z: Array of shape [n_samples, n_features]; prediction of the target values - eta: A threshold parameter that linearizes the exp() function above threshold eta - - Returns: - nl_grad: Array of size [n_samples, n_features]; first derivative of each parameter estimate - """ - logger = lm.get_main_logger() - - if distr in ["softplus", "gamma", "neg-binomial"]: - nl_grad = expit(z) - elif distr == "poisson": - nl_grad = z.copy() - nl_grad[z > eta] = np.ones_like(z)[z > eta] * np.exp(eta) - nl_grad[z <= eta] = np.exp(z[z <= eta]) - elif distr == "gaussian": - nl_grad = np.ones_like(z) - - return nl_grad - - -# --------------------------------------------------------------------------------------------------- -# Gradient -# --------------------------------------------------------------------------------------------------- -def batch_grad( - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - alpha: float, - reg_lambda: float, - X: np.ndarray, - y: np.ndarray, - beta: np.ndarray, - Tau: Union[None, np.ndarray] = None, - eta: float = 2.0, - theta: float = 1.0, - fit_intercept: bool = True, -) -> np.ndarray: - """Computes the gradient (for parameter updating) via batch gradient descent - - Args: - distr: Distribution family- can be "gaussian", "softplus", "poisson", "neg-binomial", or "gamma". Case - sensitive. - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) - term of the loss function - reg_lambda: Regularization parameter :math:`\\lambda` of penalty term - X: Array of shape [n_samples, n_features]; input data - y: Array of shape [n_samples, 1]; labels or targets for the data - beta: Array of shape [n_features,]; learned model coefficients - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - eta: A threshold parameter that linearizes the exp() function above threshold eta - theta: Shape parameter of the negative binomial distribution (number of successes before the first failure). - Used only if 'distr' is "neg-binomial" - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function. - Defaults to True. - - Returns: - g: Gradient for each parameter - """ - n_samples, n_features = X.shape - n_samples = np.float(n_samples) - - if Tau is None: - if fit_intercept: - Tau = np.eye(beta[1:].shape[0]) - else: - Tau = np.eye(beta.shape[0]) - InvCov = np.dot(Tau.T, Tau) - - # Compute linear intermediate, nonlinearity, first derivative of the nonlinearity - z = calc_z(beta[0], beta[1:], X, fit_intercept) - nl = apply_nonlinear(distr, z, eta, fit_intercept) - nl_grad = nonlinear_gradient(distr, z, eta) - - # Initialize gradient: - grad_beta0 = 0.0 - - if distr in ["poisson", "softplus"]: - if fit_intercept: - grad_beta0 = np.sum(nl_grad) - np.sum(y * nl_grad / nl) - grad_beta = (np.dot(nl_grad.T, X) - np.dot((y * nl_grad / nl).T, X)).T - - elif distr == "gamma": - # Degrees of freedom (one because the parameter array is 1D) - nu = 1.0 - grad_logl = (y / nl**2 - 1 / nl) * nl_grad - if fit_intercept: - grad_beta0 = -nu * np.sum(grad_logl) - grad_beta = -nu * np.dot(grad_logl.T, X).T - - elif distr == "neg-binomial": - partial_beta_0 = nl_grad * ((theta + y) / (nl + theta) - y / nl) - if fit_intercept: - grad_beta0 = np.sum(partial_beta_0) - grad_beta = np.dot(partial_beta_0.T, X) - - elif distr == "gaussian": - if fit_intercept: - grad_beta0 = np.sum((nl - y) * nl_grad) - grad_beta = np.dot((nl - y).T, X * nl_grad[:, None]).T - - grad_beta0 *= 1.0 / n_samples - grad_beta *= 1.0 / n_samples - if fit_intercept: - grad_beta += reg_lambda * (1 - alpha) * np.dot(InvCov, beta[1:]) # + reg_lambda * alpha * np.sign(beta[1:]) - g = np.zeros((n_features + 1,)) - g[0] = grad_beta0 - g[1:] = grad_beta - else: - grad_beta += reg_lambda * (1 - alpha) * np.dot(InvCov, beta) # + reg_lambda * alpha * np.sign(beta) - g = grad_beta - - return g - - -# --------------------------------------------------------------------------------------------------- -# Objective function -# --------------------------------------------------------------------------------------------------- -def log_likelihood( - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - y: np.ndarray, - y_hat: Union[np.ndarray, float], - theta: float = 1.0, -) -> float: - """Computes negative log-likelihood of an observation, based on true values and predictions from the regression. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma". Case - sensitive. - y: Target values - y_hat: Predicted values, either array of predictions or scalar value - - Returns: - logL: Numerical value for the log-likelihood - """ - if distr in ["poisson", "softplus"]: - eps = np.spacing(1) - logL = np.sum(y * np.log(y_hat + eps) - y_hat) - - elif distr == "gamma": - nu = 1.0 # shape parameter, exponential for now - logL = np.sum(nu * (-y / y_hat - np.log(y_hat))) - - elif distr == "neg-binomial": - logL = np.sum( - loggamma(y + theta) - - loggamma(theta) - - loggamma(y + 1) - + theta * np.log(theta) - + y * np.log(y_hat) - - (theta + y) * np.log(y_hat + theta) - ) - - elif distr == "gaussian": - logL = -0.5 * np.sum((y - y_hat) ** 2) - - return logL - - -def _loss( - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - alpha: float, - reg_lambda: float, - X: np.ndarray, - y: np.ndarray, - beta: np.ndarray, - Tau: Union[None, np.ndarray] = None, - eta: float = 2.0, - theta: float = 1.0, - fit_intercept: bool = True, -) -> float: - """Objective function, comprised of a combination of the log-likelihood and regularization losses. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma". Case - sensitive. - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) term of the loss function - reg_lambda: Regularization parameter :math:`\\lambda` of penalty term - X: Array of shape [n_samples, n_features]; input data - y: Array of shape [n_samples, 1]; labels or targets for the data - beta: Array of shape [n_features,]; learned model coefficients - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - eta: A threshold parameter that linearizes the exp() function above threshold eta - theta: Shape parameter of the negative binomial distribution (number of successes before the first failure). - Used only if 'distr' is "neg-binomial" - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function. - Defaults to True. - - Returns: - loss: Numerical value for loss - """ - n_samples, n_features = X.shape - z = calc_z(beta[0], beta[1:], X, fit_intercept) - y_hat = apply_nonlinear(distr, z, eta, fit_intercept) - ll = 1.0 / n_samples * log_likelihood(distr, y, y_hat, z, theta) - - if fit_intercept: - P = L1_L2_penalty(alpha, beta[1:], Tau) - # P = 0.5 * (1 - alpha) * L2_penalty(beta[1:], Tau) - else: - P = L1_L2_penalty(alpha, beta, Tau) - # P = 0.5 * (1 - alpha) * L2_penalty(beta, Tau) - - loss = -ll + reg_lambda * P - return loss - - -# --------------------------------------------------------------------------------------------------- -# Custom metric -# --------------------------------------------------------------------------------------------------- -def pseudo_r2( - y: np.ndarray, - yhat: np.ndarray, - ynull_: float, - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - theta: float, -): - """Compute r^2 using log-likelihood, taking into account the observed and predicted distributions as well as the - observed and predicted values. - - Args: - y: Array of shape [n_samples,]; target values for regression - yhat: Predicted targets of shape [n_samples,] - ynull_: Mean of the target labels (null model prediction) - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma". Case - sensitive. - theta: Shape parameter of the negative binomial distribution (number of successes before the first - failure). It is used only if 'distr' is equal to "neg-binomial", otherwise it is ignored. - """ - if distr in ["poisson", "neg-binomial", "softplus"]: - LS = log_likelihood(distr, y, y, theta=theta) - else: - LS = 0 - - L0 = log_likelihood(distr, y, ynull_, theta=theta) - L1 = log_likelihood(distr, y, yhat, theta=theta) - - if distr in ["poisson", "neg-binomial", "softplus"]: - score = 1 - (LS - L1) / (LS - L0) - else: - score = 1 - L1 / L0 - return score - - -def deviance( - y: np.ndarray, - yhat: np.ndarray, - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"], - theta: float, -): - """Deviance goodness-of-fit - - Args: - y: Array of shape [n_samples,]; target values for regression - yhat: Predicted targets of shape [n_samples,] - distr: Distribution family- can be "gaussian", "poisson", "softplus", "neg-binomial", or "gamma". Case - sensitive. - theta: Shape parameter of the negative binomial distribution (number of successes before the first - failure). It is used only if 'distr' is equal to "neg-binomial", otherwise it is ignored. - - Returns: - score: Deviance of the predicted labels - """ - if distr in ["poisson", "neg-binomial", "softplus"]: - LS = log_likelihood(distr, y, y, theta=theta) - else: - LS = 0 - - L1 = log_likelihood(distr, y, yhat, theta=theta) - score = -2 * (L1 - LS) - return score - - -# --------------------------------------------------------------------------------------------------- -# Generalized linear modeling master class -# --------------------------------------------------------------------------------------------------- -class GLM(BaseEstimator): - """Fitting generalized linear models (Gaussian, Poisson, negative binomial, gamma) for modeling gene expression. - - NOTES: 'Tau' is the Tikhonov matrix (a square factorization of the inverse covariance matrix), used to set the - degree to which the algorithm tends towards solutions with smaller norms. If not given, defaults to the ridge ( - L2) penalty. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "neg-binomial", or "gamma". Case sensitive. - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) term of the loss function - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - reg_lambda: Regularization parameter :math:`\\lambda` of penalty term - learning_rate: Governs the magnitude of parameter updates for the gradient descent algorithm - max_iter: Maximum number of iterations for the solver - tol: Convergence threshold or stopping criteria. Optimization loop will stop when relative change in - parameter norm is below the threshold. - eta: A threshold parameter that linearizes the exp() function above eta. - clip_coeffs: Coefficients of lower absolute value than this threshold are set to zero. - score_metric: Scoring metric. Options: - - "deviance": Uses the difference between the saturated (perfectly predictive) model and the true model. - - "pseudo_r2": Uses the coefficient of determination b/w the true and predicted values. - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function - random_seed: Seed of the random number generator used to initialize the solution. Default: 888 - theta: Shape parameter of the negative binomial distribution (number of successes before the first - failure). It is used only if 'distr' is equal to "neg-binomial", otherwise it is ignored. - verbose: If True, will display information about number of iterations until convergence. Defaults to False. - - Attributes: - beta0_: The intercept - beta_: Learned parameters - n_iter: Number of iterations - """ - - def __init__( - self, - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"] = "poisson", - alpha: float = 0.5, - Tau: Union[None, np.ndarray] = None, - reg_lambda: float = 0.1, - learning_rate: float = 0.2, - max_iter: int = 1000, - tol: float = 1e-6, - eta: float = 2.0, - clip_coeffs: float = 0.01, - score_metric: Literal["deviance", "pseudo_r2"] = "deviance", - fit_intercept: bool = True, - random_seed: int = 888, - theta: float = 1.0, - verbose: bool = True, - ): - self.logger = lm.get_main_logger() - allowable_dists = ["gaussian", "poisson", "softplus", "neg-binomial", "gamma"] - if distr not in allowable_dists: - self.logger.error(f"'distr' must be one of {', '.join(allowable_dists)}, got {distr}.") - if not isinstance(max_iter, int): - self.logger.error("'max_iter' must be an integer.") - if not isinstance(fit_intercept, bool): - self.logger.error(f"'fit_intercept' must be Boolean, got {type(fit_intercept)}") - - self.distr = distr - self.alpha = alpha - self.reg_lambda = reg_lambda - self.Tau = Tau - self.learning_rate = learning_rate - self.max_iter = max_iter - self.tol = tol - self.eta = eta - self.clip_coeffs = clip_coeffs - self.score_metric = score_metric - self.fit_intercept = fit_intercept - # Seed into instance of np.random.RandomState - self.random_state = np.random.RandomState(random_seed) - self.theta = theta - self.verbose = verbose - - def __repr__(self): - reg_lambda = self.reg_lambda - s = " thresh) - - def fit(self, X: np.ndarray, y: np.ndarray): - """The fit function. - - Args: - X: 2D array of shape [n_samples, n_features]; input data - y: 1D array of shape [n_samples,]; target data - - Returns: - self: Fitted instance of class GLM - """ - X, y = check_X_y(X, y, accept_sparse=False) - - self.beta0_ = None - self.beta_ = None - self.ynull_ = None - self.n_iter_ = 0 - - if not (isinstance(X, np.ndarray) and isinstance(y, np.ndarray)): - self.logger.error("Input must be ndarray. Got {} and {}".format(type(X), type(y))) - - if X.ndim != 2: - self.logger.error(f"X must be a 2D array, got {X.ndim}") - - if y.ndim != 1: - self.logger.error(f"y must be 1D, got {y.ndim}") - - n_observations, n_features = X.shape - - if n_observations != len(y): - self.logger.error("Shape mismatch." + "X has {} observations, y has {}.".format(n_observations, len(y))) - - # Initialize parameters - beta = np.zeros((n_features + int(self.fit_intercept),)) - if self.fit_intercept: - if self.beta0_ is None and self.beta_ is None: - beta[0] = 1 / (n_features + 1) * self.random_state.normal(0.0, 1.0, 1) - beta[1:] = 1 / (n_features + 1) * self.random_state.normal(0.0, 1.0, (n_features,)) - else: - beta[0] = self.beta0_ - beta[1:] = self.beta_ - - tol = self.tol - alpha = self.alpha - reg_lambda = self.reg_lambda - - self._convergence = list() - train_iterations = range(self.max_iter) - - # Iterative updates - for t in train_iterations: - self.n_iter_ += 1 - beta_old = beta.copy() - grad = batch_grad( - self.distr, alpha, reg_lambda, X, y, beta, self.Tau, self.eta, self.theta, self.fit_intercept - ) - beta = beta - self.learning_rate * grad - - # Apply proximal operator - if self.fit_intercept: - beta[1:] = self._prox(beta[1:], self.learning_rate * reg_lambda * alpha) - else: - beta = self._prox(beta, self.learning_rate * reg_lambda * alpha) - - # Convergence by relative parameter change tolerance - norm_update = np.linalg.norm(beta - beta_old) - norm_update /= np.linalg.norm(beta) - self._convergence.append(norm_update) - if t > 1 and self._convergence[-1] < tol and self.verbose: - self.logger.info("\tParameter update tolerance. " + "Converged in {0:d} iterations".format(t)) - break - - if self.n_iter_ == self.max_iter and self.verbose: - self.logger.warning("Reached max number of iterations without convergence.") - - # Update the estimated variables - if self.fit_intercept: - self.beta0_ = beta[0] - self.beta_ = beta[1:] - else: - self.beta0_ = 0 - self.beta_ = beta - self.ynull_ = np.mean(y) - self.is_fitted_ = True - - # Clip small nonzero values w/ absolute value below the provided threshold: - self.beta_[np.abs(self.beta_) < self.clip_coeffs] = 0 - - return self - - def predict(self, X: np.ndarray) -> np.ndarray: - """Given predictor values, reconstruct expression of dependent/target variables. - - Args: - X: Array of shape [n_samples, n_features]; input data for prediction - - Returns: - yhat: Predicted targets of shape [n_samples,] - """ - X = check_array(X, accept_sparse=False) - check_is_fitted(self, "is_fitted_") - - if not isinstance(X, np.ndarray): - self.logger.error(f"Input data should be of type ndarray (got {type(X)}).") - - # Compute intermediate state and then apply nonlinearity: - z = calc_z(self.beta0_, self.beta_, X, self.fit_intercept) - yhat = apply_nonlinear(self.distr, z, self.eta, self.fit_intercept) - yhat = np.asarray(yhat) - return yhat - - def fit_predict(self, X: np.ndarray, y: np.ndarray): - """Fit the model and predict on the same data. - - Args: - X: array of shape [n_samples, n_features]; input data to fit and predict - y: array of shape [n_samples,]; target values for regression - - Returns: - yhat: Predicted targets of shape [n_samples,] - """ - yhat = self.fit(X, y).predict(X) - return yhat - - def score(self, X: np.ndarray, y: np.ndarray): - """Score model by computing either the deviance or R^2 for predicted values. - - Args: - X: array of shape [n_samples, n_features]; input data to fit and predict - y: array of shape [n_samples,]; target values for regression - - Returns: - score: Value of chosen metric (any pos number for deviance, 0-1 for R^2) - """ - check_is_fitted(self, "is_fitted_") - valid_metrics = ["deviance", "pseudo_r2"] - if self.score_metric not in valid_metrics: - self.logger.error(f"score_metric has to be one of: {','.join(valid_metrics)}") - # Model must be fit before scoring: - if not hasattr(self, "ynull_"): - self.logger.error("Model must be fit before prediction can be scored.") - - y = np.asarray(y).ravel() - yhat = self.predict(X) - - if self.score_metric == "deviance": - score = deviance(y, yhat, self.distr, self.theta) - elif self.score_metric == "pseudo_r2": - score = pseudo_r2(y, yhat, self.ynull_, self.distr, self.theta) - return score - - -class GLMCV(BaseEstimator): - """For estimating regularized generalized linear models (GLM) along a regularization path with warm restarts. - - Args: - distr: Distribution family- can be "gaussian", "poisson", "neg-binomial", or "gamma". Case sensitive. - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) term of the loss function - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - reg_lambda: Regularization parameter :math:`\\lambda` of penalty term - n_lambdas: Number of lambdas along the regularization path. Defaults to 25. - cv: Number of cross-validation repeats - learning_rate: Governs the magnitude of parameter updates for the gradient descent algorithm - max_iter: Maximum number of iterations for the solver - tol: Convergence threshold or stopping criteria. Optimization loop will stop when relative change in - parameter norm is below the threshold. - eta: A threshold parameter that linearizes the exp() function above eta. - clip_coeffs: Absolute value below which to set coefficients to zero. - score_metric: Scoring metric. Options: - - "deviance": Uses the difference between the saturated (perfectly predictive) model and the true model. - - "pseudo_r2": Uses the coefficient of determination b/w the true and predicted values. - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function - random_seed: Seed of the random number generator used to initialize the solution. Default: 888 - theta: Shape parameter of the negative binomial distribution (number of successes before the first - failure). It is used only if 'distr' is equal to "neg-binomial", otherwise it is ignored. - verbose: If True, returns logging information as program runs. Recommended to set to False for any - parallelized processes. - - Attributes: - beta0_: The intercept - beta_: Learned parameters - glm_: The GLM object with the best score - reg_lambda_opt: The value of reg_lambda for the best GLM model - n_iter: Number of iterations - """ - - def __init__( - self, - distr: Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"] = "poisson", - alpha: float = 0.5, - Tau: Union[None, np.ndarray] = None, - reg_lambda: Union[None, List[float]] = None, - n_lambdas: int = 25, - cv: int = 5, - learning_rate: float = 0.2, - max_iter: int = 1000, - tol: float = 1e-6, - eta: float = 2.0, - clip_coeffs: float = 0.01, - score_metric: Literal["deviance", "pseudo_r2"] = "deviance", - fit_intercept: bool = True, - random_seed: int = 888, - theta: float = 1.0, - ): - if reg_lambda is None: - reg_lambda = np.logspace(np.log(0.1), np.log(1e-6), n_lambdas, base=np.exp(1)) - if not isinstance(reg_lambda, (list, np.ndarray)): - reg_lambda = [reg_lambda] - - self.logger = lm.get_main_logger() - allowable_dists = ["gaussian", "poisson", "softplus", "neg-binomial", "gamma"] - if distr not in allowable_dists: - self.logger.error(f"'distr' must be one of {', '.join(allowable_dists)}, got {distr}.") - if not isinstance(max_iter, int): - self.logger.error("'max_iter' must be an integer.") - if not isinstance(fit_intercept, bool): - self.logger.error(f"'fit_intercept' must be Boolean, got {type(fit_intercept)}") - - self.distr = distr - self.alpha = alpha - self.reg_lambda = reg_lambda - self.n_lambdas = n_lambdas - self.cv = cv - self.Tau = Tau - self.learning_rate = learning_rate - self.max_iter = max_iter - self.beta0_ = None - self.beta_ = None - self.reg_lambda_opt_ = None - self.glm_ = None - self.scores_ = None - self.ynull_ = None - self.tol = tol - self.eta = eta - self.clip_coeffs = clip_coeffs - self.theta = theta - self.score_metric = score_metric - self.fit_intercept = fit_intercept - self.random_seed = random_seed - - def __repr__(self): - reg_lambda = self.reg_lambda - s = " 1: - s += "\nlambda: %0.2f to %0.2f\n>" % (reg_lambda[0], reg_lambda[-1]) - else: - s += "\nlambda: %0.2f\n>" % reg_lambda[0] - return s - - def fit(self, X: np.ndarray, y: np.ndarray): - """The fit function. - - Args: - X: 2D array of shape [n_samples, n_features]; input data - y: 1D array of shape [n_samples,]; target data - - Returns: - self: Fitted instance of class GLM - """ - glms, scores = list(), list() - self.ynull_ = np.mean(y) - - idxs = np.arange(y.shape[0]) - np.random.shuffle(idxs) - # Ensure dataset is large enough for cross-validation; if not, adjust number of folds to number of data - # points - 1 for leave-one-out cross validation: - if idxs.shape[0] < self.cv: - self.logger.info( - f"Too few samples for {self.cv}-fold cross-validation- performing leave-one-out cross " - f"validation instead." - ) - n_folds = idxs.shape[0] - 1 - else: - n_folds = self.cv - cv_splits = np.array_split(idxs, n_folds) - - cv_training_iterations = self.reg_lambda - - for idx, rl in enumerate(cv_training_iterations): - glm = GLM( - distr=self.distr, - alpha=self.alpha, - Tau=self.Tau, - reg_lambda=rl, - learning_rate=self.learning_rate, - max_iter=self.max_iter, - tol=self.tol, - eta=self.eta, - clip_coeffs=self.clip_coeffs, - theta=self.theta, - score_metric=self.score_metric, - fit_intercept=self.fit_intercept, - random_seed=self.random_seed, - verbose=False, - ) - - scores_fold = list() - for fold in range(n_folds): - val = cv_splits[fold] - train = np.setdiff1d(idxs, val) - # Initialize parameters: - if idx == 0: - glm.beta0_, glm.beta_ = self.beta0_, self.beta_ - else: - glm.beta0_, glm.beta_ = glms[-1].beta0_, glms[-1].beta_ - - glm.n_iter_ = 0 - glm.fit(X[train], y[train]) - scores_fold.append(glm.score(X[val], y[val])) - avg_score = np.mean(scores_fold) - scores.append(avg_score) - - # Extract final parameters for this value of lambda: - if idx == 0: - glm.beta0_, glm.beta_ = self.beta0_, self.beta_ - else: - glm.beta0_, glm.beta_ = glms[-1].beta0_, glms[-1].beta_ - - glm.n_iter_ = 0 - glm.fit(X, y) - glms.append(glm) - - # Find the lambda that maximizes (for r-squared) or minimizes (for deviance) the scoring metric: - if self.score_metric == "deviance": - opt = np.array(scores).argmin() - opt_score = np.array(scores).min() - elif self.score_metric == "pseudo_r2": - opt = np.array(scores).argmax() - opt_score = np.array(scores).max() - else: - self.logger.error(f"Unknown score_metric: {self.score_metric}") - - self.beta0_, self.beta_ = glms[opt].beta0_, glms[opt].beta_ - self.reg_lambda_opt_ = self.reg_lambda[opt] - self.glm_ = glms[opt] - self.scores_ = scores - # Optimal score: - self.opt_score = opt_score - return self - - def predict(self, X: np.ndarray) -> np.ndarray: - """Using the best scoring model, predict target values. - - Args: - X: Array of shape [n_samples, n_features]; input data for prediction - - Returns: - yhat: Predicted targets based on the model with optimal reg_lambda, of shape [n_samples,] - """ - self.logger = lm.get_main_logger() - - if not hasattr(self, "beta_"): - self.logger.error("Error: model of :class `GLMCV` not yet fitted. Call :func `fit()` method.") - X = check_array(X) - - yhat = self.glm_.predict(X) - return yhat - - def fit_predict(self, X: np.ndarray, y: np.ndarray): - """Fit the model and, after finding the best model, predict on the same data using that model. - - Args: - X: array of shape [n_samples, n_features]; input data to fit and predict - y: array of shape [n_samples,]; target values for regression - - Returns: - yhat: Predicted targets based on the model with optimal reg_lambda, of shape [n_samples,] - """ - self.fit(X, y) - yhat = self.predict(X) - return yhat - - def score(self, X: np.ndarray, y: np.ndarray): - """Score model by computing either the deviance or R^2 for predicted values. - - Args: - X: array of shape [n_samples, n_features]; input data to fit and predict - y: array of shape [n_samples,]; target values for regression - - Returns: - score: Value of chosen metric (any pos number for deviance, 0-1 for R^2) for the optimal reg_lambda - """ - score = self.glm_.score(X, y) - return score - - -# --------------------------------------------------------------------------------------------------- -# Wrapper for GLM CV, with parameter optimization -# --------------------------------------------------------------------------------------------------- -@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") -def fit_glm( - X: Union[np.ndarray, pd.DataFrame], - adata: AnnData, - y_feat, - calc_first_moment: bool = True, - log_transform: bool = True, - gs_params: Union[None, dict] = None, - n_gs_cv: Union[None, int] = None, - return_model: bool = True, - **kwargs, -) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, Union[None, GLMCV]]: - """Wrapper for fitting a generalized elastic net linear model to large biological data, with automated finding of - optimum lambda regularization parameter and optional further grid search for parameter optimization. - - Args: - X: Array or DataFrame containing data for fitting- all columns in this array will be used as independent - variables - adata: AnnData object from which dependent variable gene expression values will be taken from - y_feat: Name of the feature in 'adata' corresponding to the dependent variable - log_transform: If True, will log transform expression. Defaults to True. - calc_first_moment: If True, will alleviate dropout effects by computing the first moment of each gene across - cells, consistent with the method used by the original RNA velocity method (La Manno et al., - 2018). Defaults to True. - gs_params: Optional dictionary where keys are variable names for either the classifier or the regressor and - values are lists of potential values for which to find the best combination using grid search. - Classifier parameters should be given in the following form: 'classifier__{parameter name}'. - n_gs_cv: Number of folds for cross-validation, will only be used if gs_params is not None. If None, - will default to a 5-fold cross-validation. - return_model: If True, returns fitted model. Defaults to True. - kwargs: Additional named arguments that will be provided to :class `GLMCV`. Valid options are: - - distr: Distribution family- can be "gaussian", "poisson", "neg-binomial", or "gamma". Case sensitive. - - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) term of the loss function - - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - - reg_lambda: Regularization parameter :math:`\\lambda` of penalty term - - n_lambdas: Number of lambdas along the regularization path. Only used if 'reg_lambda' is not given. - - cv: Number of cross-validation repeats - - learning_rate: Governs the magnitude of parameter updates for the gradient descent algorithm - - max_iter: Maximum number of iterations for the solver - - tol: Convergence threshold or stopping criteria. Optimization loop will stop when relative change in - parameter norm is below the threshold. - - eta: A threshold parameter that linearizes the exp() function above eta. - - score_metric: Scoring metric. Options: - - "deviance": Uses the difference between the saturated (perfectly predictive) model and the true model. - - "pseudo_r2": Uses the coefficient of determination b/w the true and predicted values. - - fit_intercept: Specifies if a constant (a.k.a. bias or intercept) should be added to the decision function - - random_seed: Seed of the random number generator used to initialize the solution. Default: 888 - - theta: Shape parameter of the negative binomial distribution (number of successes before the first - failure). It is used only if 'distr' is equal to "neg-binomial", otherwise it is ignored. - - Returns: - Beta: Array of shape [n_parameters, 1], contains weight for each parameter - rex: Array of shape [n_samples, 1]. Reconstructed independent variable values. - reg: Instance of regression model. Returned only if 'return_model' is True. - """ - logger = lm.get_main_logger() - if not "distr" in kwargs: - kwargs["distr"] = "poisson" - if not "score_metric" in kwargs: - kwargs["score_metric"] = "pseudo_r2" - - if kwargs["distr"] in ["poisson", "softplus", "neg-binomial"]: - if calc_first_moment or log_transform: - logger.info( - f"With a {kwargs['distr']} assumption, it is recommended to fit to raw counts. Setting all " - f"preprocessing settings to False." - ) - calc_first_moment = False - log_transform = False - - if calc_first_moment: - normalize_total(adata) - _, adata = transcriptomic_connectivity(adata, n_neighbors_method="ball_tree") - conn = adata.obsp["expression_connectivities"] - adata_smooth_norm, _ = calc_1nd_moment(adata.X, conn, normalize_W=True) - adata.layers["M_s"] = adata_smooth_norm - - adata.layers["raw"] = adata.X - adata.X = adata.layers["M_s"] - - if log_transform: - log1p(adata) - - y = adata[:, y_feat].X.toarray() - if isinstance(X, pd.DataFrame): - X = X.values - - # logger.info(" 3: - logger.info( - "Beginning grid search procedure. Temporarily running on reduced range of lambda values for " - "conciseness." - ) - kwargs["reg_lambda"] = [0.1, 1e-4] - else: - logger.info("Beginning grid search procedure.") - - reg = GLMCV(**kwargs) - grid = GridSearchCV(estimator=reg, param_grid=gs_params, cv=n_gs_cv) - grid.fit(X, y) - logger.info(f"Grid search finished for {y_feat}. Elapsed time: {time.time()-start_gs_time}s.") - best_params = grid.best_params_ - msg = f"Grid search best parameters for {y_feat}: " - for k, v in best_params.items(): - msg += f"\n{k}: {v}" - logger.info(msg) - - # Select parameters in the classifier signature to update classifier keyword arguments: - for param, value in best_params.items(): - kwargs[param] = value - # Restore lambda to its original configuration: - kwargs["reg_lambda"] = reg_lambda_given - - reg = GLMCV(**kwargs) - - rex = reg.fit_predict(X, y) - logger.info(f"Optimal lambda regularization value for {y_feat}: {reg.reg_lambda_opt_}.") - intercept = reg.beta0_ - Beta = reg.beta_ - opt_score = reg.opt_score - # Returns: intercept, coefficients, metric for the optimum lambda, reconstruction, optionally model object - if return_model: - return intercept, Beta, opt_score, rex, reg - else: - return intercept, Beta, opt_score, rex - - -def calc_1nd_moment(X, W, normalize_W=True): - if normalize_W: - if type(W) == np.ndarray: - d = np.sum(W, 1).flatten() - else: - d = np.sum(W, 1).A.flatten() - W = diags(1 / d) @ W if issparse(W) else np.diag(1 / d) @ W - return W @ X, W - else: - return W @ X diff --git a/spateo/tools/ST_regression/regression_utils.py b/spateo/tools/ST_regression/regression_utils.py deleted file mode 100644 index dbd75dc2..00000000 --- a/spateo/tools/ST_regression/regression_utils.py +++ /dev/null @@ -1,477 +0,0 @@ -""" -Auxiliary functions to aid in the interpretation functions for the spatial and spatially-lagged regression models. -""" -from typing import List, Tuple, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - -import numpy as np -import pandas as pd -import scipy -import statsmodels.stats.multitest -from anndata import AnnData - -from ...configuration import SKM -from ...logging import logger_manager as lm -from ...preprocessing.transform import log1p - - -# --------------------------------------------------------------------------------------------------- -# Nonlinearity -# --------------------------------------------------------------------------------------------------- -def softplus(z): - """Numerically stable version of log(1 + exp(z)).""" - nl = z.copy() - nl[z > 35] = z[z > 35] - nl[z < -10] = np.exp(z[z < -10]) - nl[(z >= -10) & (z <= 35)] = log1p(np.exp(z[(z >= -10) & (z <= 35)])) - return nl - - -# --------------------------------------------------------------------------------------------------- -# Regularization -# --------------------------------------------------------------------------------------------------- -def L1_penalty(beta: np.ndarray) -> float: - """ - Implementation of the L1 penalty that penalizes based on absolute value of coefficient magnitude. - - Args: - beta: Array of shape [n_features,]; learned model coefficients - - Returns: - L1penalty: float, value for the regularization parameter (typically stylized by lambda) - """ - # Lasso-like penalty- max(sum(abs(beta), axis=0)) - L1penalty = np.linalg.norm(beta, 1) - return L1penalty - - -def L2_penalty(beta: np.ndarray, Tau: Union[None, np.ndarray] = None) -> float: - """Implementation of the L2 penalty that penalizes based on the square of coefficient magnitudes. - - Args: - beta: Array of shape [n_features,]; learned model coefficients - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - """ - if Tau is None: - # Ridge=like penalty - L2penalty = np.linalg.norm(beta, 2) ** 2 - else: - # Tikhonov penalty - if Tau.shape[0] != beta.shape[0] or Tau.shape[1] != beta.shape[0]: - raise ValueError("Tau should be (n_features x n_features)") - else: - L2penalty = np.linalg.norm(np.dot(Tau, beta), 2) ** 2 - - return L2penalty - - -def L1_L2_penalty( - alpha: float, - beta: np.ndarray, - Tau: Union[None, np.ndarray] = None, -) -> float: - """ - Combination of the L1 and L2 penalties. - - Args: - alpha: The weighting between L1 penalty (alpha=1.) and L2 penalty (alpha=0.) term of the loss function. - beta: Array of shape [n_features,]; learned model coefficients - Tau: optional array of shape [n_features, n_features]; the Tikhonov matrix for ridge regression. If not - provided, Tau will default to the identity matrix. - - Returns: - P: Value for the regularization parameter - """ - P = 0.5 * (1 - alpha) * L2_penalty(beta, Tau) + alpha * L1_penalty(beta) - return P - - -# --------------------------------------------------------------------------------------------------- -# Significance Testing -# --------------------------------------------------------------------------------------------------- -def get_fisher_inverse(x: np.ndarray, y: np.ndarray) -> np.ndarray: - """Computes the Fisher matrix that measures the amount of information each feature in x provides about y- that is, - whether the log-likelihood is sensitive to change in the parameter x. - - Function from diffxpy: https://github.com/theislab/diffxpy - - Args: - x: Independent variable array - y: Dependent variable array - - Returns: - inverse_fisher : np.ndarray - """ - - var = np.var(y, axis=0) - fisher = np.expand_dims(np.matmul(x.T, x), axis=0) / np.expand_dims(var, axis=[1, 2]) - - fisher = np.nan_to_num(fisher) - - inverse_fisher = np.array([np.linalg.pinv(fisher[i, :, :]) for i in range(fisher.shape[0])]) - return inverse_fisher - - -def wald_test(theta_mle: np.ndarray, theta_sd: np.ndarray, theta0: Union[float, np.ndarray] = 0) -> np.ndarray: - """Perform single-coefficient Wald test, informing whether a given coefficient deviates significantly from the - supplied reference value (theta0), based on the standard deviation of the posterior of the parameter estimate. - - Function from diffxpy: https://github.com/theislab/diffxpy - - Args: - theta_mle: Maximum likelihood estimation of given parameter by feature - theta_sd: Standard deviation of the maximum likelihood estimation - theta0: Value(s) to test theta_mle against. Must be either a single number or an array w/ equal number of - entries to theta_mle. - - Returns: - pvals : np.ndarray - """ - - if np.size(theta0) == 1: - theta0 = np.broadcast_to(theta0, theta_mle.shape) - - if theta_mle.shape[0] != theta_sd.shape[0]: - raise ValueError("stats.wald_test(): theta_mle and theta_sd have to contain the same number of entries") - if theta0.shape[0] > 1: - if theta_mle.shape[0] != theta0.shape[0]: - raise ValueError("stats.wald_test(): theta_mle and theta0 have to contain the same number of entries") - - theta_sd = np.nextafter(0, np.inf, out=theta_sd, where=theta_sd < np.nextafter(0, np.inf)) - wald_statistic = np.abs(np.divide(theta_mle - theta0, theta_sd)) - pvals = 2 * (1 - scipy.stats.norm(loc=0, scale=1).cdf(wald_statistic)) # two-tailed test - return pvals - - -def multitesting_correction(pvals: np.ndarray, method: str = "fdr_bh", alpha: float = 0.05) -> np.ndarray: - """In the case of testing multiple hypotheses from the same experiment, perform multiple test correction to adjust - q-values. - - Function from diffxpy: https://github.com/theislab/diffxpy - - Args: - pvals: Uncorrected p-values; must be given as a one-dimensional array - method: Method to use for correction. Available methods can be found in the documentation for - statsmodels.stats.multitest.multipletests(), and are also listed below (in correct case) for convenience: - - Named methods: - - bonferroni - - sidak - - holm-sidak - - holm - - simes-hochberg - - hommel - - Abbreviated methods: - - fdr_bh: Benjamini-Hochberg correction - - fdr_by: Benjamini-Yekutieli correction - - fdr_tsbh: Two-stage Benjamini-Hochberg - - fdr_tsbky: Two-stage Benjamini-Krieger-Yekutieli method - alpha: Family-wise error rate (FWER) - - Returns - qval: p-values post-correction - """ - - qval = np.zeros([pvals.shape[0]]) + np.nan - qval[np.isnan(pvals) == False] = statsmodels.stats.multitest.multipletests( - pvals=pvals[np.isnan(pvals) == False], alpha=alpha, method=method, is_sorted=False, returnsorted=False - )[1] - - return qval - - -def get_p_value(variables: np.array, fisher_inv: np.array, coef_loc: int) -> np.ndarray: - """Computes p-value for differential expression for a target feature - - Function from diffxpy: https://github.com/theislab/diffxpy - - Args: - variables: Array where each column corresponds to a feature - fisher_inv: Inverse Fisher information matrix - coef_loc: Numerical column of the array corresponding to the coefficient to test - - Returns: - pvalues: Array of identical shape to variables, where each element is a p-value for that instance of that - feature - """ - - theta_mle = variables[coef_loc] - theta_sd = fisher_inv[:, coef_loc, coef_loc] - theta_sd = np.nextafter(0, np.inf, out=theta_sd, where=theta_sd < np.nextafter(0, np.inf)) - theta_sd = np.sqrt(theta_sd) - - pvalues = wald_test(theta_mle, theta_sd, theta0=0.0) - return pvalues - - -def compute_wald_test( - params: np.ndarray, fisher_inv: np.ndarray, significance_threshold: float = 0.01 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Function from diffxpy: https://github.com/theislab/diffxpy - - Args: - params: Array of shape [n_features, n_params] - fisher_inv: Inverse Fisher information matrix - significance_threshold: Upper threshold to be considered significant - - Returns: - significance: Array of identical shape to variables, where each element is True or False if it meets the - threshold for significance - pvalues: Array of identical shape to variables, where each element is a p-value for that instance of that - feature - qvalues: Array of identical shape to variables, where each element is a q-value for that instance of that - feature - """ - - pvalues = [] - - # Compute p-values for each feature, store in temporary list: - for idx in range(params.T.shape[0]): - pvals = get_p_value(params.T, fisher_inv, idx) - pvalues.append(pvals) - - pvalues = np.concatenate(pvalues) - # Multiple testing correction w/ Benjamini-Hochberg procedure and FWER 0.05 - qvalues = multitesting_correction(pvalues) - pvalues = np.reshape(pvalues, (-1, params.T.shape[1])) - qvalues = np.reshape(qvalues, (-1, params.T.shape[1])) - significance = qvalues < significance_threshold - - return significance, pvalues, qvalues - - -# --------------------------------------------------------------------------------------------------- -# Regression Metrics -# --------------------------------------------------------------------------------------------------- -def mae(y_true, y_pred) -> float: - """Mean absolute error- in this context, actually log1p mean absolute error - - Args: - y_true: Regression model output - y_pred: Observed values for the dependent variable - - Returns: - mae: Mean absolute error value across all samples - """ - abs = np.abs(y_true - y_pred) - mean = np.mean(abs) - return mean - - -def mse(y_true, y_pred) -> float: - """Mean squared error- in this context, actually log1p mean squared error - - Args: - y_true: Regression model output - y_pred: Observed values for the dependent variable - - Returns: - mse: Mean squared error value across all samples - """ - se = np.square(y_true - y_pred) - se = np.mean(se, axis=-1) - return se - - -# --------------------------------------------------------------------------------------------------- -# Testing Model Accuracy -# --------------------------------------------------------------------------------------------------- -@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") -def plot_prior_vs_data( - reconst: pd.DataFrame, - adata: AnnData, - kind: str = "barplot", - target_name: Union[None, str] = None, - title: Union[None, str] = None, - figsize: Union[None, Tuple[float, float]] = None, - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "save", - save_kwargs: dict = {}, -): - """Plots distribution of observed vs. predicted counts in the form of a comparative density barplot. - - Args: - reconst: DataFrame containing values for reconstruction/prediction of targets of a regression model - adata: AnnData object containing observed counts - kind: Kind of plot to generate. Options: "barplot", "scatterplot". Case sensitive, defaults to "barplot". - target_name: Optional, can be: - - Column name in DataFrame/AnnData object: name of gene to subset to - - "sum": computes sum over all features present in 'reconst' to compare to the corresponding subset of - 'adata'. - - "mean": computes mean over all features present in 'reconst' to compare to the corresponding subset of - 'adata'. - If not given, will subset AnnData to features in 'reconst' and flatten both arrays to compare all values. - - If not given, will compute the sum over all - features present in 'reconst' and compare to the corresponding subset of 'adata'. - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - import matplotlib.pyplot as plt - from matplotlib import rcParams - - from ...configuration import config_spateo_rcParams - from ...plotting.static.utils import save_return_show_fig_utils - - logger = lm.get_main_logger() - - config_spateo_rcParams() - if figsize is None: - figsize = rcParams.get("figure.figsize") - - if target_name == "sum": - predicted = reconst.sum(axis=1).values.reshape(-1, 1) - observed = ( - adata[:, reconst.columns].X.toarray() if scipy.sparse.issparse(adata.X) else adata[:, reconst.columns].X - ) - observed = np.sum(observed, axis=1).reshape(-1, 1) - elif target_name == "mean": - predicted = reconst.mean(axis=1).values.reshape(-1, 1) - observed = ( - adata[:, reconst.columns].X.toarray() if scipy.sparse.issparse(adata.X) else adata[:, reconst.columns].X - ) - observed = np.mean(observed, axis=1).reshape(-1, 1) - elif target_name is not None: - observed = adata[:, target_name].X.toarray() if scipy.sparse.issparse(adata.X) else adata[:, target_name].X - observed = observed.reshape(-1, 1) - predicted = reconst[target_name].values.reshape(-1, 1) - else: - # Flatten arrays: - observed = ( - adata[:, reconst.columns].X.toarray() if scipy.sparse.issparse(adata.X) else adata[:, reconst.columns].X - ) - observed = observed.flatten().reshape(-1, 1) - predicted = reconst.values.flatten().reshape(-1, 1) - - obs_pred = np.hstack((observed, predicted)) - # Upper limit along the x-axis (99th percentile to prevent outliers from affecting scale too badly): - xmax = np.percentile(obs_pred, 99) - # Lower limit along the x-axis: - xmin = np.min(observed) - # Divide x-axis into pieces for purposes of setting x labels: - xrange, step = np.linspace(xmin, xmax, num=10, retstep=True) - - fig, ax = plt.subplots(1, 1, figsize=figsize) - if target_name is None: - target_name = "Total Counts" - - if kind == "barplot": - ax.hist( - obs_pred, - xrange, - alpha=0.7, - label=[f"Observed {target_name}", f"Predicted {target_name}"], - density=True, - color=["#FFA07A", "#20B2AA"], - ) - - plt.legend(loc="upper right", fontsize=9) - - ax.set_xticks(ticks=[i + 0.5 * step for i in xrange[:-1]], labels=[np.round(l, 3) for l in xrange[:-1]]) - plt.xlabel("Counts", size=9) - plt.ylabel("Normalized Proportion of Cells", size=9) - if title is not None: - plt.title(title, size=9) - plt.tight_layout() - - elif kind == "scatterplot": - from scipy.stats import spearmanr - - observed = observed.flatten() - predicted = predicted.flatten() - slope, intercept = np.polyfit(observed, predicted, 1) - - # Extract residuals: - predicted_model = np.polyval([slope, intercept], observed) - observed_mean = np.mean(observed) - predicted_mean = np.mean(predicted) - n = observed.size # number of samples - m = 2 # number of parameters - dof = n - m # degrees of freedom - # Students statistic of interval confidence: - t = scipy.stats.t.ppf(0.975, dof) - residual = observed - predicted_model - # Standard deviation of the error: - std_error = (np.sum(residual**2) / dof) ** 0.5 - - # Calculate spearman correlation and coefficient of determination: - s = spearmanr(observed, predicted)[0] - numerator = np.sum((observed - observed_mean) * (predicted - predicted_mean)) - denominator = (np.sum((observed - observed_mean) ** 2) * np.sum((predicted - predicted_mean) ** 2)) ** 0.5 - correlation_coef = numerator / denominator - r2 = correlation_coef**2 - - # Plot best fit line: - observed_line = np.linspace(np.min(observed), np.max(observed), 100) - predicted_line = np.polyval([slope, intercept], observed_line) - - # Confidence interval and prediction interval: - ci = ( - t - * std_error - * (1 / n + (observed_line - observed_mean) ** 2 / np.sum((observed - observed_mean) ** 2)) ** 0.5 - ) - pi = ( - t - * std_error - * (1 + 1 / n + (observed_line - observed_mean) ** 2 / np.sum((observed - observed_mean) ** 2)) ** 0.5 - ) - - ax.plot(observed, predicted, "o", ms=3, color="royalblue", alpha=0.7) - ax.plot(observed_line, predicted_line, color="royalblue", alpha=0.7) - ax.fill_between( - observed_line, predicted_line + pi, predicted_line - pi, color="lightcyan", label="95% prediction interval" - ) - ax.fill_between( - observed_line, predicted_line + ci, predicted_line - ci, color="skyblue", label="95% confidence interval" - ) - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - - ax.set_xlabel(f"Observed {target_name}") - ax.set_ylabel(f"Predicted {target_name}") - title = title if title is not None else "Observed and Predicted {}".format(target_name) - ax.set_title(title) - - # Display r^2, Spearman correlation, mean absolute error on plot as well: - r2s = str(np.round(r2, 2)) - spearman = str(np.round(s, 2)) - ma_err = mae(observed, predicted) - mae_s = str(np.round(ma_err, 2)) - - # Place text at slightly above the minimum x_line value and maximum y_line value to avoid obscuring the plot: - ax.text( - 1.01 * np.min(observed), - 1.01 * np.max(predicted), - "$r^2$ = " + r2s + ", Spearman $r$ = " + spearman + ", MAE = " + mae_s, - fontsize=8, - ) - plt.legend(loc="lower center", bbox_to_anchor=(0.5, -0.4), fontsize=8) - - else: - logger.info( - ":func `plot_prior_vs_data` error: Invalid input given to 'kind'. Options: 'barplot', " "'scatterplot'." - ) - - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="parameters", - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) diff --git a/spateo/tools/ST_regression/spatial_regression.py b/spateo/tools/ST_regression/spatial_regression.py deleted file mode 100644 index ea4a07c4..00000000 --- a/spateo/tools/ST_regression/spatial_regression.py +++ /dev/null @@ -1,2020 +0,0 @@ -""" -Suite of tools for spatially-aware as well as spatially-lagged linear regression - -Also performs downstream characterization following spatially-informed regression to characterize niche impact on gene -expression - -Note to self: current set up --> each of the spatial regression classes can be called either through cell_interaction ( -e.g. st.cell_interaction.NicheModel) or standalone (e.g. st.NicheModel)- the same is true for all -functions besides the general regression ones (e.g. fit_glm, which must be called w/ st.fit_glm). -""" -import os -import time -from itertools import product -from random import sample -from typing import List, Optional, Tuple, Union - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import scipy -import seaborn as sns -from anndata import AnnData -from joblib import Parallel, delayed -from matplotlib import rcParams -from patsy import dmatrix -from scipy.sparse import diags, issparse -from tqdm import tqdm - -from ...configuration import config_spateo_rcParams, shiftedColorMap -from ...logging import logger_manager as lm -from ...plotting.static.utils import save_return_show_fig_utils -from ...preprocessing.normalize import normalize_total -from ...preprocessing.transform import log1p -from ...tools.find_neighbors import construct_nn_graph, transcriptomic_connectivity -from ...tools.utils import update_dict -from .generalized_lm import fit_glm -from .regression_utils import compute_wald_test, get_fisher_inverse - - -# --------------------------------------------------------------------------------------------------- -# Wrapper classes for model running -# --------------------------------------------------------------------------------------------------- -class Base_Model: - """Basis class for all spatially-aware and spatially-lagged regression models that can be implemented through this - toolkit. Includes necessary methods for data loading and preparation, computation of spatial weights matrices, - computation of evaluation metrics and more. - - Args: - adata: object of class `anndata.AnnData` - group_key: Key in .obs where group (e.g. cell type) information can be found - spatial_key: Key in .obsm where x- and y-coordinates are stored - distr: Can optionally provide distribution family to specify the type of model that should be fit at the time - of initializing this class rather than after calling :func `GLMCV_fit_predict`- can be "gaussian", - "poisson", "softplus", "neg-binomial", or "gamma". Case sensitive. - genes: Subset to genes of interest: will be used as dependent variables in non-ligand-based regression analyses, - will be independent variables in ligand-based regression analyses - drop_dummy: Name of the category to be dropped (the "dummy variable") in the regression. The dummy category - can aid in interpretation as all model coefficients can be taken to be in reference to the dummy - category. If None, will randomly select a few samples to constitute the dummy group. - layer: Entry in .layers to use instead of .X when fitting model- all other operations will use .X. - cci_dir: Full path to the directory containing cell-cell communication databases. Only used in the case of - models that use ligands for prediction. - normalize: Perform library size normalization, to set total counts in each cell to the same number (adjust - for cell size) - smooth: To correct for dropout effects, leverage gene expression neighborhoods to smooth expression - log_transform: Set True if log-transformation should be applied to expression (otherwise, will assume - preprocessing/log-transform was computed beforehand) - niche_compute_indicator: Only used if 'mod_type' is "niche" or "niche_lr". If True, for the "niche" model, for - the connections array encoding the cell type-cell type interactions that occur within each niche, - threshold all nonzero values to 1, to reflect the presence of a pairwise cell type interaction. - Otherwise, will fit on the normalized number of pairwise interactions within each niche. For the - "niche_lr" model, for the cell type pair interactions array, threshold all nonzero values to 1 to reflect - the presence of an interaction between the two cell types within each niche. Otherwise, will fit on - normalized data. - weights_mode: Options "knn", "kernel", "band"; sets whether to use K-nearest neighbors, a kernel-based - method, or distance band to compute spatial weights, respectively. - data_id: If given, will save pairwise distance arrays & nearest neighbor arrays to folder in the working - directory, under './neighbors/{data_id}_distance.csv' and './neighbors/{data_id}_adj.csv'. Will also - check for existing files under these names to avoid re-computing these arrays. If not given, will not save. - kwargs : Provides additional spatial weight-finding arguments. Note that these must specifically match the name - that the function will look for (case sensitive). For reference: - - n_neighbors : int - Number of nearest neighbors for KNN - - p : int - Minkowski p-norm for KNN and distance band methods - - distance_metric : str - Pairwise distance metric for KNN - - bandwidth : float or array-like of floats - Sets kernel width for kernel method - - fixed : bool - Allow bandwidth to vary across observations for kernel method - - n_neighbors_bandwidth : int - Number of nearest neighbors for determining bandwidth for kernel method - - kernel_function : str - "triangular", "uniform", "quadratic", "quartic" or "gaussian". Rule for setting how spatial - weight decays with distance - - threshold : float - Distance for which to consider spots "neighbors" for each spot in distance band method (typically - in units of pixels) - - alpha : float - Should be less than 0; can be used to set weights to decay with distance for distance band method - """ - - def __init__( - self, - adata: AnnData, - spatial_key: str = "spatial", - distr: Union[None, Literal["gaussian", "poisson", "softplus", "neg-binomial", "gamma"]] = None, - group_key: Union[None, str] = None, - genes: Union[None, List] = None, - drop_dummy: Union[None, str] = None, - layer: Union[None, str] = None, - cci_dir: Union[None, str] = None, - normalize: bool = True, - smooth: bool = False, - log_transform: bool = False, - niche_compute_indicator: bool = True, - weights_mode: str = "knn", - data_id: Union[None, str] = None, - **kwargs, - ): - self.logger = lm.get_main_logger() - - self.adata = adata - self.cell_names = self.adata.obs_names - # Sort cell type categories (to keep order consistent for downstream applications): - self.celltype_names = sorted(list(set(adata.obs[group_key]))) - - self.spatial_key = spatial_key - self.distr = distr - self.group_key = group_key - self.genes = genes - self.logger.info( - "Note: argument provided to 'genes' represents the dependent variables for non-ligand-based " - "analysis, but are used as independent variables for ligand-based analysis." - ) - self.drop_dummy = drop_dummy - self.layer = layer - self.cci_dir = cci_dir - self.normalize = normalize - self.smooth = smooth - self.log_transform = log_transform - self.niche_compute_indicator = niche_compute_indicator - self.weights_mode = weights_mode - self.data_id = data_id - - # Kwargs can be used to adjust how spatial weights/spatial neighbors are found. - # Default values for all possible spatial weight/spatial neighbor parameters: - self.sp_kwargs = { - "n_neighbors": 10, - "p": 2, - "distance_metric": "euclidean", - "bandwidth": None, - "fixed": True, - "n_neighbors_bandwidth": 6, - "kernel_function": "triangular", - "threshold": 50, - "alpha": -1.0, - } - - # Update using user input: - self.sp_kwargs = update_dict(self.sp_kwargs, kwargs) - - def preprocess_data( - self, - normalize: Union[None, bool] = None, - smooth: Union[None, bool] = None, - log_transform: Union[None, bool] = None, - ): - """Normalization and transformation of input data. Can manually specify whether to normalize, scale, - etc. data- any arguments not given this way will default to values passed on instantiation of the Interpreter - object. - - Returns: - None, all preprocessing operates inplace on the object's input AnnData. - """ - if normalize is None: - normalize = self.normalize - if smooth is None: - smooth = self.smooth - if log_transform is None: - log_transform = self.log_transform - - if self.distr in ["poisson", "softplus", "neg-binomial"]: - if normalize or smooth or log_transform: - self.logger.info( - f"With a {self.distr} assumption, it is recommended to fit to raw counts. Computing normalizations " - f"and transforms if applicable, but storing the results for later and fitting to the raw counts." - ) - self.adata.layers["raw"] = self.adata.X - - # Normalize to size factor: - if normalize: - if self.distr not in ["poisson", "softplus", "neg-binomial"]: - self.logger.info("Setting total counts in each cell to 1e4 inplace.") - normalize_total(self.adata) - else: - self.logger.info("Setting total counts in each cell to 1e4, storing in adata.layers['X_norm'].") - dat = normalize_total(self.adata, inplace=False) - self.adata.layers["X_norm"] = dat["X"] - self.adata.obs["norm_factor"] = dat["norm_factor"] - self.adata.layers["stored_processed"] = dat["X"] - - # Smooth data if 'smooth' is True and log-transform data matrix if 'log_transform' is True: - if smooth: - if self.distr not in ["poisson", "softplus", "neg-binomial"]: - self.logger.info("Smoothing gene expression inplace...") - # Compute connectivity matrix if not already existing: - try: - conn = self.adata.obsp["expression_connectivities"] - except: - _, adata = transcriptomic_connectivity(self.adata, n_neighbors_method="ball_tree") - conn = adata.obsp["expression_connectivities"] - adata_smooth_norm, _ = calc_1nd_moment(self.adata.X, conn, normalize_W=True) - self.adata.layers["M_s"] = adata_smooth_norm - - # Use smoothed layer for downstream processing: - self.adata.layers["raw"] = self.adata.X - self.adata.X = self.adata.layers["M_s"] - - else: - self.logger.info( - "Smoothing gene expression inplace and storing in in adata.layers['M_s'] or " - "adata.layers['normed_M_s'] if normalization was first performed." - ) - adata_temp = self.adata.copy() - # Check if normalized expression is present- if 'distr' is one of the indicated distributions AND - # 'normalize' is True, AnnData will not have been updated in place, with the normalized array - # instead being stored in the object. - try: - adata_temp.X = adata_temp.layers["X_norm"] - norm = True - except: - norm = False - pass - - try: - conn = self.adata.obsp["expression_connectivities"] - except: - _, adata = transcriptomic_connectivity(adata_temp, n_neighbors_method="ball_tree") - conn = adata.obsp["expression_connectivities"] - adata_smooth_norm, _ = calc_1nd_moment(adata_temp.X, conn, normalize_W=True) - if norm: - self.adata.layers["norm_M_s"] = adata_smooth_norm - else: - self.adata.layers["M_s"] = adata_smooth_norm - self.adata.layers["stored_processed"] = adata_smooth_norm - - if log_transform: - if self.distr not in ["poisson", "softplus", "neg-binomial"]: - self.logger.info("Log-transforming expression inplace...") - log1p(self.adata) - else: - self.logger.info( - "Log-transforming expression and storing in adata.layers['X_log1p'], " - "adata.layers['X_norm_log1p'], adata.layers['X_M_s_log1p'], or adata.layers[" - "'X_norm_M_s_log1p'], depending on the normalizations and transforms that were " - "specified." - ) - adata_temp = self.adata.copy() - # Check if normalized expression is present- if 'distr' is one of the indicated distributions AND - # 'normalize' and/or 'smooth' is True, AnnData will not have been updated in place, - # with the normalized array instead being stored in the object. - if "norm_M_s" in adata_temp.layers.keys(): - layer = "norm_M_s" - adata_temp.X = adata_temp.layers["norm_M_s"] - norm, smoothed = True, True - elif "M_s" in adata_temp.layers.keys(): - layer = "M_s" - adata_temp.X = adata_temp.layers["M_s"] - norm, smoothed = False, True - elif "X_norm" in adata_temp.layers.keys(): - layer = "X_norm" - adata_temp.X = adata_temp.layers["X_norm"] - norm, smoothed = True, False - else: - layer = None - norm, smoothed = False, False - - if layer is not None: - log1p(adata_temp.layers[layer]) - else: - log1p(adata_temp) - - if norm and smoothed: - self.adata.layers["X_norm_M_s_log1p"] = adata_temp.X - elif norm: - self.adata.layers["X_norm_log1p"] = adata_temp.X - elif smoothed: - self.adata.layers["X_M_s_log1p"] = adata_temp.X - else: - self.adata.layers["X_log1p"] = adata_temp.X - self.adata.layers["stored_processed"] = adata_temp.X - - def prepare_data( - self, - mod_type: str = "category", - lig: Union[None, List[str]] = None, - rec: Union[None, List[str]] = None, - niche_lr_r_lag: bool = True, - use_ds: bool = True, - rec_ds: Union[None, List[str]] = None, - species: Literal["human", "mouse", "axolotl"] = "human", - ): - """Handles any necessary data preparation, starting from given source AnnData object - - Args: - mod_type: The type of model that will be employed- this dictates how the data will be processed and - prepared. - Options: - - category: spatially-aware, for each sample, computes category prevalence within the spatial - neighborhood and uses these as independent variables - - niche: spatially-aware, uses spatial connections between samples as independent variables - - ligand_lag: spatially-lagged, from database uses select ligand genes to perform regression on - select receptor and/or receptor-downstream genes, and additionally considers neighbor - expression of the ligands - - niche_lr: spatially-aware, uses a coupling of spatial category connections, ligand expression - and receptor expression to perform regression on select receptor-downstream genes - lig: Only used if 'mod_type' contains "ligand". Provides the list of ligands to use as predictors. If not - given, will attempt to subset self.genes - rec: Only used if 'mod_type' contains "ligand". Provides the list of receptors to investigate. If not - given, will search through database for all genes that correspond to the provided genes from 'ligands'. - niche_lr_r_lag: Only used if 'mod_type' is "niche_lr". Uses the spatial lag of the receptor as the - dependent variable rather than each spot's unique receptor expression. Defaults to True. - use_ds: If True, uses receptor-downstream genes in addition to ligands and receptors. - rec_ds: Only used if 'mod_type' is "niche_lr" or "ligand_lag". Can be used to optionally manually define a - list of genes shown to be (or thought to potentially be) downstream of one or more of the provided - L:R pairs. If not given, will find receptor-downstream genes from database based on input to 'lig' - and 'rec'. - species: Selects the cell-cell communication database the relevant ligands will be drawn from. Options: - "human", "mouse", "axolotl". - """ - # Can provide either a single L:R pair or multiple of ligands and/or receptors: - if lig is not None: - if not isinstance(lig, list): - lig = [lig] - if rec is not None: - if not isinstance(rec, list): - rec = [rec] - if rec_ds is not None: - if not isinstance(rec_ds, list): - rec_ds = [rec_ds] - - self.preprocess_data() - - # General preprocessing required by multiple model types (for the models that use cellular niches): - # Convert groups/categories into one-hot array: - group_name = self.adata.obs[self.group_key] - db = pd.DataFrame({"group": group_name}) - categories = np.array(self.adata.obs[self.group_key].unique().tolist()) - db["group"] = pd.Categorical(db["group"], categories=categories) - - self.logger.info("Preparing data: converting categories to one-hot labels for all samples.") - X = pd.get_dummies(data=db, drop_first=False) - # Ensure columns are in order: - X = X.reindex(sorted(X.columns), axis=1) - - # Compute adjacency matrix- use the KNN value in 'sp_kwargs' (which may have been passed as an - # argument when initializing the interpreter): - if self.data_id is not None: - self.logger.info(f"Checking for pre-computed adjacency matrix for dataset {self.data_id}...") - try: - self.adata.obsp["adj"] = pd.read_csv( - os.path.join(os.getcwd(), f"neighbors/{self.data_id}_neighbors.csv"), index_col=0 - ).values - self.logger.info(f"Adjacency matrix loaded from file.") - except: - self.logger.info(f"Pre-computed adjacency matrix not found. Computing adjacency matrix.") - start = time.time() - construct_nn_graph( - self.adata, - spatial_key=self.spatial_key, - n_neighbors=self.sp_kwargs["n_neighbors"], - exclude_self=True, - ) - self.logger.info(f"Computed adjacency matrix, time elapsed: {time.time() - start}s.") - - # Create 'neighbors' directory, if necessary: - if not os.path.exists("./neighbors"): - os.makedirs("./neighbors") - # And save computed adjacency matrix: - self.logger.info(f"Saving adjacency matrix to path neighbors/{self.data_id}_neighbors.csv") - adj = pd.DataFrame(self.adata.obsp["adj"], index=self.adata.obs_names, columns=self.adata.obs_names) - adj.to_csv(os.path.join(os.getcwd(), f"neighbors/{self.data_id}_neighbors.csv")) - else: - self.logger.info(f"Path to pre-computed adjacency matrix not given. Computing adjacency matrix.") - start = time.time() - construct_nn_graph( - self.adata, - spatial_key=self.spatial_key, - n_neighbors=self.sp_kwargs["n_neighbors"], - exclude_self=True, - ) - self.logger.info(f"Computed adjacency matrix, time elapsed: {time.time() - start}s.") - - # Create 'neighbors' directory, if necessary: - if not os.path.exists("./neighbors"): - os.makedirs("./neighbors") - # And save computed adjacency matrix: - self.logger.info(f"Saving adjacency matrix to path neighbors/{self.data_id}_neighbors.csv") - adj = pd.DataFrame(self.adata.obsp["adj"], index=self.adata.obs_names, columns=self.adata.obs_names) - adj.to_csv(os.path.join(os.getcwd(), f"neighbors/{self.data_id}_neighbors.csv")) - - adj = self.adata.obsp["adj"] - - # Construct category adjacency matrix (n_samples x n_categories array that records how many neighbors of - # each category are present within the neighborhood of each sample): - dmat_neighbors = (adj > 0).astype("int").dot(X.values) - - # Construct the category interaction matrix (1D array w/ n_categories ** 2 elements, encodes the niche of - # each sample by documenting the category-category spatial connections within the niche- specifically, - # for each sample, records the category identity of its neighbors in space): - data = {"categories": X, "dmat_neighbours": dmat_neighbors} - connections = np.asarray(dmatrix("categories:dmat_neighbours-1", data)) - - # Set connections array to indicator array: - if self.niche_compute_indicator: - connections[connections > 1] = 1 - else: - connections = (connections - connections.min()) / (connections - connections.max()) - - # Specific preprocessing for each model type: - if "category" in mod_type: - self.logger.info(f"Using {self.group_key} values to predict feature expression...") - # First, convert groups/categories into one-hot array: - group_num = self.adata.obs[self.group_key].value_counts() - max_group, min_group, min_group_ncells = ( - group_num.index[0], - group_num.index[-1], - group_num.values[-1], - ) - - group_name = self.adata.obs[self.group_key] - db = pd.DataFrame({"group": group_name}) - categories = np.array(self.adata.obs[self.group_key].unique().tolist() + ["others"]) - db["group"] = pd.Categorical(db["group"], categories=categories) - - # Solve the dummy variable trap by dropping dummy category (deleting rows to avoid issues with - # multicollinearity): - if self.drop_dummy is None: - # Leave some samples from every group intact - db.iloc[sample(np.arange(self.adata.n_obs).tolist(), min_group_ncells), :] = "others" - elif self.drop_dummy in categories: - group_inds = np.where(db["group"] == self.drop_dummy)[0] - db.iloc[group_inds, :] = "others" - db = db["group"].cat.remove_unused_categories() - else: - raise ValueError( - f"Dummy category ({self.drop_dummy}) provided is not in the " f"adata.obs[{self.group_key}]." - ) - drop_columns = ["group_others"] - - self.logger.info("Preparing data: converting categories to one-hot labels for all samples.") - X = pd.get_dummies(data=db, drop_first=False) - # Ensure columns are in order: - X = X.reindex(sorted(X.columns), axis=1) - - # Construct category adjacency matrix (n_samples x n_categories array that records how many neighbors of - # each category are present within the neighborhood of each sample): - dmat_neighbors = (adj > 0).astype("int").dot(X.values) - self.X = pd.DataFrame(dmat_neighbors, columns=X.columns, index=self.adata.obs_names) - self.X = self.X.reindex(sorted(self.X.columns), axis=1) - self.n_features = self.X.shape[1] - - # To index all but the dummy column when fitting model: - self.variable_names = self.X.columns.difference(drop_columns).to_list() - - # Get the names of all remaining groups: - self.param_labels, group_name = ( - set(group_name).difference([self.drop_dummy]), - group_name.to_list(), - ) - - self.param_labels = list(np.sort(list(self.param_labels))) - - elif mod_type == "niche" or mod_type == "niche_lag": - # If mod_type is 'niche' or 'niche_lag', use the connections matrix as independent variables in the - # regression: - connections_cols = list(product(X.columns, X.columns)) - connections_cols.sort(key=lambda x: x[1]) - connections_cols = [f"{i[0]}-{i[1]}" for i in connections_cols] - self.X = pd.DataFrame(connections, columns=connections_cols, index=self.adata.obs_names) - - # Set self.param_labels to reflect inclusion of the interactions: - self.param_labels = self.variable_names = self.X.columns - - elif "ligand_lag" in mod_type: - ligands = lig - receiving_genes = rec - - # Load signaling network file and find appropriate subset (if 'species' is axolotl, use the human portion - # of the network): - signet = pd.read_csv(os.path.join(self.cci_dir, "human_mouse_signaling_network.csv"), index_col=0) - if species not in ["human", "mouse", "axolotl"]: - self.logger.error("Invalid input to 'species'. Options: 'human', 'mouse', 'axolotl'.") - if species == "axolotl": - species = "human" - axolotl_lr = pd.read_csv(os.path.join(self.cci_dir, "lr_network_axolotl.csv"), index_col=0) - axolotl_l = set(axolotl_lr["human_ligand"]) - sig_net = signet[signet["species"] == species.title()] - lig_available = set(sig_net["src"]) - if "axolotl_l" in locals(): - lig_available = lig_available.union(axolotl_l) - - # Set predictors and target- for consistency with field conventions, set ligands and ligand-downstream - # gene names to uppercase (the AnnData object is assumed to follow this convention as well): - # Use the argument provided to 'ligands' to set the predictor block: - if ligands is None: - ligands = [g for g in self.genes if g in lig_available] - else: - # Filter provided ligands to those that can be found in the database: - ligands = [l for l in ligands if l in lig_available] - self.logger.info("Proceeding with analysis using ligands {}".format(",".join(ligands))) - - # Filter ligands to those that can be found in the database: - ligands = [l for l in ligands if l in self.adata.var_names] - if len(ligands) == 0: - self.logger.error( - "None of the ligands could be found in AnnData variable names. " - "Check that AnnData index names match database entries. " - "Also possible to have selected only ligands that can't be found in AnnData- " - "select different ligands." - ) - self.n_ligands = len(ligands) - - ligands_expr = pd.DataFrame( - self.adata[:, ligands].X.toarray() if scipy.sparse.issparse(self.adata.X) else self.adata[:, ligands].X, - index=self.adata.obs_names, - columns=ligands, - ) - - self.X = ligands_expr - - if receiving_genes is None: - # Append all receptors (direct connections to ligands): - # Note that the database contains repeat L:R pairs- each pair is listed more than once if it is part - # of more than one pathway. Furthermore, if two ligands bind the same receptor, the receptor will be - # listed twice. Since we're looking for just the names, take the set of receptors/downstream genes - # to get only unique molecules: - receptors = set(list(sig_net.loc[sig_net["src"].isin(ligands)]["dest"].values)) - receiving_genes = list(receptors) - self.logger.info( - "List of receptors was not provided- found these receptors from the provided " - f"ligands: {(', ').join(receiving_genes)}" - ) - - if rec_ds is not None: - # If specific list of downstream genes (indirect connections to ligands) is provided, append: - receiving_genes.extend(rec_ds) - elif use_ds: - # Optionally append all downstream genes from the database (direct connections to receptors, - # indirect connections to ligands): - self.logger.info( - "Downstream genes were not manually provided with 'rec_ds'...automatically " - "searching for downstream genes associated with the discovered 'receptors'." - ) - receiver_ds = list(set(list(sig_net.loc[sig_net["src"].isin(receiving_genes)]["dest"].values))) - self.logger.info( - "List of receptor-downstream genes was not provided- found these genes from the " - f"current list of receivers: {(', ').join(receiver_ds)}" - ) - receiving_genes.extend(receiver_ds) - receiving_genes = list(set(receiving_genes)) - - # Filter receiving genes for those that can be found in the dataset: - receiving_genes = [r for r in receiving_genes if r in self.adata.var_names] - - self.genes = receiving_genes - - # All ligands will have associated parameters and be used as variables in the model - self.param_labels = self.variable_names = ligands - - elif mod_type == "niche_lr": - # Load LR database based on input to 'species': - if species == "human": - lr_network = pd.read_csv(os.path.join(self.cci_dir, "lr_network_human.csv"), index_col=0) - elif species == "mouse": - lr_network = pd.read_csv(os.path.join(self.cci_dir, "lr_network_mouse.csv"), index_col=0) - elif species == "axolotl": - lr_network = pd.read_csv(os.path.join(self.cci_dir, "lr_network_axolotl.csv"), index_col=0) - else: - self.logger.error("Invalid input given to 'species'. Options: 'human', 'mouse', or 'axolotl'.") - - if lig is None: - self.logger.error("For 'mod_type' = 'niche_lr', ligands must be provided.") - - # If no receptors are given, search database for matches w/ the ligand: - if rec is None: - rec = set(list(lr_network.loc[lr_network["from"].isin(lig)]["to"].values)) - - self.logger.info( - "List of receptors was not provided- found these receptors from the provided " - f"ligands: {(', ').join(rec)}" - ) - - # Filter ligand and receptor lists to those that can be found in the data: - lig = [l for l in lig if l in self.adata.var_names] - if len(lig) == 0: - self.logger.error( - "None of the ligands could be found in AnnData variable names. " - "Check that AnnData index names match database entries." - "Also possible to have selected only ligands that can't be found in AnnData- " - "select different ligands." - ) - rec = [r for r in rec if r in self.adata.var_names] - - # Convert groups/categories into one-hot array: - group_name = self.adata.obs[self.group_key] - db = pd.DataFrame({"group": group_name}) - categories = np.array(self.adata.obs[self.group_key].unique().tolist()) - n_categories = len(categories) - db["group"] = pd.Categorical(db["group"], categories=categories) - - self.logger.info("Preparing data: converting categories to one-hot labels for all samples.") - X = pd.get_dummies(data=db, drop_first=False) - # Ensure columns are in order: - X = X.reindex(sorted(X.columns), axis=1) - - # 'l' and 'r' must be matched, and so must be the same length, unless it is a case of one ligand that can - # bind multiple receptors or vice versa: - if len(lig) != len(rec): - self.logger.warning( - "Length of the provided list of ligands (input to 'l') does not match the length " - "of the provided list of receptors (input to 'r'). This is fine, so long as all ligands and " - "all receptors have at least one match in the other list." - ) - - pairs = [] - # This analysis takes ligand and receptor expression to predict expression of downstream genes- make sure - # (1) input to 'r' are listed as receptors in the appropriate database, and (2) for each input to 'r', - # there is a matched input in 'l': - for ligand in lig: - lig_key = "from" if species != "axolotl" else "human_ligand" - rec_key = "to" if species != "axolotl" else "human_receptor" - possible_receptors = set(lr_network.loc[lr_network[lig_key] == ligand][rec_key]) - - if not any(receptor in possible_receptors for receptor in rec): - self.logger.error( - "No record of {} interaction with any of {}. Ensure provided lists contain " - "paired ligand-receptors.".format(ligand, (",".join(rec))) - ) - found_receptors = list(set(possible_receptors).intersection(set(rec))) - lig_pairs = list(product([ligand], found_receptors)) - pairs.extend(lig_pairs) - self.n_pairs = len(pairs) - print(f"Setting up Niche-L:R model using the following ligand-receptor pairs: {pairs}") - - self.logger.info( - "Starting from {} ligands and {} receptors, found {} ligand-receptor " - "pairs.".format(len(lig), len(rec), len(pairs)) - ) - - # Since features are combinatorial, it is not recommended to specify more than too many ligand-receptor - # pairs: - if len(pairs) > 200 / n_categories**2: - self.logger.warning( - "Regression model has many predictors- consider measuring fewer ligands and receptors." - ) - - # Each ligand-receptor pair will have an associated niche matrix: - self.niche_mats = {} - - # Copy of AnnData to avoid modifying in-place: - expr = self.adata.copy() - # Look for normalized and/or transformed values if "poisson", "softplus" or "neg-binomial" were given as the - # distribution to fit to- Niche LR dependent variables draw from the gene expression: - try: - expr.X = expr.layers["stored_processed"] - except: - pass - - for lr_pair in pairs: - lig, rec = lr_pair[0], lr_pair[1] - lig_expr_values = expr[:, lig].X.toarray() if scipy.sparse.issparse(expr.X) else expr[:, lig].X - rec_expr_values = expr[:, rec].X.toarray() if scipy.sparse.issparse(expr.X) else expr[:, rec].X - # Optionally, compute the spatial lag of the receptor: - if niche_lr_r_lag: - if not hasattr(self, "w"): - self.compute_spatial_weights() - from pysal.model import spreg - - rec_lag = spreg.utils.lag_spatial(self.w, rec_expr_values) - expr.obs[f"{rec}_lag"] = rec_lag - # Multiply one-hot category array by the expression of select receptor within that cell: - if not niche_lr_r_lag: - rec_vals = rec_expr_values - else: - rec_vals = expr.obs[f"{rec}_lag"].values - rec_expr = np.multiply(X.values, np.tile(rec_vals.reshape(-1, 1), X.shape[1])) - - # Separately multiply by the expression of select ligand such that an expression value only exists - # for one cell type per row: - lig_vals = lig_expr_values - lig_expr = np.multiply(X.values, np.tile(lig_vals, X.shape[1])) - # Multiply adjacency matrix by the cell-specific expression of select ligand: - nbhd_lig_expr = (adj > 0).astype("int").dot(lig_expr) - - # Construct the category interaction matrix (1D array w/ n_categories ** 2 elements, encodes the - # ligand-receptor niches of each sample by documenting the cell type-specific L:R enrichment within - # the niche: - data = {"category_rec_expr": rec_expr, "neighborhood_lig_expr": nbhd_lig_expr} - lr_connections = np.asarray(dmatrix("category_rec_expr:neighborhood_lig_expr-1", data)) - - lr_connections_cols = list(product(X.columns, X.columns)) - lr_connections_cols.sort(key=lambda x: x[1]) - # Swap sending & receiving cell types because we're looking at receptor expression in the "source" cell - # and ligand expression in the surrounding cells. - lr_connections_cols = [f"{i[1]}-{i[0]}_{lig}-{rec}" for i in lr_connections_cols] - self.niche_mats[f"{lig}-{rec}"] = pd.DataFrame(lr_connections, columns=lr_connections_cols) - self.niche_mats = {key: value for key, value in sorted(self.niche_mats.items())} - - # Define set of variables to regress on- genes downstream of the receptor. Can use custom provided list - # or create the list from database: - if rec_ds is not None: - # If specific list of downstream genes (indirect connections to ligands) is provided, append: - ds = rec_ds - else: - # Optionally append all downstream genes from the database (direct connections to receptors, - # indirect connections to ligands): - receptors = set([pair[1] for pair in pairs]) - signet = pd.read_csv(os.path.join(self.cci_dir, "human_mouse_signaling_network.csv"), index_col=0) - if species == "axolotl": - species = "human" - sig_net = signet[signet["species"] == species.title()] - - receiver_ds = set(list(sig_net.loc[sig_net["src"].isin(receptors)]["dest"].values)) - ds = list(receiver_ds) - self.logger.info( - "List of receptor-downstream genes was not provided- found these genes from the " - f"provided receptors: {(', ').join(ds)}" - ) - self.genes = ds - self.X = pd.concat(self.niche_mats, axis=1) - self.X.columns = self.X.columns.droplevel() - self.X.index = self.adata.obs_names - # Drop all-zero columns (represent cell type pairs with no spatial coupled L/R expression): - self.X = self.X.loc[:, (self.X != 0).any(axis=0)] - - if self.niche_compute_indicator: - self.X[self.X > 0] = 1 - else: - # Minmax-scale columns to minimize the external impact of intercellular differences in ligand/receptor - # expression: - self.X = (self.X - self.X.min()) / (self.X.max() - self.X.min()) - - self.param_labels = self.variable_names = self.X.columns - - else: - self.logger.error("Invalid argument to 'mod_type'.") - - # Save model type as an attribute so it can be accessed by other methods: - self.mod_type = mod_type - - # If 'genes' is given, can take the minimum necessary portion of AnnData object- otherwise, use all genes: - if self.genes is not None: - self.genes = list(self.adata.var.index.intersection(self.genes)) - else: - self.genes = list(self.adata.var.index) - # self.adata = self.adata[:, self.genes] - - def compute_spatial_weights(self): - """Generates matrix of pairwise spatial distances, used in spatially-lagged models""" - # Choose how weights are computed: - from pysal.lib import weights - - if self.weights_mode == "knn": - self.w = weights.distance.KNN.from_array(self.adata.obsm[self.spatial_key], k=self.sp_kwargs["n_neighbors"]) - elif self.weights_mode == "kernel": - self.w = weights.distance.Kernel.from_array( - self.adata.obsm[self.spatial_key], - bandwidth=self.sp_kwargs["bandwidth"], - fixed=self.sp_kwargs["fixed"], - k=self.sp_kwargs["n_neighbors_bandwidth"], - function=self.sp_kwargs["kernel_function"], - ) - elif self.weights_mode == "band": - self.w = weights.distance.DistanceBand.from_array( - self.adata.obsm[self.spatial_key], threshold=self.sp_kwargs["threshold"], alpha=self.sp_kwargs["alpha"] - ) - else: - self.logger.error("Invalid input to 'weights_mode'. Options: 'knn', 'kernel', 'band'.") - - # Row standardize spatial weights matrix: - self.w.transform = "R" - - # --------------------------------------------------------------------------------------------------- - # Computing parameters for spatially-aware and lagged models- generalized linear models - # --------------------------------------------------------------------------------------------------- - def GLMCV_fit_predict( - self, - gs_params: Union[None, dict] = None, - n_gs_cv: Union[None, int] = None, - n_jobs: int = 30, - cat_key: Union[None, str] = None, - categories: Union[None, str, List[str]] = None, - **kwargs, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Wrapper for fitting predictive generalized linear regression model. - - Args: - gs_params: Optional dictionary where keys are variable names for the regressor and - values are lists of potential values for which to find the best combination using grid search. - Classifier parameters should be given in the following form: 'classifier__{parameter name}'. - n_gs_cv: Number of folds for grid search cross-validation, will only be used if gs_params is not None. If - None, will default to a 5-fold cross-validation. - n_jobs: For parallel processing, number of tasks to run at once - cat_key: Optional, name of key in .obs containing categorical (e.g. cell type) information - categories: Optional, names of categories to subset to for the regression. In cases where the exogenous - block is exceptionally heterogenous, can be used to narrow down the search space. - kwargs: Additional named arguments that will be provided to :class `GLMCV`. - - Returns: - coeffs: Contains fitted parameters for each feature - reconst: Contains predicted expression for each feature - """ - X = self.X[self.variable_names] - kwargs["distr"] = self.distr - - if categories is not None: - self.categories = categories - # Flag indicating that resultant parameters matrix is not pairwise (i.e. that there's not one parameter - # for each cell type combination): - self.square = False - if not isinstance(self.categories, list): - self.categories = [self.categories] - - if cat_key is None: - self.logger.error( - ":func `GLMCV_fit_predict` error: 'Categories' were given, but not 'cat_key' " - "specifying where in .obs to look." - ) - # Filter adata for rows annotated as being any category in 'categories', and X block for columns annotated with - # any of the categories in 'categories'. - self.adata = self.adata[self.adata.obs[cat_key].isin(self.categories)] - self.cell_names = self.adata.obs_names - X = X.filter(regex="|".join(self.categories)) - X = X.loc[self.adata.obs_names] - else: - self.square = True - - # Set preprocessing parameters to False- :func `prepare_data` handles these steps. - results = Parallel(n_jobs)( - delayed(fit_glm)( - X, - self.adata, - cur_g, - calc_first_moment=False, - log_transform=False, - gs_params=gs_params, - n_gs_cv=n_gs_cv, - return_model=False, - **kwargs, - ) - for cur_g in self.genes - ) - intercepts = [item[0] for item in results] - coeffs = [item[1] for item in results] - opt_scores = [item[2] for item in results] - reconst = [item[3] for item in results] - - coeffs = pd.DataFrame(coeffs, index=self.genes, columns=X.columns) - for cn in coeffs.columns: - self.adata.var.loc[:, cn] = coeffs[cn] - self.adata.uns["pseudo_r2"] = dict(zip(self.genes, opt_scores)) - self.adata.uns["intercepts"] = dict(zip(self.genes, intercepts)) - # Nested list transforms into dataframe rows- instantiate and transpose to get to correct shape: - reconst = pd.DataFrame(reconst, index=self.genes, columns=self.cell_names).T - return coeffs, reconst - - # --------------------------------------------------------------------------------------------------- - # Downstream interpretation - # --------------------------------------------------------------------------------------------------- - def visualize_params( - self, - coeffs: pd.DataFrame, - subset_cols: Union[None, str, List[str]] = None, - cmap: str = "autumn", - zero_center_cmap: bool = False, - mask_threshold: Union[None, float] = None, - mask_zero: bool = True, - transpose: bool = False, - title: Union[None, str] = None, - xlabel: Union[None, str] = None, - ylabel: Union[None, str] = None, - figsize: Union[None, Tuple[float, float]] = None, - annot_kws: dict = {}, - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "save", - save_kwargs: dict = {}, - ): - """Generates heatmap of parameter values for visualization - - Args: - coeffs: Contains coefficients (and any other relevant statistics that were computed) from regression for - each variable - subset_cols: String or list of strings that can be used to subset coeffs DataFrame such that only columns - with names containing the provided key strings are displayed on heatmap. For example, can use "coeff" to - plot only the linear regression coefficients, "zstat" for the z-statistic, etc. Or can use the full - name of the column to select specific columns. - cmap: Name of the colormap to use - zero_center_cmap: Set True to set colormap intensity midpoint to zero. - mask_threshold: Optional, sets lower absolute value thresholds for parameters to be assigned color in - heatmap (will compare absolute value of each element against this threshold) - mask_zero: Set True to not assign color to zeros (representing neither a positive or negative interaction) - transpose: Set True to reverse the dataframe's orientation before plotting - title: Optional, provides title for plot. If not given, will use default "Spatial Parameters". - xlabel: Optional, provides label for x-axis. If not given, will use default "Predictor Features". - ylabel: Optional, provides label for y-axis. If not given, will use default "Target Features". - figsize: Can be used to set width and height of figure window, in inches. If not given, will use Spateo - default. - annot_kws: Optional dictionary that can be used to set qualities of the axis/tick labels. For example, - can set 'size': 9, 'weight': 'bold', etc. - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - config_spateo_rcParams() - if figsize is None: - self.figsize = rcParams.get("figure.figsize") - else: - self.figsize = figsize - if len(annot_kws) == 0: - annot_kws = {"size": 6, "weight": "bold"} - - # Reformat column names for better visual: - coeffs.columns = coeffs.columns.str.replace("group_", "") - coeffs.columns = coeffs.columns.str.replace("_", ":") - - if subset_cols is not None: - if isinstance(subset_cols, str): - subset_cols = [subset_cols] - col_subset = [col for col in coeffs.columns if any(key in col for key in subset_cols)] - coeffs = coeffs[col_subset] - # Remove rows with no nonzero values: - coeffs = coeffs.loc[(coeffs != 0).any(axis=1)] - - if transpose: - coeffs = coeffs.T - - if mask_threshold is not None: - mask = np.abs(coeffs) < mask_threshold - # Drop columns in which all elements fail to meet mask threshold criteria (then recompute mask w/ - # potentially smaller dataframe): - coeffs = coeffs.loc[:, (mask == 0).any(axis=0)] - mask = np.abs(coeffs) < mask_threshold - elif mask_zero: - mask = coeffs == 0 - # Drop columns in which all elements fail to meet mask threshold criteria (then recompute mask w/ - # potentially smaller dataframe): - coeffs = coeffs.loc[:, (mask == 0).any(axis=0)] - mask = coeffs == 0 - else: - mask = None - - # If "zero_center_cmap", find percentile corresponding to zero and set colormap midpoint to this value: - if zero_center_cmap: - cmap = plt.get_cmap(cmap) - coeffs_max = np.max(coeffs.values) - zero_point = 1 - coeffs_max / (coeffs_max + abs(np.min(coeffs.values))) - print(zero_point) - cmap = shiftedColorMap(cmap, midpoint=zero_point) - - xtick_labels = list(coeffs.columns) - ytick_labels = list(coeffs.index) - - fig, ax = plt.subplots(1, 1, figsize=self.figsize) - res = sns.heatmap( - coeffs, - cmap=cmap, - square=True, - yticklabels=ytick_labels, - linecolor="grey", - linewidths=0.3, - annot_kws=annot_kws, - xticklabels=xtick_labels, - mask=mask, - ax=ax, - ) - # Outer frame: - for _, spine in res.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - - plt.title(title if title is not None else "Spatial Parameters") - if xlabel is not None: - plt.xlabel(xlabel, size=6) - if ylabel is not None: - plt.ylabel(ylabel, size=6) - ax.set_xticklabels(xtick_labels, rotation=90, ha="center") - plt.tight_layout() - - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="parameters", - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) - - def compute_coeff_significance( - self, - coeffs: pd.DataFrame, - significance_threshold: float = 0.05, - only_positive: bool = False, - only_negative: bool = False, - ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - """ - Computes statistical significance for fitted coefficients. - - Args: - coeffs: Contains coefficients from regression for each variable - significance_threshold: p-value needed to call a sender-receiver relationship significant - only_positive: Set True to find significance/pvalues/qvalues only for the subset of coefficients that is - positive (representing possible mechanisms of positive regulation). - only_negative: Set True to find significance/pvalues/qvalues only for the subset of coefficients that is - negative (representing possible mechanisms of positive regulation). - - Returns: - is_significant: Dataframe of identical shape to coeffs, where each element is True or False if it meets the - threshold for significance - pvalues: Dataframe of identical shape to coeffs, where each element is a p-value for that instance of that - feature - qvalues: Dataframe of identical shape to coeffs, where each element is a q-value for that instance of that - feature - """ - # If Poisson or softplus, use log-transformed values for downstream applications (model ultimately uses a - # linear combination of independent variables to predict log-transformed dependent): - if self.distr in ["poisson", "softplus"]: - try: - log_key = [key for key in self.adata.layers.keys() if "log1p" in key][0] - self.adata.X = self.adata.layers[log_key] - self.logger.info( - "With Poisson distribution assumed for dependent variable, using log-transformed data " - f"to compute sender-receiver effects...log key found in adata under key {log_key}." - ) - except: - self.logger.info( - "With Poisson distribution assumed for dependent variable, using log-transformed data " - "to compute sender-receiver effects...log key not found in adata, manually computing." - ) - log1p(self.adata) - self.logger.info("Data log-transformed.") - - coeffs_np = coeffs.values - # Save labels of indices and columns (correspond to features & parameters, respectively, for the coeffs - # DataFrame, will be columns & indices respectively for the other arrays generated by this function): - feature_labels = coeffs.index - param_labels = coeffs.columns - - # Get inverse Fisher information matrix, with the y block containing all features that were used in regression): - y = ( - self.adata[:, self.genes].X.toarray() - if scipy.sparse.issparse(self.adata.X) - else self.adata[:, self.genes].X - ) - inverse_fisher = get_fisher_inverse(self.X.values, y) - - # Compute significance for each parameter: - is_significant, pvalues, qvalues = compute_wald_test( - params=coeffs_np, fisher_inv=inverse_fisher, significance_threshold=significance_threshold - ) - - is_significant = pd.DataFrame(is_significant, index=param_labels, columns=feature_labels) - pvalues = pd.DataFrame(pvalues, index=param_labels, columns=feature_labels) - qvalues = pd.DataFrame(qvalues, index=param_labels, columns=feature_labels) - - # If 'only_positive' or 'only_negative' are set, set all elements corresponding to negative or positive - # coefficients (respectively) to False and all pvalues/qvalues to 1: - if only_positive: - is_significant[coeffs.T <= 0] = False - pvalues[coeffs.T <= 0] = 1 - qvalues[coeffs.T <= 0] = 1 - elif only_negative: - is_significant[coeffs.T >= 0] = False - pvalues[coeffs.T >= 0] = 1 - qvalues[coeffs.T >= 0] = 1 - - return is_significant, pvalues, qvalues - - def get_effect_sizes( - self, - coeffs: pd.DataFrame, - only_positive: bool = False, - only_negative: bool = False, - significance_threshold: float = 0.05, - lr_pair: Union[None, str] = None, - save_prefix: Union[None, str] = None, - ): - """For each predictor and each feature, determine if the influence of said predictor in predicting said - feature is significant. - - Additionally, for each feature and each sender-receiver category pair, determines the effect size that - the sender induces in the feature for the receiver. - - Only valid if the model specified uses the connections between categories as variables for the regression- - thus can be applied to 'mod_type' "niche", or "niche_lr". - - Args: - coeffs: Contains coefficients from regression for each variable - only_positive: Set True to find significance/pvalues/qvalues only for the subset of coefficients that is - positive (representing possible mechanisms of positive regulation). - only_negative: Set True to find significance/pvalues/qvalues only for the subset of coefficients that is - negative (representing possible mechanisms of positive regulation). - significance_threshold: p-value needed to call a sender-receiver relationship significant - lr_pair: Required if (and used only in the case that) coefficients came from a Niche-LR model; used to - subset the coefficients array to the specific ligand-receptor pair of interest. Takes the form - "{ligand}-{receptor}" and should match one of the keys in :dict `self.niche_mats`. If not given, - will default to the first key in this dictionary. - save_prefix: If provided, saves all relevant dataframes to :path `./regression_outputs` under the name - `{prefix}_{coeffs/pvalues, etc.}.csv`. If not provided, will not save. - """ - # If "Poisson" given as the distributional assumption, check for log-transformed data: - if self.distr == "poisson": - if not any("log1p" in key for key in self.adata.layers.keys()): - self.logger.info( - "With Poisson distribution assumed for dependent variable, using log-transformed data " - "to compute sender-receiver effects...log key not found in adata, manually computing." - ) - self.preprocess_data(log_transform=True) - - if not "niche" in self.mod_type: - self.logger.error( - "Type coupling analysis only valid if connections between categories are used as the " - "predictor variable." - ) - - coeffs_np = coeffs.values - - is_significant, pvalues, qvalues = self.compute_coeff_significance( - coeffs, - only_positive=only_positive, - only_negative=only_negative, - significance_threshold=significance_threshold, - ) - - # If 'save_prefix' is given, save the complete coefficients, significance, p-value and q-value matrices: - if save_prefix is not None: - if not os.path.exists("./regression_outputs"): - os.makedirs("./regression_outputs") - is_significant.to_csv(f"./regression_outputs/{save_prefix}_is_sign.csv") - pvalues.to_csv(f"./regression_outputs/{save_prefix}_pvalues.csv") - qvalues.to_csv(f"./regression_outputs/{save_prefix}_qvalues.csv") - coeffs.to_csv(f"./regression_outputs/{save_prefix}_coeffs.csv") - - # If niche-LR model, extract the portion corresponding to the interaction terms for a specific L-R pair: - if self.mod_type == "niche_lr": - if lr_pair is None: - self.logger.warning( - "'lr_pair' not specified- defaulting to the first L:R pair that was used for the " - "model. For reference, all L:R pairs used for the " - f"model: {list(self.niche_mats.keys())}" - ) - lr_pair = list(self.niche_mats.keys())[0] - if lr_pair not in self.niche_mats.keys(): - self.logger.warning( - "Input to 'lr_pair' not recognized- proceeding with the first L:R pair that was " - "used for the model. For reference, all L:R pairs used for the " - f"model: {list(self.niche_mats.keys())}" - ) - lr_pair = list(self.niche_mats.keys())[0] - - is_significant = is_significant.filter(lr_pair, axis="index") - pvalues = pvalues.filter(lr_pair, axis="index") - qvalues = qvalues.filter(lr_pair, axis="index") - - # Coefficients, etc. will also be a subset of the complete array: - coeffs = coeffs.filter(lr_pair, axis="columns") - coeffs_np = coeffs.values - # Significance, pvalues, qvalues filtered above - - # If the 'square' flag is set- that is, if the original parameters constitute at least one pairwise - # combination of cell types (which is not true if 'categories' are given to the fit function): - if self.square: - self.effect_size = np.concatenate( - np.expand_dims( - np.split(coeffs_np.T, indices_or_sections=np.sqrt(coeffs_np.T.shape[0]), axis=0), axis=0 - ), - axis=0, - ) - else: - self.effect_size = coeffs - - # Else if connection-based model, all regression coefficients already correspond to the interaction terms: - else: - if self.square: - self.effect_size = np.concatenate( - np.expand_dims( - np.split(coeffs_np.T, indices_or_sections=np.sqrt(coeffs_np.T.shape[0]), axis=0), axis=0 - ), - axis=0, - ) - else: - self.effect_size = coeffs - - if self.square: - # Split array such that an nxn matrix is created, where n is 'n_features' (the number of cell type - # categories) - self.pvalues = np.concatenate( - np.expand_dims(np.split(pvalues, indices_or_sections=np.sqrt(pvalues.shape[0]), axis=0), axis=0), - axis=0, - ) - self.qvalues = np.concatenate( - np.expand_dims(np.split(qvalues, indices_or_sections=np.sqrt(qvalues.shape[0]), axis=0), axis=0), - axis=0, - ) - self.is_significant = np.concatenate( - np.expand_dims( - np.split(is_significant, indices_or_sections=np.sqrt(is_significant.shape[0]), axis=0), axis=0 - ), - axis=0, - ) - else: - self.pvalues = pvalues.T - self.qvalues = qvalues.T - self.is_significant = is_significant.T - - def niche_differential_expression( - self, - cmap: str = "Reds", - fontsize: Union[None, int] = None, - figsize: Union[None, Tuple[float, float]] = None, - ignore_self: bool = True, - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "save", - save_kwargs: dict = {}, - ): - """Generates heatmap of spatially differentially-expressed features for each pair of sender and receiver - categories. Only valid if the model specified uses the connections between categories as variables for the - regression. - - A high number of differentially-expressed genes between a given sender-receiver pair means that the sender - being in the neighborhood of the receiver tends to correlate with differential expression levels of many of - the genes within the selection- much of the cellular variation in the receiver cell type can be attributed to - being in proximity with the sender. - - Args: - cmap: Name of Matplotlib color map to use - fontsize: Size of figure title and axis labels - figsize: Width and height of plotting window - save_show_or_return: Options: "save", "show", "return", "both", "all" - - "both" for save and show - ignore_self: If True, will ignore the effect of cell type in proximity to other cells of the same type- - will record the number of DEGs only if the two cell types are different. - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - if fontsize is None: - self.fontsize = rcParams.get("font.size") - else: - self.fontsize = fontsize - if figsize is None: - self.figsize = rcParams.get("figure.figsize") - else: - self.figsize = figsize - - if not hasattr(self, "is_significant"): - self.logger.warning("Significance dataframe does not exist- please run :func `get_effect_sizes` " "first.") - - if not hasattr(self, "square"): - self.logger.error( - ":func `type_coupling_analysis` can only be run if the design matrix can be made square- that is, " - "if all pairwise combinations of cell types are represented." - ) - - sig_df = pd.DataFrame( - np.sum(self.is_significant, axis=-1), columns=self.celltype_names, index=self.celltype_names - ) - if ignore_self: - np.fill_diagonal(sig_df.values, 0) - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) - res = sns.heatmap(sig_df, square=True, linecolor="grey", linewidths=0.3, cmap=cmap, mask=(sig_df == 0), ax=ax) - # Outer frame: - for _, spine in res.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - plt.xlabel("Receiving Cell") - plt.ylabel("Sending Cell") - - title = ( - "Niche-Associated Differential Expression" - if self.mod_type == "niche" - else "Cell Type-Specific Ligand:Receptor-Associated Differential Expression" - ) - plt.title(title) - plt.tight_layout() - - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="type_coupling", - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) - - def sender_effect_on_all_receivers( - self, - sender: str, - gene_subset: Union[None, List[str]] = None, - significance_threshold: float = 0.05, - cut_pvals: float = -5, - fontsize: Union[None, int] = None, - figsize: Union[None, Tuple[float, float]] = None, - cmap: str = "seismic", - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", - save_kwargs: Optional[dict] = {}, - ): - """Evaluates and visualizes the effect that the given sender cell type has on expression/abundance in each - possible receiver cell type. - - Args: - sender: sender cell type label - gene_subset: Names of genes to subset for plot. If None, will use all genes that were used in the - regression. - significance_threshold: Set non-significant effect sizes to zero, where the threshold is given here - cut_pvals: Minimum allowable log10(pval)- anything below this will be clipped to this value - fontsize: Size of figure title and axis labels - figsize: Width and height of plotting window - cmap: Name of matplotlib colormap specifying colormap to use - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - logger = lm.get_main_logger() - config_spateo_rcParams() - - if fontsize is None: - self.fontsize = rcParams.get("font.size") - else: - self.fontsize = fontsize - if figsize is None: - self.figsize = rcParams.get("figure.figsize") - else: - self.figsize = figsize - - if self.square: - sender_idx = self.celltype_names.index(sender) - - arr = self.effect_size[sender_idx, :, :].copy() - arr[np.where(self.qvalues[sender_idx, :, :] > significance_threshold)] = 0 - df = pd.DataFrame(arr, index=self.celltype_names, columns=self.genes) - if gene_subset: - df = df.drop(index=sender)[gene_subset] - vmax = np.max(np.abs(df.values)) - - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=self.figsize) - es = sns.heatmap( - df.T, - square=True, - linecolor="grey", - linewidths=0.3, - cbar_kws={"label": "Effect size", "location": "top"}, - cmap=cmap, - vmin=-vmax, - vmax=vmax, - ax=ax, - ) - - # Outer frame: - for _, spine in es.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - - else: - logger.error("Invalid input to 'plot_mode'. Options: 'qvals', 'effect_size'.") - - else: - if sender not in self.categories: - self.logger.error( - "Adata was subset to categories of interest and fit on those categories, " - "but the group provided to 'sender' is not one of those categories." - ) - - sender_cols = [col for col in self.effect_size.columns if sender in col.split("-")[1]] - # (note that what should be considered the "receiver" is the first cell type listed) - - df = self.effect_size[sender_cols].copy() - # Reformat columns for visual purposes: - receivers = [ct[0] for ct in df.columns.str.split("-")] - df.columns = [col.split("_")[1] for col in receivers] - df.values[np.where(self.qvalues[sender_cols] > significance_threshold)] = 0 - if gene_subset: - df = df.loc[gene_subset] - vmax = np.max(np.abs(df.values)) - - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=self.figsize) - es = sns.heatmap( - df, - square=True, - linecolor="grey", - linewidths=0.3, - cbar_kws={"label": "Effect size", "location": "top"}, - cmap=cmap, - vmin=-vmax, - vmax=vmax, - ax=ax, - ) - - # Outer frame: - for _, spine in es.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - - else: - logger.error("Invalid input to 'plot_mode'. Options: 'qvals', 'effect_size'.") - - plt.xlabel("Receiver cell type", fontsize=9) - plt.title("{} effects on receivers".format(sender), fontsize=9) - plt.tight_layout() - - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="{}_effects_on_receivers".format(sender), - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) - - def all_senders_effect_on_receiver( - self, - receiver: str, - gene_subset: Union[None, List[str]] = None, - significance_threshold: float = 0.05, - cut_pvals: float = -5, - fontsize: Union[None, int] = None, - figsize: Union[None, Tuple[float, float]] = None, - cmap: str = "seismic", - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", - save_kwargs: Optional[dict] = {}, - ): - """Evaluates and visualizes the effect that each possible sender cell type has on expression/abundance in a - selected receiver cell type. - - Args: - receiver: Receiver cell type label - plot_mode: specifies what gets plotted. - Options: - - "qvals": elements of the plot represent statistical significance of the interaction - - "effect_size": elements of the plot represent effect size induced in the receiver by the sender - gene_subset: Names of genes to subset for plot. If None, will use all genes that were used in the - regression. - significance_threshold: Set non-significant effect sizes to zero, where the threshold is given here - cut_pvals: Minimum allowable log10(pval)- anything below this will be clipped to this value - fontsize: Size of figure title and axis labels - figsize: Width and height of plotting window - cmap: Name of matplotlib colormap specifying colormap to use - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - logger = lm.get_main_logger() - config_spateo_rcParams() - - if fontsize is None: - self.fontsize = rcParams.get("font.size") - else: - self.fontsize = fontsize - if figsize is None: - self.figsize = rcParams.get("figure.figsize") - else: - self.figsize = figsize - - if self.square: - receiver_idx = self.celltype_names.index(receiver) - - arr = self.effect_size[:, receiver_idx, :].copy() - arr[np.where(self.qvalues[:, receiver_idx, :] > significance_threshold)] = 0 - df = pd.DataFrame(arr, index=self.celltype_names, columns=self.genes) - if gene_subset: - df = df.drop(index=receiver)[gene_subset] - vmax = np.max(np.abs(df.values)) - - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=self.figsize) - es = sns.heatmap( - df.T, - square=True, - linecolor="grey", - linewidths=0.3, - cbar_kws={"label": "Effect size", "location": "top"}, - cmap=cmap, - vmin=-vmax, - vmax=vmax, - ax=ax, - ) - - # Outer frame: - for _, spine in es.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - - else: - logger.error("Invalid input to 'plot_mode'. Options: 'qvals', 'effect_size'.") - - else: - if receiver not in self.categories: - self.logger.error( - "Adata was subset to categories of interest and fit on those categories, " - "but the provided group to 'receiver' is not one of those categories." - ) - - receiver_cols = [col for col in self.effect_size.columns if receiver in col.split("-")[0]] - # (note that what should be considered the "receiver" is the first cell type listed) - - df = self.effect_size[receiver_cols].copy() - # Reformat columns for visual purposes: - senders = [ct[1] for ct in df.columns.str.split("-")] - df.columns = [col.split("_")[1] for col in senders] - df.values[np.where(self.qvalues[receiver_cols] > significance_threshold)] = 0 - if gene_subset: - df = df.loc[gene_subset] - vmax = np.max(np.abs(df.values)) - - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=self.figsize) - es = sns.heatmap( - df, - square=True, - linecolor="grey", - linewidths=0.3, - cbar_kws={"label": "Effect size", "location": "top"}, - cmap=cmap, - vmin=-vmax, - vmax=vmax, - ax=ax, - ) - - # Outer frame: - for _, spine in es.spines.items(): - spine.set_visible(True) - spine.set_linewidth(0.75) - - else: - logger.error("Invalid input to 'plot_mode'. Options: 'qvals', 'effect_size'.") - - plt.xlabel("Sender cell type", fontsize=9) - plt.title("Sender Effects on " + receiver, fontsize=9) - plt.tight_layout() - - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="sender_effects_on_{}".format(receiver), - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) - - def sender_receiver_effect_volcanoplot( - self, - sender: str, - receiver: str, - significance_threshold: float = 0.05, - effect_size_threshold: Union[None, float] = None, - fontsize: Union[None, int] = None, - figsize: Union[None, Tuple[float, float]] = (4.5, 7.0), - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", - save_kwargs: Optional[dict] = {}, - ): - """Volcano plot to identify differentially expressed genes of a given receiver cell type in the presence of a - given sender cell type. - - Args: - sender: Sender cell type label - receiver: Receiver cell type label - significance_threshold: Set non-significant effect sizes (given by q-values) to zero, where the - threshold is given here - effect_size_threshold: Set absolute value effect-size threshold beyond which observations are marked as - interesting. If not given, will take the 95th percentile fold-change as the cutoff. - fontsize: Size of figure title and axis labels - figsize: Width and height of plotting window - save_show_or_return: Whether to save, show or return the figure. - If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, - displayed and the associated axis and other object will be return. - save_kwargs: A dictionary that will passed to the save_fig function. - By default it is an empty dictionary and the save_fig function will use the - {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, - "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modifies those - keys according to your needs. - """ - logger = lm.get_main_logger() - config_spateo_rcParams() - - if fontsize is None: - self.fontsize = rcParams.get("font.size") - else: - self.fontsize = fontsize - if figsize is None: - self.figsize = rcParams.get("figure.figsize") - else: - self.figsize = figsize - - # Set fold-change threshold if not already provided: - if effect_size_threshold is None: - effect_size_threshold = np.percentile(self.effect_size, 95) - - fig, ax = plt.subplots(1, 1, figsize=self.figsize) - ax.grid(False) - - if self.square: - receiver_idx = self.celltype_names.index(receiver) - sender_idx = self.celltype_names.index(sender) - - # All non-significant features: - qval_filter = np.where(self.qvalues[sender_idx, receiver_idx, :] >= significance_threshold) - vmax = np.max(np.abs(self.effect_size[sender_idx, receiver_idx, :])) - - if qval_filter[0].size > 0: - sns.scatterplot( - x=self.effect_size[sender_idx, receiver_idx, :][qval_filter], - y=-np.log10(self.qvalues[sender_idx, receiver_idx, :])[qval_filter], - color="white", - edgecolor="black", - s=50, - ax=ax, - ) - - # Identify subset that may be significant, but which doesn't pass the fold-change threshold: - qval_filter = np.where(self.qvalues[sender_idx, receiver_idx, :] < significance_threshold) - x = self.effect_size[sender_idx, receiver_idx, :][qval_filter] - y = -np.nan_to_num(np.log10(self.qvalues[sender_idx, receiver_idx, :])[qval_filter], posinf=10, neginf=-10) - fc_filter = np.where(x < effect_size_threshold) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color="darkgrey", edgecolor="black", s=50, ax=ax) - - # Identify subset that are significantly downregulated: - dreg_color = matplotlib.cm.get_cmap("winter")(0) - y = -np.nan_to_num(np.log10(self.qvalues[sender_idx, receiver_idx, :])[qval_filter], posinf=10, neginf=-10) - fc_filter = np.where(x <= -effect_size_threshold) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color=dreg_color, edgecolor="black", s=50, ax=ax) - - # Identify subset that are significantly upregulated: - ureg_color = matplotlib.cm.get_cmap("autumn")(0) - fc_filter = np.where(x >= effect_size_threshold) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color=ureg_color, edgecolor="black", s=50, ax=ax) - - else: - if sender not in self.categories and receiver not in self.categories: - self.logger.error( - "Adata was subset to categories of interest and fit on those categories, " - "but neither the sender nor the receiver group are of those categories." - ) - - # All non-significant features: - sender_receiver_cols = [ - col for col in self.effect_size.columns if sender in col.split("-")[1] and receiver in col.split("-")[0] - ] - qval_filter = np.where(self.qvalues[sender_receiver_cols] >= significance_threshold) - vmax = np.max(np.abs(self.effect_size[sender_receiver_cols].values)) - - if qval_filter[0].size > 0: - sns.scatterplot( - x=self.effect_size[sender_receiver_cols].values[qval_filter], - y=-np.log10(self.qvalues[sender_receiver_cols].values)[qval_filter], - color="white", - edgecolor="black", - s=50, - ax=ax, - ) - - qval_filter = np.where(self.qvalues[sender_receiver_cols] < significance_threshold) - x = self.effect_size[sender_receiver_cols].values[qval_filter] - y = -np.nan_to_num(np.log10(self.qvalues[sender_receiver_cols].values)[qval_filter], posinf=10, neginf=-10) - - # Identify subset that may be significant, but which doesn't pass the effect size threshold: - fc_filter = np.where(x < effect_size_threshold) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color="darkgrey", edgecolor="black", s=50, ax=ax) - - # Identify subset that are significantly downregulated: - dreg_color = matplotlib.cm.get_cmap("winter")(0) - fc_filter = np.where(x <= -effect_size_threshold) - y = -np.nan_to_num(np.log10(self.qvalues[sender_receiver_cols].values)[qval_filter], posinf=10, neginf=-10) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color=dreg_color, edgecolor="black", s=50, ax=ax) - - # Identify subset that are significantly upregulated: - ureg_color = matplotlib.cm.get_cmap("autumn")(0) - fc_filter = np.where(x >= effect_size_threshold) - if qval_filter[0].size > 0: - sns.scatterplot(x=x[fc_filter], y=y[fc_filter], color=ureg_color, edgecolor="black", s=50, ax=ax) - - # Plot configuration: - ax.set_xlim((-vmax * 1.1, vmax * 1.1)) - ax.set_xlabel("Effect size", fontsize=9) - ax.set_ylabel("$-\log_{10}$ FDR-corrected pvalues", fontsize=9) - ax.tick_params(axis="both", labelsize=8) - plt.axvline(-effect_size_threshold, color="darkgrey", linestyle="--", linewidth=0.9) - plt.axvline(effect_size_threshold, color="darkgrey", linestyle="--", linewidth=0.9) - plt.axhline(-np.log10(significance_threshold), linestyle="--", color="darkgrey", linewidth=0.9) - - plt.tight_layout() - save_return_show_fig_utils( - save_show_or_return=save_show_or_return, - show_legend=True, - background="white", - prefix="effect_of_{}_on_{}".format(sender, receiver), - save_kwargs=save_kwargs, - total_panels=1, - fig=fig, - axes=ax, - return_all=False, - return_all_list=None, - ) - - -class Category_Model(Base_Model): - """Wraps all necessary methods for data loading and preparation, model initialization, parameterization, - evaluation and prediction when instantiating a model for spatially-aware (but not spatially lagged) regression - using categorical variables (specifically, the prevalence of categories within spatial neighborhoods) to predict - the value of gene expression. - - Arguments passed to :class `Base_Model`. The only keyword argument that is used for this class is - 'n_neighbors'. - - Args: - args: Positional arguments to :class `Base_Model` - kwargs: Keyword arguments to :class `Base_Model` - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert self.group_key is not None, "Categorical labels required for this model." - - # Prepare data: - self.prepare_data(mod_type="category") - - -class Niche_Model(Base_Model): - """Wraps all necessary methods for data loading and preparation, model initialization, parameterization, - evaluation and prediction when instantiating a model for spatially-aware regression using both the prevalence of - and connections between categories within spatial neighborhoods to predict the value of gene expression. - - Arguments passed to :class `Base_Model`. - - Args: - args: Positional arguments to :class `Base_Model` - kwargs: Keyword arguments to :class `Base_Model` - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert self.group_key is not None, "Categorical labels required for this model." - - self.prepare_data(mod_type="niche") - - -class Lagged_Model(Base_Model): - """Wraps all necessary methods for data loading and preparation, model initialization, parameterization, - evaluation and prediction when instantiating a model for spatially-lagged regression. - - Can specify one of two models: "ligand", which uses the spatial lag of ligand genes and the spatial lag of the - regression target to predict the regression target, or "niche", which uses the spatial lag of cell type - colocalization and the spatial lag of the regression target to predict the regression target. - - If "ligand" is specified, arguments to `lig` must be given, and it is recommended to provide `species` as well- - default for this is human. - - Arguments passed to :class `Base_Model`. - - Args: - model_type: Either "ligand" or "niche", specifies whether to fit a model that incorporates the spatial lag of - ligand expression or the spatial lag of cell type colocalization. - lig: Name(s) of ligands to use as predictors - rec: Name(s) of receptors to use as regression targets. If not given, will search through database for all - genes that correspond to the provided genes from 'ligands'. - rec_ds: Name(s) of receptor-downstream genes to use as regression targets. If not given, will search through - database for all genes that correspond to receptor-downstream genes. - species: Specifies L:R database to use - normalize: Perform library size normalization, to set total counts in each cell to the same number (adjust - for cell size) - smooth: To correct for dropout effects, leverage gene expression neighborhoods to smooth expression - log_transform: Set True if log-transformation should be applied to expression (otherwise, will assume - preprocessing/log-transform was computed beforehand) - args: Additional positional arguments to :class `Base_Model` - kwargs: Additional keyword arguments to :class `Base_Model` - """ - - def __init__( - self, - model_type: str = "ligand", - lig: Union[None, str, List[str]] = None, - rec: Union[None, str, List[str]] = None, - rec_ds: Union[None, str, List[str]] = None, - species: Literal["human", "mouse", "axolotl"] = "human", - normalize: bool = True, - smooth: bool = False, - log_transform: bool = True, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - if self.distr != "gaussian": - self.logger.info( - "We recommend applying spatially-lagged models to processed data, for which normality " - "can be assumed- in this case `distr` can be set to 'gaussian'." - ) - - if model_type == "ligand": - if lig is None: - self.logger.error( - "From instantiation of :class `Lagged_Model`: `model_type` was given as 'ligand', " - "but ligands were not provided using parameter 'lig'." - ) - # Optional data preprocessing: - self.preprocess_data(normalize, smooth, log_transform) - self.prepare_data(mod_type="ligand_lag", lig=lig, rec=rec, rec_ds=rec_ds, species=species) - elif model_type == "niche": - # Optional data preprocessing: - self.preprocess_data(normalize, smooth, log_transform) - self.prepare_data(mod_type="niche_lag") - - def run_GM_lag(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - """Runs spatially lagged two-stage least squares model""" - if not hasattr(self, "w"): - self.logger.info( - "Called 'run_GM_lag' before computing spatial weights array- computing spatial weights " - "array before proceeding..." - ) - self.compute_spatial_weights() - - # Regress on one gene at a time: - all_values, all_pred, all_resid = [], [], [] - for i in tqdm(range(len(self.genes))): - cur_g = self.genes[i] - values, pred, resid = self.single( - cur_g, self.X, self.variable_names, self.param_labels, self.adata, self.w, self.layer - ) - all_values.append(values) - all_pred.append(pred) - all_resid.append(resid) - - # Coefficients and their significance: - coeffs = pd.DataFrame(np.vstack(all_values)) - coeffs.columns = self.adata.var.loc[self.genes, :].columns - - pred = pd.DataFrame(np.hstack(all_pred), index=self.adata.obs_names, columns=self.genes) - resid = pd.DataFrame(np.hstack(all_resid), index=self.adata.obs_names, columns=self.genes) - - # Update AnnData object: - self.adata.obsm["ypred"] = pred - self.adata.obsm["resid"] = resid - - for cn in coeffs.columns: - self.adata.var.loc[:, cn] = coeffs[cn] - - return coeffs, pred, resid - - def single( - self, - cur_g: str, - X: pd.DataFrame, - X_variable_names: List[str], - param_labels: List[str], - adata: AnnData, - w: np.ndarray, - layer: Union[None, str] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Defines model run process for a single feature- not callable by the user, all arguments populated by - arguments passed on instantiation of :class `Base_Model`. - - Args: - cur_g: Name of the feature to regress on - X: Values used for the regression - X_variable_names: Names of the variables used for the regression - param_labels: Names of categories- each computed parameter corresponds to a single element in - param_labels - adata: AnnData object to store results in - w: Spatial weights array - layer: Specifies layer in AnnData to use- if None, will use .X. - - Returns: - coeffs: Coefficients for each categorical group for each feature - pred: Predicted values from regression for each feature - resid: Residual values from regression for each feature - """ - if layer is None: - X["log_expr"] = adata[:, cur_g].X.A - else: - X["log_expr"] = adata[:, cur_g].layers[layer].A - - try: - from pysal.model import spreg - - model = spreg.GM_Lag( - X[["log_expr"]].values, - X[X_variable_names].values, - w=w, - name_y="log_expr", - name_x=X_variable_names, - ) - self.logger.info(f"Printing model summary for regression on {cur_g}: \n") - print(model.summary) - y_pred = model.predy - resid = model.u - - # Coefficients for each cell type: - a = pd.DataFrame(model.betas, model.name_x + ["W_log_exp"], columns=["coef"]) - b = pd.DataFrame( - model.z_stat, - model.name_x + ["W_log_exp"], - columns=["z_stat", "p_val"], - ) - - df = a.merge(b, left_index=True, right_index=True) - - for ind, g in enumerate(["const"] + param_labels + ["W_log_exp"]): - adata.var.loc[cur_g, str(g) + "_GM_lag_coeff"] = df.iloc[ind, 0] - adata.var.loc[cur_g, str(g) + "_GM_lag_zstat"] = df.iloc[ind, 1] - adata.var.loc[cur_g, str(g) + "_GM_lag_pval"] = df.iloc[ind, 2] - - except: - y_pred = np.full((X.shape[0],), np.nan) - resid = np.full((X.shape[0],), np.nan) - - for ind, g in enumerate(["const"] + param_labels + ["W_log_exp"]): - adata.var.loc[cur_g, str(g) + "_GM_lag_coeff"] = np.nan - adata.var.loc[cur_g, str(g) + "_GM_lag_zstat"] = np.nan - adata.var.loc[cur_g, str(g) + "_GM_lag_pval"] = np.nan - - # Outputs for a single gene: - return adata.var.loc[cur_g, :].values, y_pred.reshape(-1, 1), resid.reshape(-1, 1) - - -class Niche_LR_Model(Base_Model): - """Wraps all necessary methods for data loading and preparation, model initialization, parameterization, - evaluation and prediction when instantiating a model for spatially-aware regression using the prevalence of and - connections between categories within spatial neighborhoods and the cell type-specific expression of ligands and - receptors to predict the regression target. - - Arguments passed to :class `Base_Model`. - - Args: - lig: Name(s) of ligands to use as predictors - rec: Name(s) of receptors to use as regression targets. If not given, will search through database for all - genes that correspond to the provided genes from 'ligands' - rec_ds: Name(s) of receptor-downstream genes to use as regression targets. If not given, will search through - database for all genes that correspond to receptors - species: Specifies L:R database to use - niche_lr_r_lag: Only used if 'mod_type' is "niche_lr". Uses the spatial lag of the receptor as the - dependent variable rather than each spot's unique receptor expression. Defaults to True. - args: Additional positional arguments to :class `Base_Model` - kwargs: Additional keyword arguments to :class `Base_Model` - """ - - def __init__( - self, - lig: Union[None, str, List[str]], - rec: Union[None, str, List[str]] = None, - rec_ds: Union[None, str, List[str]] = None, - species: Literal["human", "mouse", "axolotl"] = "human", - niche_lr_r_lag: bool = True, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.logger.info( - "Predictor arrays for :class `Niche_LR_Model` are extremely sparse. It is recommended " - "to provide categories to subset for :func `GLMCV_fit_predict`." - ) - - self.prepare_data( - mod_type="niche_lr", lig=lig, rec=rec, rec_ds=rec_ds, species=species, niche_lr_r_lag=niche_lr_r_lag - ) - - -def calc_1nd_moment(X, W, normalize_W=True) -> Tuple[np.ndarray, Optional[np.ndarray]]: - if normalize_W: - if type(W) == np.ndarray: - d = np.sum(W, 1).flatten() - else: - d = np.sum(W, 1).A.flatten() - W = diags(1 / d) @ W if issparse(W) else np.diag(1 / d) @ W - return W @ X, W - else: - return W @ X diff --git a/spateo/tools/__init__.py b/spateo/tools/__init__.py index f62712f4..7dc6b9a7 100755 --- a/spateo/tools/__init__.py +++ b/spateo/tools/__init__.py @@ -5,6 +5,7 @@ find_spatially_related_genes, get_genes_from_spatial_archetype, ) +from .CCI_effects_modeling import * from .cci_two_cluster import ( find_cci_two_group, prepare_cci_cellpair_adata, @@ -21,31 +22,10 @@ ) from .cluster_lasso import * from .coarse_align import AffineTrans, align_slices_pca, pca_align, procrustes -from .find_neighbors import ( - construct_geodesic_distance_matrix, - construct_nn_graph, - construct_spatial_distance_matrix, - generate_spatial_weights_fixed_nbrs, - generate_spatial_weights_fixed_radius, - weighted_expr_neighbors_graph, - weighted_spatial_graph, -) +from .find_neighbors import construct_nn_graph from .glm import glm_degs from .labels import Label, create_label_class from .lisa import GM_lag_model, lisa_geo_df, local_moran_i from .live_wire import LiveWireSegmentation, compute_shortest_path, live_wire from .spatial_degs import cellbin_morani, moran_i -from .spatial_smooth import * -from .spatial_smooth.run_smoothing import smooth_and_downsample -from .spatial_smooth.smooth import STGNN from .spatially_variable_gene_ot import cal_wass_dis_bs -from .ST_regression import * -from .ST_regression.generalized_lm import fit_glm -from .ST_regression.regression_utils import plot_prior_vs_data -from .ST_regression.spatial_regression import ( - Category_Model, - Lagged_Model, - Niche_LR_Model, - Niche_Model, -) -from .utils import cellbin_select diff --git a/spateo/tools/spatial_smooth/__init__.py b/spateo/tools/spatial_smooth/__init__.py deleted file mode 100644 index 5634c76e..00000000 --- a/spateo/tools/spatial_smooth/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .run_smoothing import smooth_and_downsample -from .smooth import STGNN diff --git a/spateo/tools/spatial_smooth/run_smoothing.py b/spateo/tools/spatial_smooth/run_smoothing.py deleted file mode 100644 index b20b67a7..00000000 --- a/spateo/tools/spatial_smooth/run_smoothing.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Wrapper function to run generative modeling for count denoising and imputation. -""" -from typing import List, Optional, Tuple, Union - -import anndata -import numpy as np -import scipy - -from ...configuration import SKM -from ...logging import logger_manager as lm -from ...plotting.static.space import space -from ...preprocessing.filter import filter_genes -from ...preprocessing.normalize import normalize_total -from ...preprocessing.transform import log1p -from ...tools.spatial_degs import moran_i -from .smooth import STGNN - - -@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE) -def smooth_and_downsample( - adata: anndata.AnnData, - filter_by_moran: bool = False, - spatial_key: str = "spatial", - n_neighbors: Optional[int] = 20, - positive_ratio_cutoff: float = 0.1, - imputation: bool = True, - n_ds: Optional[int] = None, - to_visualize: Union[None, str, List[str]] = None, - cmap: str = "magma", - device: str = "cpu", - **kwargs, -) -> Tuple[anndata.AnnData, anndata.AnnData]: - """Smooth gene expression distributions and downsample a spatial sample by selecting representative points from - this smoothed slice. - - Args: - adata: AnnData object to model - filter_by_moran: Set True to split - for samples with highly uniform expression patterns, simple spatial - smoothing will be used. For samples with localized patterns, graph neural network will be used for - smoothing. If False, graph neural network will be applied to all genes. - spatial_key: Only used if 'filter_by_moran' is True; key in .obsm where x- and y-coordinates are stored. - n_neighbors: Number of neighbors for each - positive_ratio_cutoff: Filter condition for genes- each gene must be present in higher than this proportion - of the total number of cells to be retained - imputation: Set True to perform imputation. If False, will only downsample. - n_ds: Optional number of cells to downsample to- if not given, will not perform downsampling - kwargs: Additional arguments that can be provided to :func `STGNN.train_STGNN`. Options for kwargs: - - learn_rate: Float, controls magnitude of gradient for network learning - - dropout: Float between 0 and 1, proportion of weights in each layer to set to 0 - - act: String specifying activation function for each encoder layer. Options: "sigmoid", "tanh", "relu", - "elu" - - clip: Float between 0 and 1, threshold below which imputed feature values will be set to 0, - as a percentile. Recommended between 0 and 0.1. - - weight_decay: Float, controls degradation rate of parameters - - epochs: Int, number of iterations of training loop to perform - - dim_output: Int, dimensionality of the output representation - - alpha: Float, controls influence of reconstruction loss in representation learning - - beta: Float, weight factor to control the influence of contrastive loss in representation learning - - theta: Float, weight factor to control the influence of the regularization term in representation learning - - add_regularization: Bool, adds penalty term to representation learning - - Returns: - adata_orig: Input AnnData object - (optional) adata_rex: - (optional) adata: AnnData subsetted down to downsampled buckets. - """ - import spreg - from libpysal import weights - - logger = lm.get_main_logger() - if n_ds is None and not imputation: - logger.error( - "Neither downsampling nor imputation will be done (no integer has been provided to 'n_ds' and " - "'imputation' is currently False)- exiting program." - ) - - adata_orig = adata.copy() - normalize_total(adata_orig, 1e4) - log1p(adata_orig) - min_cells = int(adata_orig.n_obs * positive_ratio_cutoff) - filter_genes(adata_orig, min_cells=min_cells) - - if imputation: - if filter_by_moran: - # Keep genes with significant Moran's I q-value (threshold = 0.05): - m_degs = moran_i(adata_orig) - m_uniform = m_degs[m_degs.moran_q_val >= 0.05].index - m_degs = m_degs[m_degs.moran_q_val < 0.05].index - - adata_m_filt_out = adata_orig[:, m_uniform] - adata_m_filt = adata_orig[:, m_degs] - - # For the genes with nonsignificant Moran's index, perform spatial smoothing: - adata_m_filt_out_rex = adata_m_filt_out.copy() - n_neighbors = np.ceil(0.01 * adata_m_filt_out.n_obs) if n_neighbors is None else n_neighbors - w = weights.distance.KNN.from_array(adata_m_filt_out.obsm[spatial_key], k=n_neighbors) - rec_lag = scipy.sparse.csr_matrix(spreg.utils.lag_spatial(w, adata_m_filt_out.X)) - rec_lag.eliminate_zeros() - adata_m_filt_out_rex.X = rec_lag - - # For the genes with significant Moran's index, perform smoothing w/ generative modeling: - model = STGNN(adata_m_filt, spatial_key, random_seed=50, add_regularization=False, device=device) - adata_m_filt_rex = model.train_STGNN(**kwargs) - # Set default layer to 'X_smooth_gcn' (the reconstruction): - adata_m_filt_rex.X = adata_m_filt_rex.layers["X_smooth_gcn"] - - # Final smoothing: - w = weights.distance.KNN.from_array(adata_m_filt_rex.obsm["spatial"], k=n_neighbors) - rec_lag = scipy.sparse.csr_matrix(spreg.utils.lag_spatial(w, adata_m_filt_rex.X)) - rec_lag.eliminate_zeros() - adata_m_filt_rex.X = rec_lag - - adata_rex = anndata.concat([adata_m_filt_rex, adata_m_filt_out_rex], axis=1) - # .uns, .varm and .obsm are ignored by the concat operation- add back to the concatenated object: - adata_rex.uns = adata_m_filt_rex.uns - adata_rex.varm = adata_m_filt_rex.varm - adata_rex.obsm = adata_m_filt_rex.obsm - - if to_visualize is not None: - for feat in to_visualize: - # For plotting, normalize all columns of imputed and original data such that the maximum value is 1: - feat_idx = adata_orig.var_names.get_loc(feat) - adata_orig.X[:, feat] /= np.max(adata_orig.X[:, feat]) - - # Generate two plots: one for observed data and one for imputed: - print(f"{feat} Observed") - size = 100 / adata_orig.n_obs - space(adata_orig, color=feat, cmap=cmap, figsize=(2.5, 2.5), dpi=300, pointsize=size, alpha=0.9) - - print(f"{feat} Imputed") - size = 100 / adata_orig.n_obs - space(adata_rex, color=feat, cmap=cmap, figsize=(2.5, 2.5), dpi=300, pointsize=size, alpha=0.9) - - else: - # Smooth all genes using generative modeling: - model = STGNN(adata_orig, spatial_key, random_seed=50, add_regularization=False, device=device) - adata_rex = model.train_STGNN(**kwargs) - # Set default layer to 'X_smooth_gcn' (the reconstruction): - adata_rex.X = adata_rex.layers["X_smooth_gcn"] - - if to_visualize is not None: - for feat in to_visualize: - # Generate two plots: one for observed data and one for imputed: - print(f"{feat} Observed") - size = 100 / adata_orig.n_obs - space(adata_orig, color=feat, cmap=cmap, figsize=(5, 5), dpi=300, pointsize=size, alpha=0.9) - - print(f"{feat} Imputed") - size = 100 / adata_orig.n_obs - space(adata_rex, color=feat, cmap=cmap, figsize=(5, 5), dpi=300, pointsize=size, alpha=0.9) - - # Add downsampling later: - - if imputation: - return adata_rex, adata_orig diff --git a/spateo/tools/spatial_smooth/smooth.py b/spateo/tools/spatial_smooth/smooth.py deleted file mode 100644 index a4387bb1..00000000 --- a/spateo/tools/spatial_smooth/smooth.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -Denoising and imputation of sparse spatial transcriptomics data - - -Note that this functionality requires PyTorch >= 1.8.0 -Also note that this provides an alternative method for finding spatial domains (not yet fully implemented) -""" -import os -import random -from typing import Union - -import numpy as np -import scipy -import torch -import torch.nn.functional as F -from anndata import AnnData -from torch import FloatTensor, Tensor, nn -from torch.backends import cudnn -from tqdm import tqdm - -from ...configuration import SKM -from ...logging import logger_manager as lm -from ..find_neighbors import construct_nn_graph, normalize_adj -from .smooth_model import Encoder - - -# -------------------------------------------- Tensor operations -------------------------------------------- # -def permutation(feature: FloatTensor) -> Tensor: - """Given counts matrix in tensor form, return counts matrix with scrambled rows/spot names""" - ids = np.arange(feature.shape[0]) - ids = np.random.permutation(ids) - feature_permutated = feature[ids] - - return feature_permutated - - -@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata") -def get_aug_feature(adata: AnnData, highly_variable: bool = False): - """From AnnData object, get counts matrix, augment it and store both as .obsm entries - - Args: - adata: Source AnnData object - highly_variable: Set True to subset to highly-variable genes - """ - if highly_variable: - adata_Vars = adata[:, adata.var["highly_variable"]] - else: - adata_Vars = adata - - if isinstance(adata_Vars.X, scipy.sparse.csc_matrix) or isinstance(adata_Vars.X, scipy.sparse.csr_matrix): - expr = adata_Vars.X.toarray()[ - :, - ] - else: - expr = adata_Vars.X[ - :, - ] - - # Data augmentation: - expr_permuted = permutation(expr) - - adata.obsm["expr"] = expr - adata.obsm["expr_permuted"] = expr_permuted - - -def fix_seed(seed: int = 888): - """Set seeds for all random number generators using 'seed' parameter (defaults to 888)""" - os.environ["PYTHONHASHSEED"] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - cudnn.deterministic = True - cudnn.benchmark = False - - os.environ["PYTHONHASHSEED"] = str(seed) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - - -def add_contrastive_label(adata): - """Creates array with 1 and 0 labels for each spot- for contrastive learning""" - n_spot = adata.n_obs - one_matrix = np.ones([n_spot, 1]) - zero_matrix = np.zeros([n_spot, 1]) - label_contrastive = np.concatenate([one_matrix, zero_matrix], axis=1) - adata.obsm["label_contrastive"] = label_contrastive - - -class STGNN: - """ - Graph neural network for representation learning of spatial transcriptomics data from only the gene expression - matrix. Wraps preprocessing and training. - - adata: class `anndata.AnnData` - spatial_key: Key in .obsm where x- and y-coordinates are stored - random_seed: Sets seed for all random number generators - add_regularization: Set True to include weight-based penalty term in representation learning. - device: Options: 'cpu', 'cuda:_'. Perform computations on CPU or GPU. If GPU, provide the name of the device to run - computations - """ - - def __init__( - self, - adata: AnnData, - spatial_key: str = "spatial", - random_seed: int = 50, - add_regularization: bool = True, - device: str = "cpu", - ): - self.adata = adata.copy() - self.random_seed = random_seed - self.add_regularization = add_regularization - self.device = torch.device(device) - - fix_seed(self.random_seed) - construct_nn_graph(self.adata, spatial_key=spatial_key) - add_contrastive_label(self.adata) - - self.adata_output = self.adata.copy() - - def train_STGNN(self, **kwargs): - """ - Args: - kwargs: Arguments that can be passed to :class `Trainer`. - - Returns: - adata_output: AnnData object with the smoothed values stored in a layer, either "X_smooth_gcn" or - "X_smooth_gcn_reg". - - """ - # Activation function for GNN: - act = kwargs.get("act", "relu") - # Dictionary to convert string input to 'act' to PyTorch activation function: - act_dict = {"linear": F.linear, "sigmoid": F.sigmoid, "tanh": F.tanh, "relu": F.relu, "elu": F.elu} - kwargs["act"] = act_dict[act] - - if self.add_regularization: - # Compute two versions of embedding and store as separate entries in .obsm: - adata = self.adata_output.copy() - get_aug_feature(adata) - model = Trainer(adata, device=self.device) - # Adjust arguments based on .vars(): - for var in kwargs.keys(): - if var in vars(model).keys(): - model.var = kwargs[var] - - emb = model.train() - - # Save reconstruction to layers: - self.adata_output.layers["X_smooth_gcn"] = emb - - # Reset random seed so that the model follows the same initialization procedures - fix_seed(self.random_seed) - adata = self.adata_output.copy() - get_aug_feature(adata) - model = Trainer(adata, add_regularization=True, device=self.device) - emb_regularization = model.train() - - self.adata_output.layers["X_smooth_gcn_reg"] = emb_regularization - - else: - get_aug_feature(self.adata_output) - model = Trainer(self.adata_output, device=self.device) - # Adjust arguments based on .vars(): - for var in kwargs.keys(): - if var in vars(model).keys(): - setattr(model, var, kwargs[var]) - - emb = model.train() - - self.adata_output.layers["X_smooth_gcn"] = emb - - return self.adata_output - - -class Trainer: - """ - Graph neural network training module. - - Args: - adata: class `anndata.AnnData` - device: torch.device object - learn_rate: Controls magnitude of gradient for network learning - dropout: Proportion of weights in each layer to set to 0 - act: String specifying activation function for each encoder layer. Options: "linear", "sigmoid", "tanh", - "relu", "elu" - clip: Threshold below which imputed feature values will be set to 0, as a percentile - weight_decay: Controls degradation rate of parameters - epochs: Number of iterations of training loop to perform - dim_output: Dimensionality of the output representation - gamma_1: Controls influence of reconstruction loss in representation learning - gamma_2: Weight factor to control the influence of contrastive loss in representation learning - gamma_3: Weight factor to control the influence of the regularization term in representation learning - add_regularization: Adds penalty term to representation learning - """ - - def __init__( - self, - adata: AnnData, - device: "torch.device", - learn_rate: float = 0.001, - dropout: float = 0.0, - act=F.relu, - clip: Union[None, float] = 0.25, - weight_decay: float = 0.00, - epochs: int = 1000, - dim_output: int = 64, - gamma_1: float = 10, - gamma_2: float = 1, - gamma_3: float = 0.1, - add_regularization: bool = False, - ): - self.adata = adata.copy() - self.device = device - self.learn_rate = learn_rate - self.dropout = dropout - self.act = act - self.clip = clip - self.weight_decay = weight_decay - self.epochs = epochs - self.gamma_1 = gamma_1 - self.gamma_2 = gamma_2 - self.gamma_3 = gamma_3 - self.add_regularization = add_regularization - - self.expr = torch.FloatTensor(adata.obsm["expr"].copy()).to(self.device) - self.expr_permuted = torch.FloatTensor(adata.obsm["expr_permuted"].copy()).to(self.device) - self.label_contrastive = torch.FloatTensor(adata.obsm["label_contrastive"]).to(self.device) - self.adj = adata.obsm["adj"] - self.graph_neigh = torch.FloatTensor(adata.obsm["graph_neigh"].copy() + np.eye(self.adj.shape[0])).to( - self.device - ) - - self.dim_input = self.expr.shape[1] - self.dim_output = dim_output - - # Further preprocessing on the adjacency matrix: - self.adj = normalize_adj(self.adj) - self.adj = torch.FloatTensor(self.adj).to(self.device) - - def train(self): - """ - Returns - ------- - emb_rec : np.ndarray - Reconstruction of the counts matrix - """ - logger = lm.get_main_logger() - logger.info( - f"Training graph neural network model with learn rate: {self.learn_rate} for {self.epochs} epochs, " - f"dropout rate: {self.dropout} and clipping threshold percentile: {self.clip}." - ) - - self.model = Encoder(self.dim_input, self.dim_output, self.graph_neigh, self.dropout, self.act, self.clip).to( - self.device - ) - self.loss_contrastive = nn.BCEWithLogitsLoss() - - self.optimizer = torch.optim.Adam(self.model.parameters(), self.learn_rate, weight_decay=self.weight_decay) - - self.model.train() - - for epoch in tqdm(range(self.epochs)): - self.model.train() - # Construct augmented graph (negative pair w/ the target graph) and then feed augmented graph and - # original graph through the model: - self.expr_a = permutation(self.expr) - self.hidden_feat, self.emb, norm_graph, permuted = self.model(self.expr, self.expr_a, self.adj) - - self.loss_cont_true_graph = self.loss_contrastive(norm_graph, self.label_contrastive) - self.loss_cont_permuted = self.loss_contrastive(permuted, self.label_contrastive) - self.loss_feat = F.mse_loss(self.expr, self.emb) - - if self.add_regularization: - self.loss_norm = 0 - for name, parameters in self.model.named_parameters(): - if name in ["weight1", "weight2"]: - self.loss_norm = self.loss_norm + torch.norm(parameters, p=2) - - loss = ( - self.gamma_1 * self.loss_feat - + self.gamma_2 * (self.loss_cont_true_graph + self.loss_cont_permuted) - + self.gamma_3 * self.loss_norm - ) - else: - loss = self.gamma_1 * self.loss_feat + self.gamma_2 * ( - self.loss_cont_true_graph + self.loss_cont_permuted - ) - - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - with torch.no_grad(): - self.model.eval() - # Return reconstruction: - self.emb_rec = self.model(self.expr, self.expr_a, self.adj)[1].detach().cpu().numpy() - - return self.emb_rec diff --git a/spateo/tools/spatial_smooth/smooth_model.py b/spateo/tools/spatial_smooth/smooth_model.py deleted file mode 100644 index 51149041..00000000 --- a/spateo/tools/spatial_smooth/smooth_model.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import FloatTensor -from torch.nn.modules.module import Module -from torch.nn.parameter import Parameter - - -class Discriminator(nn.Module): - """Module that learns associations between graph embeddings and their positively-labeled augmentations - - Args: - nf: Dimensionality (along the feature axis) of the input array - """ - - def __init__(self, nf: int): - super(Discriminator, self).__init__() - self.f_k = nn.Bilinear(nf, nf, 1) - - for m in self.modules(): - self.weights_init(m) - - def weights_init(self, m): - if isinstance(m, nn.Bilinear): - torch.nn.init.xavier_uniform_(m.weight.data) - if m.bias is not None: - m.bias.data.fill_(0.0) - - def forward(self, g_repr: FloatTensor, g_pos: FloatTensor, g_neg: FloatTensor): - """Feeds data forward through network and computes graph representations - - Args: - g_repr: Representation of source graph, with aggregated neighborhood representations - g_pos : Representation of augmentation of the source graph that can be considered a positive pairing, - with aggregated neighborhood representations - g_neg: Representation of augmentation of the source graph that can be considered a negative pairing, - with aggregated neighborhood representations - - Returns: - logits: Similarity score for the positive and negative paired graphs - """ - c_x = g_repr.expand_as(g_pos) - - sc_1 = self.f_k(g_pos, c_x) - sc_2 = self.f_k(g_neg, c_x) - - logits = torch.cat((sc_1, sc_2), 1) - - return logits - - -class AvgReadout(nn.Module): - """ - Aggregates graph embedding information over graph neighborhoods to obtain global representation of the graph - """ - - def __init__(self): - super(AvgReadout, self).__init__() - - def forward(self, emb: FloatTensor, mask: FloatTensor): - """ - Args: - emb : float tensor - Graph embedding - mask : float tensor - Selects elements to aggregate for each row - """ - vsum = torch.mm(mask, emb) - row_sum = torch.sum(mask, 1) - row_sum = row_sum.expand((vsum.shape[1], row_sum.shape[0])).T - global_emb = vsum / row_sum - - return F.normalize(global_emb, p=2, dim=1) - - -class Encoder(Module): - """Representation learning for spatial transcriptomics data - - Args: - in_features: Number of features in the dataset - out_features: Size of the desired encoding - graph_neigh: Pairwise adjacency matrix indicating which spots are neighbors of which other spots - dropout: Proportion of weights in each layer to set to 0 - act: object of class `torch.nn.functional`, default `F.relu`. Activation function for each encoder layer - clip: Threshold below which imputed feature values will be set to 0, as a percentile of the max value - """ - - def __init__( - self, - in_features: int, - out_features: int, - graph_neigh: FloatTensor, - dropout: float = 0.0, - act=F.relu, - clip: Union[None, float] = None, - ): - super(Encoder, self).__init__() - - self.in_features = in_features - self.out_features = out_features - self.graph_neigh = graph_neigh - self.dropout = dropout - self.act = act - self.clip = clip - - self.weight1 = Parameter(torch.FloatTensor(self.in_features, self.out_features)) - self.weight2 = Parameter(torch.FloatTensor(self.out_features, self.in_features)) - self.reset_parameters() - - self.disc = Discriminator(self.out_features) - - self.sigm = nn.Sigmoid() - self.read = AvgReadout() - - def reset_parameters(self): - torch.nn.init.xavier_uniform_(self.weight1) - torch.nn.init.xavier_uniform_(self.weight2) - - def forward(self, feat: FloatTensor, feat_a: FloatTensor, adj: FloatTensor): - """ - Args: - feat: Counts matrix - feat_a: Counts matrix following permutation and augmentation - adj: Pairwise distance matrix - """ - z = F.dropout(feat, self.dropout, self.training) - z = torch.mm(z, self.weight1) - z = torch.mm(adj, z) - - hidden_emb = z - - h = torch.mm(z, self.weight2) - h = torch.mm(adj, h) - - # Clipping constraint: - if self.clip is not None: - thresh = torch.quantile(h, self.clip, dim=0) - mask = h < thresh - h[mask] = 0 - # Non-negativity constraint: - nz_mask = h < 0 - h[nz_mask] = 0 - - emb = self.act(z) - - # Adversarial learning: - z_a = F.dropout(feat_a, self.dropout, self.training) - z_a = torch.mm(z_a, self.weight1) - z_a = torch.mm(adj, z_a) - emb_a = self.act(z_a) - - g = self.read(emb, self.graph_neigh) - g = self.sigm(g) - - g_a = self.read(emb_a, self.graph_neigh) - g_a = self.sigm(g_a) - - ret = self.disc(g, emb, emb_a) - ret_a = self.disc(g_a, emb_a, emb) - - return hidden_emb, h, ret, ret_a