diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index a4b5ab86..dad8cdff 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit a4b5ab864ea6ceb83eea3f22ba002aba892bc069 +Subproject commit dad8cdff03ecee6a4b96a4248528de9f97dcc073 diff --git a/requirements.txt b/requirements.txt index 27d01cd4..71a7204e 100755 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,6 @@ numpy>=1.18.1<=2.10.0 opencv-python>=4.5.4.60 pandana pandas>=0.25.1 -PATSY>=0.5.1 plotly>=5.1.0 POT>=0.8.1 pynndescent>=0.4.8 diff --git a/spateo/configuration.py b/spateo/configuration.py index 85913945..1f38703f 100755 --- a/spateo/configuration.py +++ b/spateo/configuration.py @@ -570,8 +570,8 @@ def set_figure_params( ): """Set resolution/size, styling and format of figures. This function is adapted from: https://github.com/theislab/scanpy/blob/f539870d7484675876281eb1c475595bf4a69bdb/scanpy/_settings.py - Arguments - --------- + + Args: spateo: `bool` (default: `True`) Init default values for :obj:`matplotlib.rcParams` suited for spateo. background: `str` (default: `white`) diff --git a/spateo/plotting/static/baseplot.py b/spateo/plotting/static/baseplot.py new file mode 100644 index 00000000..50189412 --- /dev/null +++ b/spateo/plotting/static/baseplot.py @@ -0,0 +1,182 @@ +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +from anndata import AnnData +from matplotlib import rcParams +from matplotlib.colors import to_hex + +from spateo.tools.utils import update_dict + +from ...configuration import SKM +from .utils import _select_font_color, save_fig + + +@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE) +class BasePlot: + def __init__( + self, + adata: AnnData, + color: Union[str, list] = "ntr", + layer: Union[str, list] = "X", + basis: Union[str, list] = "umap", + slices: Union[str, list] = None, + slices_split: bool = False, + slices_key: str = "slices", + stack_colors=False, + stack_colors_threshold=0.001, + stack_colors_title="stacked colors", + stack_colors_legend_size=2, + stack_colors_cmaps=None, + ncols: int = 4, + aspect: str = "auto", + axis_on: bool = False, + background: Optional[str] = None, + dpi: int = 100, + figsize: tuple = (6, 4), + gridspec: bool = True, + pointsize: Optional[int] = None, + save_show_or_return: str = "show", + save_kwargs: Optional[dict] = None, + show_legend="on data", + theme: Optional[str] = None, + ): + self.adata = adata.copy() + self.stack_colors = stack_colors + self.stack_colors_threshold = stack_colors_threshold + self.stack_colors_title = stack_colors_title + self.stack_colors_legend_size = stack_colors_legend_size + self.show_legend = show_legend + self.aspect = aspect + self.axis_on = axis_on + self.dpi = dpi + self.figsize = figsize + self.save_show_or_return = save_show_or_return + self.save_kwargs = save_kwargs + self.slices_split = slices_split + self.slices_key = slices_key + self.theme = theme + self.basis = self._check_iterable(basis) + self.color = self._check_iterable(color) + self.layer = self._check_iterable(layer) + if slices is None and slices_split: + self.slices = self.adata.obs[self.slices_key].unique().tolist() + self.slices = self._check_iterable(slices) + self.prefix = "baseplot" + + if background is None: + _background = rcParams.get("figure.facecolor") + self._background = to_hex(_background) if type(_background) is tuple else _background + else: + self._background = background + self.font_color = _select_font_color(self._background) + + if stack_colors and stack_colors_cmaps is None: + self.stack_colors_cmaps = [ + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ] + self.stack_legend_handles = [] + if stack_colors: + self.color_key = None + + n_s = len(self.slices) if slices_split else 1 + n_c = len(self.color) if not stack_colors else 1 + n_l = len(self.layer) + n_b = len(self.basis) + total_panels, ncols = ( + n_s * n_c * n_l * n_b, + min(max([n_s, n_c, n_l, n_b]), ncols), + ) + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + + if pointsize is None: + self.pointsize = 16000.0 / np.sqrt(adata.shape[0]) + else: + self.pointsize = 16000.0 / np.sqrt(adata.shape[0]) * pointsize + + if gridspec: + if total_panels > 1: + self.fig = plt.figure( + None, + (figsize[0] * ncol, figsize[1] * nrow), + facecolor=self._background, + dpi=self.dpi, + ) + self.gs = plt.GridSpec(nrow, ncol, wspace=0.12) + else: + self.fig, ax = plt.subplots(figsize=figsize) + self.gs = [ax] + self.ax_index = 0 + + def plot(self): + if self.slices_split: + for cur_s in self.slices: + adata = self.adata[self.adata.obs[self.slices_key] == cur_s, :] + for cur_b in self.basis: + for cur_l in self.layer: + for cur_c in self.color: + self._plot_basis_layer(adata, cur_c, cur_b, cur_l) + if not self.stack_colors: + self.ax_index += 1 + if self.stack_colors: + self.ax_index += 1 + + else: + for cur_b in self.basis: + for cur_l in self.layer: + for cur_c in self.color: + self._plot_basis_layer(self.adata, cur_c, cur_b, cur_l) + if not self.stack_colors: + self.ax_index += 1 + if self.stack_colors: + self.ax_index += 1 + + clf = self._save_show_or_return() + return clf + + def _plot_basis_layer(self, *args, **kwargs): + raise NotImplementedError + + def _save_show_or_return(self): + if self.save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": self.prefix, + "dpi": self.dpi, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + s_kwargs = update_dict(s_kwargs, self.save_kwargs) + + save_fig(**s_kwargs) + elif self.save_show_or_return in ["show", "both", "all"]: + if self.show_legend: + plt.subplots_adjust(right=0.85) + plt.show() + elif self.save_show_or_return in ["return", "all"]: + return plt.clf() + + def _check_iterable(self, arg): + if arg is None or isinstance(arg, str): + return [arg] + else: + return list(arg) diff --git a/spateo/plotting/static/heatmap.py b/spateo/plotting/static/heatmap.py new file mode 100644 index 00000000..c83bc0ac --- /dev/null +++ b/spateo/plotting/static/heatmap.py @@ -0,0 +1,169 @@ +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from anndata import AnnData + +from ...configuration import SKM +from .baseplot import BasePlot +from .utils import _to_hex + + +@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE) +class HeatMap(BasePlot): + def __init__( + self, + adata: AnnData, + markers: list, + group: Optional[str] = None, + group_mean: bool = False, + group_cmap: str = "tab20", + col_cluster: bool = False, + row_cluster: bool = False, + layer: Union[str, list] = "X", + slices: Union[str, list] = None, + slices_split: bool = False, + slices_key: str = "slices", + background: Optional[str] = None, + dpi: int = 100, + figsize: tuple = (11, 5), + save_show_or_return: str = "show", + save_kwargs: Optional[dict] = None, + swap_axis: bool = False, + cbar_pos: Optional[tuple] = None, + theme: Optional[str] = None, + cmap: str = "viridis", + **kwargs + ): + super().__init__( + adata=adata, + color=[markers], + basis=group, + layer=layer, + slices=slices, + slices_split=slices_split, + slices_key=slices_key, + background=background, + dpi=dpi, + figsize=figsize, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + theme=theme, + gridspec=False, + ) + self.group_mean = group_mean + self.group_cmap = group_cmap + self.col_cluster = col_cluster + self.row_cluster = row_cluster + self.cmap = cmap + self.cbar_pos = cbar_pos + self.swap_axis = swap_axis + self.kwargs = kwargs + + def _plot_basis_layer(self, adata: AnnData, markers, cells_group, cur_l): + value_df, colors = self._fetch_data(adata, markers, cells_group, cur_l) + + if self.swap_axis: + value_df = value_df.T + + heatmap_kwargs = dict( + xticklabels=1, + yticklabels=False, + col_colors=colors if self.swap_axis else None, + row_colors=None if self.swap_axis else colors, + row_linkage=None, + col_linkage=None, + method="average", + metric="euclidean", + z_score=None, + standard_scale=None, + cbar_pos=self.cbar_pos, + ) + if self.kwargs is not None: + heatmap_kwargs.update(self.kwargs) + + sns_heatmap = sns.clustermap( + value_df, + col_cluster=self.col_cluster, + row_cluster=self.row_cluster, + cmap=self.cmap, + figsize=self.figsize, + **heatmap_kwargs, + ) + + # if not self.show_legend: + # sns_heatmap.cax.set_visible(False) + + def _fetch_data(self, adata: AnnData, markers, cells_group, cur_l): + layer = None if cur_l == "X" else cur_l + value_df = pd.DataFrame() + for i, marker in enumerate(markers): + v = adata.obs_vector(marker, layer=layer) + value_df[marker] = v + value_df.index = adata.obs.index + colors = None + if cells_group is not None: + value_df[cells_group] = adata.obs_vector(cells_group, layer=layer) + value_df = value_df.sort_values(cells_group) + if self.group_mean: + value_df = value_df.groupby(cells_group, as_index=False).mean() + num_labels = len(value_df[cells_group].unique()) + + color_key = _to_hex(plt.get_cmap(self.group_cmap)(np.linspace(0, 1, num_labels))) + cell_lut = dict(zip(value_df[cells_group].unique().tolist(), color_key)) + colors = value_df[cells_group].map(cell_lut) + value_df = value_df.drop(cells_group, axis=1) + + return value_df, colors + + +@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE) +def heatmap( + adata: AnnData, + markers: list, + group: Optional[str] = None, + group_mean: bool = False, + group_cmap: str = "tab20", + col_cluster: bool = False, + row_cluster: bool = False, + layer: Union[str, list] = "X", + slices: Union[str, list] = None, + slices_split: bool = False, + slices_key: str = "slices", + background: Optional[str] = None, + dpi: int = 100, + figsize: tuple = (11, 5), + save_show_or_return: str = "show", + save_kwargs: Optional[dict] = None, + swap_axis: bool = False, + cbar_pos: Optional[tuple] = None, + theme: Optional[str] = None, + cmap: str = "viridis", + **kwargs +): + hm = HeatMap( + adata=adata, + markers=markers, + group=group, + group_mean=group_mean, + group_cmap=group_cmap, + col_cluster=col_cluster, + row_cluster=row_cluster, + layer=layer, + slices=slices, + slices_split=slices_split, + slices_key=slices_key, + background=background, + dpi=dpi, + figsize=figsize, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + swap_axis=swap_axis, + cbar_pos=cbar_pos, + theme=theme, + cmap=cmap, + **kwargs, + ) + return hm.plot() diff --git a/spateo/plotting/static/space.py b/spateo/plotting/static/space.py index f51d8fa7..100a22f1 100644 --- a/spateo/plotting/static/space.py +++ b/spateo/plotting/static/space.py @@ -40,10 +40,9 @@ def space( *args, **kwargs ): - """\ - Scatter plot for physical coordinates of each cell. - Parameters - ---------- + """Scatter plot for physical coordinates of each cell. + + Args: adata: an Annodata object that contain the physical coordinates for each bin/cell, etc. genes: @@ -83,8 +82,8 @@ def space( ps_sample_num: `int` The number of bins / cells that will be sampled to estimate the distance between different bin / cells. %(scatters.parameters.no_adata|basis|figsize)s - Returns - ------- + + Returns: plots gene or cell feature of the adata object on the physical spatial coordinates. """ # main_info("Plotting spatial info on adata")