diff --git a/docs/_static/docstring_previews/scgen_reg_mean.png b/docs/_static/docstring_previews/scgen_reg_mean.png new file mode 100644 index 00000000..2b76a5ca Binary files /dev/null and b/docs/_static/docstring_previews/scgen_reg_mean.png differ diff --git a/pertpy/plot/_augur.py b/pertpy/plot/_augur.py index 8b818647..1ce0b686 100644 --- a/pertpy/plot/_augur.py +++ b/pertpy/plot/_augur.py @@ -16,7 +16,7 @@ class AugurpyPlot: """Plotting functions for Augurpy.""" @staticmethod - def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False) -> Figure | Axes: + def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None) -> Figure | Axes: """Plot result of differential prioritization. Args: @@ -56,11 +56,11 @@ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure ag = Augur("random_forest_classifier") - return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax, return_figure=return_figure) + return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax) @staticmethod def important_features( - data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False + data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None ) -> Figure | Axes: """Plot a lollipop plot of the n features with largest feature importances. @@ -95,12 +95,10 @@ def important_features( ag = Augur("random_forest_classifier") - return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax, return_figure=return_figure) + return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax) @staticmethod - def lollipop( - data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False - ) -> Figure | Axes: + def lollipop(data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None) -> Figure | Axes | None: """Plot a lollipop plot of the mean augur values. Args: @@ -133,12 +131,10 @@ def lollipop( ag = Augur("random_forest_classifier") - return ag.plot_lollipop(data=data, key=key, ax=ax, return_figure=return_figure) + return ag.plot_lollipop(data=data, key=key, ax=ax) @staticmethod - def scatterplot( - results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False - ) -> Figure | Axes: + def scatterplot(results1: dict[str, Any], results2: dict[str, Any], top_n=None) -> Figure | Axes: """Create scatterplot with two augur results. Args: @@ -172,4 +168,4 @@ def scatterplot( ag = Augur("random_forest_classifier") - return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n, return_figure=return_figure) + return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n) diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index 9ff2d688..1464f670 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -429,8 +429,7 @@ def draw_tree( # pragma: no cover show=show, save=save, units=units, - h=h, - w=w, + figsize=(w, h), dpi=dpi, ) @@ -517,8 +516,7 @@ def draw_effects( # pragma: no cover show=show, save=save, units=units, - h=h, - w=w, + figsize=(w, h), dpi=dpi, ) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index efc60b19..ac135fbf 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -976,8 +976,13 @@ def predict_differential_prioritization( return delta def plot_dp_scatter( - self, results: pd.DataFrame, top_n: int = None, ax: Axes = None, return_figure: bool = False - ) -> Figure | Axes: + self, + results: pd.DataFrame, + top_n: int = None, + ax: Axes = None, + show: bool | None = None, + save: str | bool | None = None, + ) -> Figure | Axes | None: """Plot scatterplot of differential prioritization. Args: @@ -1034,7 +1039,14 @@ def plot_dp_scatter( legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5)) ax.add_artist(legend1) - return fig if return_figure else ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax def plot_important_features( self, @@ -1042,8 +1054,9 @@ def plot_important_features( key: str = "augurpy_results", top_n: int = 10, ax: Axes = None, - return_figure: bool = False, - ) -> Figure | Axes: + show: bool | None = None, + save: str | bool | None = None, + ) -> Figure | Axes | None: """Plot a lollipop plot of the n features with largest feature importances. Args: @@ -1095,11 +1108,23 @@ def plot_important_features( plt.ylabel("Gene") plt.yticks(y_axes_range, n_features["genes"]) - return fig if return_figure else ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax def plot_lollipop( - self, data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False - ) -> Figure | Axes: + self, + data: dict[str, Any], + key: str = "augurpy_results", + ax: Axes = None, + show: bool | None = None, + save: str | bool | None = None, + ) -> Figure | Axes | None: """Plot a lollipop plot of the mean augur values. Args: @@ -1147,11 +1172,23 @@ def plot_lollipop( plt.ylabel("Cell Type") plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns) - return fig if return_figure else ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax def plot_scatterplot( - self, results1: dict[str, Any], results2: dict[str, Any], top_n: int = None, return_figure: bool = False - ) -> Figure | Axes: + self, + results1: dict[str, Any], + results2: dict[str, Any], + top_n: int = None, + show: bool | None = None, + save: str | bool | None = None, + ) -> Figure | Axes | None: """Create scatterplot with two augur results. Args: @@ -1207,4 +1244,11 @@ def plot_scatterplot( plt.xlabel("Augur scores 1") plt.ylabel("Augur scores 2") - return fig if return_figure else ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index 4738a7d3..e4a9f301 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -26,11 +26,14 @@ from scipy.cluster import hierarchy as sp_hierarchy if TYPE_CHECKING: + from collections.abc import Sequence + import numpyro as npy import toytree as tt from ete3 import Tree from jax._src.typing import Array from matplotlib.axes import Axes + from matplotlib.colors import Colormap config.update("jax_enable_x64", True) @@ -1185,11 +1188,11 @@ def plot_stacked_barplot( # pragma: no cover data: AnnData | MuData, feature_name: str, modality_key: str = "coda", - figsize: tuple[float, float] | None = None, - dpi: int | None = 100, palette: ListedColormap | None = cm.tab20, show_legend: bool | None = True, level_order: list[str] = None, + figsize: tuple[float, float] | None = None, + dpi: int | None = 100, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, @@ -1288,11 +1291,11 @@ def plot_effects_barplot( # pragma: no cover plot_facets: bool = True, plot_zero_covariate: bool = True, plot_zero_cell_type: bool = False, - figsize: tuple[float, float] | None = None, - dpi: int | None = 100, palette: str | ListedColormap | None = cm.tab20, level_order: list[str] = None, args_barplot: dict | None = None, + figsize: tuple[float, float] | None = None, + dpi: int | None = 100, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, @@ -1492,11 +1495,11 @@ def plot_boxplots( # pragma: no cover cell_types: list | None = None, args_boxplot: dict | None = None, args_swarmplot: dict | None = None, - figsize: tuple[float, float] | None = None, - dpi: int | None = 100, palette: str | None = "Blues", show_legend: bool | None = True, level_order: list[str] = None, + figsize: tuple[float, float] | None = None, + dpi: int | None = 100, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, @@ -1738,7 +1741,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, - ) -> plt.Axes: + ) -> plt.Axes | None: """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type. If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color. @@ -1858,12 +1861,11 @@ def plot_draw_tree( # pragma: no cover tight_text: bool | None = False, show_scale: bool | None = False, units: Literal["px", "mm", "in"] | None = "px", - h: float | None = None, - w: float | None = None, - dpi: int | None = 90, + figsize: tuple[float, float] | None = None, + dpi: int | None = 100, show: bool | None = True, save: str | bool | None = None, - ): + ) -> Tree | None: """Plot a tree using input ete3 tree object. Args: @@ -1882,11 +1884,9 @@ def plot_draw_tree( # pragma: no cover file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not. Defaults to None. - units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. - Defaults to "px". - h: Height of the image in units. Defaults to None. - w: Width of the image in units. Defaults to None. - dpi: Dots per inches. Defaults to 90. + units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px". + figsize: Figure size. Defaults to None. + dpi: Dots per inches. Defaults to 100. Returns: Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`) @@ -1931,10 +1931,11 @@ def my_layout(node): tree_style.show_leaf_name = False tree_style.layout_fn = my_layout tree_style.show_scale = show_scale + if save is not None: - tree.render(save, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore + tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore if show: - return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore + return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore else: return tree, tree_style @@ -1948,12 +1949,11 @@ def plot_draw_effects( # pragma: no cover show_leaf_effects: bool | None = False, tight_text: bool | None = False, show_scale: bool | None = False, + units: Literal["px", "mm", "in"] | None = "px", + figsize: tuple[float, float] | None = None, + dpi: int | None = 100, show: bool | None = True, save: str | None = None, - units: Literal["px", "mm", "in"] | None = "in", - h: float | None = None, - w: float | None = None, - dpi: int | None = 90, ): """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects. @@ -1975,10 +1975,9 @@ def plot_draw_effects( # pragma: no cover show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True. file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not. Defaults to None. - units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in". - h: Height of the image in units. Defaults to None. - w: Width of the image in units. Defaults to None. - dpi: Dots per inches. Defaults to 90. + units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px". + figsize: Figure size. Defaults to None. + dpi: Dots per inches. Defaults to 100. Returns: Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) @@ -2130,21 +2129,23 @@ def my_layout(node): tree2.render(save, tree_style=tree_style, units=units) if show: if not show_leaf_effects: - return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) + return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) else: if not show_leaf_effects: return tree2, tree_style def plot_effects_umap( # pragma: no cover self, - data: MuData, + mdata: MuData, effect_name: str | list | None, cluster_key: str, modality_key_1: str = "rna", modality_key_2: str = "coda", + color_map: Colormap | str | None = None, + palette: str | Sequence[str] | None = None, + ax: Axes = None, show: bool = None, save: str | bool | None = None, - ax: Axes = None, **kwargs, ) -> plt.Axes | None: """Plot a UMAP visualization colored by effect strength. @@ -2153,7 +2154,7 @@ def plot_effects_umap( # pragma: no cover (default is data['rna']) depending on the cluster they were assigned to. Args: - data: AnnData object or MuData object. + mudata: MuData object. effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']). To assign cell types' effects to original cells. @@ -2207,8 +2208,8 @@ def plot_effects_umap( # pragma: no cover .. image:: /_static/docstring_previews/tasccoda_effects_umap.png """ # TODO: Add effect_name parameter and cluster_key and test the example - data_rna = data[modality_key_1] - data_coda = data[modality_key_2] + data_rna = mdata[modality_key_1] + data_coda = mdata[modality_key_2] if isinstance(effect_name, str): effect_name = [effect_name] for _, effect in enumerate(effect_name): @@ -2224,7 +2225,18 @@ def plot_effects_umap( # pragma: no cover else: vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name)) - return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, save=save, **kwargs) + return sc.pl.umap( + data_rna, + color=effect_name, + vmax=vmax, + vmin=vmin, + palette=palette, + color_map=color_map, + ax=ax, + show=show, + save=save, + **kwargs, + ) def get_a( diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index 02750629..d51e4a27 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal import anndata as ad +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc @@ -1069,7 +1070,10 @@ def plot_split_violins( celltype_key: str, split_which: tuple[str, str] = None, mcp: str = "mcp_0", - ) -> Axes: + ax: Axes | None = None, + save: bool | str | None = None, + show: bool | None = None, + ) -> Axes | None: """Plots split violin plots for a given MCP and split variable. Any cells with a value for split_key not in split_which are removed from the plot. @@ -1107,10 +1111,24 @@ def plot_split_violins( ax.set_xticklabels(ax.get_xticklabels(), rotation=90) - return ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax def plot_pairplot( - self, adata: AnnData, celltype_key: str, color: str, sample_id: str, mcp: str = "mcp_0" + self, + adata: AnnData, + celltype_key: str, + color: str, + sample_id: str, + mcp: str = "mcp_0", + show: bool | None = None, + save: bool | str | None = None, ) -> PairGrid: """Generate a pairplot visualization for multi-cell perturbation (MCP) data. @@ -1150,4 +1168,11 @@ def plot_pairplot( mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1) ax = sns.pairplot(mcp_pivot, hue=color, corner=True) - return ax + if show: + plt.show() + return None + if save: + plt.savefig(save, bbox_inches="tight") + return None + elif not show or show is None: + return ax diff --git a/pertpy/tools/_enrichment.py b/pertpy/tools/_enrichment.py index f713af4d..724d2ab7 100644 --- a/pertpy/tools/_enrichment.py +++ b/pertpy/tools/_enrichment.py @@ -7,6 +7,7 @@ import pandas as pd import scanpy as sc from anndata import AnnData +from matplotlib.axes import Axes from scanpy.plotting import DotPlot from scanpy.tools._score_genes import _sparse_nanmean from scipy.sparse import issparse @@ -292,6 +293,9 @@ def plot_dotplot( categories: Sequence[str] = None, groupby: str = None, key: str = "pertpy_enrichment", + ax: Axes | None = None, + save: bool | str | None = None, + show: bool | None = None, **kwargs, ) -> DotPlot | dict | None: """Plots a dotplot by groupby and categories. @@ -369,7 +373,9 @@ def plot_dotplot( "var_group_labels": var_group_labels, } - return sc.pl.dotplot(enrichment_score_adata, groupby=groupby, swap_axes=True, **plot_args, **kwargs) + return sc.pl.dotplot( + enrichment_score_adata, groupby=groupby, swap_axes=True, ax=ax, save=save, show=show, **plot_args, **kwargs + ) def plot_gsea( self, @@ -413,5 +419,5 @@ def plot_gsea( n=n, interactive_plot=interactive_plot, ) - fig.suptitle(cluster) + fig.subtitle(cluster) fig.show() diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 6f2eaf4a..0e55d3a6 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -17,6 +17,9 @@ if TYPE_CHECKING: from collections.abc import Sequence + from matplotlib.axes import Axes + from matplotlib.colors import Colormap + try: from rpy2.robjects import conversion, numpy2ri, pandas2ri from rpy2.robjects.packages import STAP, PackageNotInstalledError, importr @@ -703,6 +706,9 @@ def plot_nhood_graph( min_size: int = 10, plot_edges: bool = False, title: str = "DA log-Fold Change", + color_map: Colormap | str | None = None, + palette: str | Sequence[str] | None = None, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, **kwargs, @@ -774,6 +780,9 @@ def plot_nhood_graph( vmax=vmax, vmin=vmin, title=title, + color_map=color_map, + palette=palette, + ax=ax, show=show, save=save, **kwargs, @@ -785,6 +794,9 @@ def plot_nhood( ix: int, feature_key: str | None = "rna", basis: str = "X_umap", + color_map: Colormap | str | None = None, + palette: str | Sequence[str] | None = None, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, **kwargs, @@ -815,7 +827,17 @@ def plot_nhood( """ mdata[feature_key].obs["Nhood"] = mdata[feature_key].obsm["nhoods"][:, ix].toarray().ravel() sc.pl.embedding( - mdata[feature_key], basis, color="Nhood", size=30, title="Nhood" + str(ix), show=show, save=save, **kwargs + mdata[feature_key], + basis, + color="Nhood", + size=30, + title="Nhood" + str(ix), + color_map=color_map, + palette=palette, + ax=ax, + show=show, + save=save, + **kwargs, ) def plot_da_beeswarm( @@ -826,6 +848,8 @@ def plot_da_beeswarm( alpha: float = 0.1, subset_nhoods: list[str] = None, palette: str | Sequence[str] | dict[str, str] | None = None, + save: bool | str | None = None, + show: bool | None = None, ) -> None: """Plot beeswarm plot of logFC against nhood labels @@ -931,12 +955,23 @@ def plot_da_beeswarm( plt.legend(loc="upper left", title=f"< {int(alpha * 100)}% SpatialFDR", bbox_to_anchor=(1, 1), frameon=False) plt.axvline(x=0, ymin=0, ymax=1, color="black", linestyle="--") + if save: + plt.savefig(save, bbox_inches="tight") + return None + if show: + plt.show() + return None + elif not show or show is None: + return plt.gcf() + def plot_nhood_counts_by_cond( self, mdata: MuData, test_var: str, subset_nhoods: list[str] = None, log_counts: bool = False, + save: bool | str | None = None, + show: bool | None = None, ) -> None: """Plot boxplot of cell numbers vs condition of interest. @@ -972,3 +1007,12 @@ def plot_nhood_counts_by_cond( plt.xticks(rotation=90) plt.xlabel(test_var) + + if save: + plt.savefig(save, bbox_inches="tight") + return None + if show: + plt.show() + return None + elif not show or show is None: + return plt.gcf() diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index f87e00d5..d6393c51 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -4,11 +4,11 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Literal +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc import seaborn as sns -from matplotlib import pyplot as pl from scanpy import get from scanpy._settings import settings from scanpy._utils import _check_use_raw, sanitize_anndata @@ -24,6 +24,7 @@ from anndata import AnnData from matplotlib.axes import Axes + from matplotlib.colors import Colormap from scipy import sparse @@ -509,6 +510,7 @@ def plot_barplot( # pragma: no cover axis_title_size: int = 8, legend_title_size: int = 8, legend_text_size: int = 8, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, ): @@ -558,7 +560,7 @@ def plot_barplot( # pragma: no cover if not show: color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"} unique_genes = NP_KO_cells["gene"].unique() - fig, axs = pl.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) + fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) for i, gene in enumerate(unique_genes): ax = axs[int(i / 5), i % 5] grouped_df = ( @@ -592,10 +594,13 @@ def plot_barplot( # pragma: no cover fontsize=legend_text_size, title_fontsize=legend_title_size, ) - pl.tight_layout() + plt.tight_layout() _utils.savefig_or_show("mixscape_barplot", show=show, save=save) - return ax + if not show or show is None: + return ax + else: + return None def plot_heatmap( # pragma: no cover self, @@ -608,6 +613,8 @@ def plot_heatmap( # pragma: no cover subsample_number: int | None = 900, vmin: float | None = -2, vmax: float | None = 2, + fontsize: int | None = 8, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, **kwds, @@ -660,6 +667,8 @@ def plot_heatmap( # pragma: no cover vmax=vmax, n_genes=20, groups=["NT"], + fontsize=fontsize, + ax=ax, show=show, save=save, **kwds, @@ -670,12 +679,15 @@ def plot_perturbscore( # pragma: no cover adata: AnnData, labels: str, target_gene: str, - mixscape_class="mixscape_class", - color="orange", + mixscape_class: str = "mixscape_class", + color: str = "orange", palette: dict[str, str] = None, split_by: str = None, - before_mixscape=False, + before_mixscape: bool = False, perturbation_type: str = "KO", + ax: Axes | None = None, + show: bool | None = None, + save: bool | str | None = None, ) -> None: """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. @@ -727,7 +739,7 @@ def plot_perturbscore( # pragma: no cover palette = {gd: "#7d7d7d", target_gene: color} plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False) top_r = max(plot_dens.get_lines()[cond].get_data()[1].max() for cond in range(len(plot_dens.get_lines()))) - pl.close() + plt.close() perturbation_score["y_jitter"] = perturbation_score["pvec"] rng = np.random.default_rng() perturbation_score.loc[perturbation_score[labels] == gd, "y_jitter"] = rng.uniform( @@ -738,7 +750,7 @@ def plot_perturbscore( # pragma: no cover ) # If split_by is provided, split densities based on the split_by if split_by is not None: - sns.set(style="whitegrid") + sns.set_theme(style="whitegrid") g = sns.FacetGrid( data=perturbation_score, col=split_by, hue=split_by, palette=palette, height=5, sharey=False ) @@ -750,26 +762,37 @@ def plot_perturbscore( # pragma: no cover # If split_by is not provided, create a single plot else: - sns.set(style="whitegrid") + sns.set_theme(style="whitegrid") sns.kdeplot( data=perturbation_score, x="pvec", hue="gene_target", fill=True, common_norm=False, palette=palette ) sns.scatterplot( data=perturbation_score, x="pvec", y="y_jitter", hue="gene_target", palette=palette, s=10, alpha=0.5 ) - pl.xlabel("Perturbation score", fontsize=16) - pl.ylabel("Cell density", fontsize=16) - pl.title("Density Plot", fontsize=18) - pl.legend(title="gene_target", title_fontsize=14, fontsize=12) + plt.xlabel("Perturbation score", fontsize=16) + plt.ylabel("Cell density", fontsize=16) + plt.title("Density Plot", fontsize=18) + plt.legend(title="gene_target", title_fontsize=14, fontsize=12) sns.despine() + ax = plt.gca() + + if save: + plt.savefig(save, bbox_inches="tight") + return None + if show: + plt.show() + return None + elif not show or show is None: + return ax + # If before_mixscape is False, split densities based on mixscape classifications else: if palette is None: palette = {gd: "#7d7d7d", f"{target_gene} NP": "#c9c9c9", f"{target_gene} {perturbation_type}": color} plot_dens = sns.kdeplot(data=perturbation_score, x="pvec", hue=labels, fill=False, common_norm=False) top_r = max(plot_dens.get_lines()[i].get_data()[1].max() for i in range(len(plot_dens.get_lines()))) - pl.close() + plt.close() perturbation_score["y_jitter"] = perturbation_score["pvec"] rng = np.random.default_rng() gd2 = list( @@ -788,7 +811,7 @@ def plot_perturbscore( # pragma: no cover ) # If split_by is provided, split densities based on the split_by if split_by is not None: - sns.set(style="whitegrid") + sns.set_theme(style="whitegrid") g = sns.FacetGrid( data=perturbation_score, col=split_by, hue="mix", palette=palette, height=5, sharey=False ) @@ -800,7 +823,7 @@ def plot_perturbscore( # pragma: no cover # If split_by is not provided, create a single plot else: - sns.set(style="whitegrid") + sns.set_theme(style="whitegrid") sns.kdeplot( data=perturbation_score, x="pvec", @@ -813,12 +836,23 @@ def plot_perturbscore( # pragma: no cover sns.scatterplot( data=perturbation_score, x="pvec", y="y_jitter", hue="mix", palette=palette, s=10, alpha=0.5 ) - pl.xlabel("Perturbation score", fontsize=16) - pl.ylabel("Cell density", fontsize=16) - pl.title("Density", fontsize=18) - pl.legend(title="mixscape class", title_fontsize=14, fontsize=12) + plt.xlabel("Perturbation score", fontsize=16) + plt.ylabel("Cell density", fontsize=16) + plt.title("Density", fontsize=18) + plt.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() + ax = plt.gca() + + if save: + plt.savefig(save, bbox_inches="tight") + return None + if show: + plt.show() + return None + elif not show or show is None: + return ax + def plot_violin( # pragma: no cover self, adata: AnnData, @@ -838,9 +872,9 @@ def plot_violin( # pragma: no cover xlabel: str = "", ylabel: str | Sequence[str] | None = None, rotation: float | None = None, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, - ax: Axes | None = None, **kwargs, ): """Violin plot using mixscape results. @@ -1011,7 +1045,7 @@ def plot_violin( # pragma: no cover show = settings.autoshow if show is None else show if hue is not None and stripplot is True: - pl.legend(handles, labels) + plt.legend(handles, labels) _utils.savefig_or_show("mixscape_violin", show=show, save=save) if not show: @@ -1031,6 +1065,9 @@ def plot_lda( # pragma: no cover perturbation_type: str | None = "KO", lda_key: str | None = "mixscape_lda", n_components: int | None = None, + color_map: Colormap | str | None = None, + palette: str | Sequence[str] | None = None, + ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, **kwds, @@ -1076,4 +1113,13 @@ def plot_lda( # pragma: no cover n_components = adata_subset.uns[lda_key].shape[1] sc.pp.neighbors(adata_subset, use_rep=lda_key) sc.tl.umap(adata_subset, n_components=n_components) - sc.pl.umap(adata_subset, color=mixscape_class, show=show, save=save, **kwds) + sc.pl.umap( + adata_subset, + color=mixscape_class, + palette=palette, + color_map=color_map, + show=show, + save=save, + ax=ax, + **kwds, + ) diff --git a/pertpy/tools/_scgen/_scgen.py b/pertpy/tools/_scgen/_scgen.py index befe202a..3fc42c3a 100644 --- a/pertpy/tools/_scgen/_scgen.py +++ b/pertpy/tools/_scgen/_scgen.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, Any import jax.numpy as jnp +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc from adjustText import adjust_text from anndata import AnnData from jax import Array -from matplotlib import pyplot from scipy import stats from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -373,20 +373,19 @@ def get_latent_representation( def plot_reg_mean_plot( self, adata, - condition_key, - axis_keys, - labels, - path_to_save="./reg_mean.pdf", - save=True, - gene_list=None, - show=False, - top_100_genes=None, - verbose=False, - legend=True, - title=None, - x_coeff=0.30, - y_coeff=0.8, - fontsize=14, + condition_key: str, + axis_keys: dict[str, str], + labels: dict[str, str], + save: str | bool | None = None, + gene_list: list[str] = None, + show: bool = False, + top_100_genes: list[str] = None, + verbose: bool = False, + legend: bool = True, + title: str = None, + x_coeff: float = 0.30, + y_coeff: float = 0.8, + fontsize: float = 14, **kwargs, ) -> tuple[float, float] | float: """Plots mean matching for a set of specified genes. @@ -423,11 +422,14 @@ def plot_reg_mean_plot( >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred) >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True) + + Preview: + .. image:: /_static/docstring_previews/scgen_reg_mean.png """ import seaborn as sns - sns.set() - sns.set(color_codes=True) + sns.set_theme() + sns.set_theme(color_codes=True) diff_genes = top_100_genes stim = adata[adata.obs[condition_key] == axis_keys["y"]] @@ -463,11 +465,11 @@ def plot_reg_mean_plot( j = adata.var_names.tolist().index(i) x_bar = x[j] y_bar = y[j] - texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color="black")) - pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5) + texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black")) + plt.plot(x_bar, y_bar, "o", color="red", markersize=5) # if "y1" in axis_keys.keys(): # y1_bar = y1[j] - # pyplot.text(x_bar, y1_bar, i, fontsize=11, color="black") + # plt.text(x_bar, y1_bar, i, fontsize=11, color="black") if gene_list is not None: adjust_text( texts, @@ -477,11 +479,11 @@ def plot_reg_mean_plot( force_static=(0.0, 0.0), ) if legend: - pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) if title is None: - pyplot.title("", fontsize=fontsize) + plt.title("", fontsize=fontsize) else: - pyplot.title(title, fontsize=fontsize) + plt.title(title, fontsize=fontsize) ax.text( max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y), @@ -496,10 +498,10 @@ def plot_reg_mean_plot( fontsize=kwargs.get("textsize", fontsize), ) if save: - pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100) + plt.savefig(save, bbox_inches="tight") if show: - pyplot.show() - pyplot.close() + plt.show() + plt.close() if diff_genes is not None: return r_value**2, r_value_diff**2 else: @@ -508,20 +510,19 @@ def plot_reg_mean_plot( def plot_reg_var_plot( self, adata, - condition_key, - axis_keys, - labels, - path_to_save="./reg_var.pdf", - save=True, - gene_list=None, - top_100_genes=None, - show=False, - legend=True, - title=None, - verbose=False, - x_coeff=0.30, - y_coeff=0.8, - fontsize=14, + condition_key: str, + axis_keys: dict[str, str], + labels: dict[str, str], + save: str | bool | None = None, + gene_list: list[str] = None, + top_100_genes: list[str] = None, + show: bool = False, + legend: bool = True, + title: str = None, + verbose: bool = False, + x_coeff: float = 0.3, + y_coeff: float = 0.8, + fontsize: float = 14, **kwargs, ) -> tuple[float, float] | float: """Plots variance matching for a set of specified genes. @@ -548,8 +549,8 @@ def plot_reg_var_plot( """ import seaborn as sns - sns.set() - sns.set(color_codes=True) + sns.set_theme() + sns.set_theme(color_codes=True) sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon") diff_genes = top_100_genes @@ -580,13 +581,13 @@ def plot_reg_var_plot( start, stop, step = kwargs.get("range") ax.set_xticks(np.arange(start, stop, step)) ax.set_yticks(np.arange(start, stop, step)) - # _p1 = pyplot.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}") - # pyplot.plot(x, m * x + b, "-", color="green") + # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}") + # plt.plot(x, m * x + b, "-", color="green") ax.set_xlabel(labels["x"], fontsize=fontsize) ax.set_ylabel(labels["y"], fontsize=fontsize) if "y1" in axis_keys.keys(): y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel() - _ = pyplot.scatter( + _ = plt.scatter( x, y1, marker="*", @@ -599,17 +600,17 @@ def plot_reg_var_plot( j = adata.var_names.tolist().index(i) x_bar = x[j] y_bar = y[j] - pyplot.text(x_bar, y_bar, i, fontsize=11, color="black") - pyplot.plot(x_bar, y_bar, "o", color="red", markersize=5) + plt.text(x_bar, y_bar, i, fontsize=11, color="black") + plt.plot(x_bar, y_bar, "o", color="red", markersize=5) if "y1" in axis_keys.keys(): y1_bar = y1[j] - pyplot.text(x_bar, y1_bar, "*", color="black", alpha=0.5) + plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5) if legend: - pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) if title is None: - pyplot.title("", fontsize=12) + plt.title("", fontsize=12) else: - pyplot.title(title, fontsize=12) + plt.title(title, fontsize=12) ax.text( max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y), @@ -625,25 +626,25 @@ def plot_reg_var_plot( ) if save: - pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100) + plt.savefig(save, bbox_inches="tight") if show: - pyplot.show() - pyplot.close() + plt.show() + plt.close() if diff_genes is not None: return r_value**2, r_value_diff**2 else: return r_value**2 - def plot_plot_binary_classifier( + def plot_binary_classifier( self, - scgen, - adata, - delta, - ctrl_key, - stim_key, - path_to_save, - save=True, - fontsize=14, + scgen: SCGEN, + adata: AnnData | None, + delta: np.ndarray, + ctrl_key: str, + stim_key: str, + show: bool = False, + save: str | bool | None = None, + fontsize: float = 14, ) -> None: """Plots the dot product between delta and latent representation of a linear classifier. @@ -663,7 +664,7 @@ def plot_plot_binary_classifier( save: Specify if the plot should be saved or not. fontsize: Set the font size of the plot. """ - pyplot.close("all") + plt.close("all") adata = scgen._validate_anndata(adata) condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key cd = adata[adata.obs[condition_key] == ctrl_key, :] @@ -676,21 +677,26 @@ def plot_plot_binary_classifier( dot_cd[ind] = np.dot(delta, vec) for ind, vec in enumerate(all_latent_stim): dot_sal[ind] = np.dot(delta, vec) - pyplot.hist( + plt.hist( dot_cd, label=ctrl_key, bins=50, ) - pyplot.hist(dot_sal, label=stim_key, bins=50) - pyplot.axvline(0, color="k", linestyle="dashed", linewidth=1) - pyplot.title(" ", fontsize=fontsize) - pyplot.xlabel(" ", fontsize=fontsize) - pyplot.ylabel(" ", fontsize=fontsize) - pyplot.xticks(fontsize=fontsize) - pyplot.yticks(fontsize=fontsize) - ax = pyplot.gca() + plt.hist(dot_sal, label=stim_key, bins=50) + plt.axvline(0, color="k", linestyle="dashed", linewidth=1) + plt.title(" ", fontsize=fontsize) + plt.xlabel(" ", fontsize=fontsize) + plt.ylabel(" ", fontsize=fontsize) + plt.xticks(fontsize=fontsize) + plt.yticks(fontsize=fontsize) + ax = plt.gca() ax.grid(False) if save: - pyplot.savefig(f"{path_to_save}", bbox_inches="tight", dpi=100) - pyplot.show() + plt.savefig(save, bbox_inches="tight") + return None + if show: + plt.show() + return None + elif not show or show is None: + return ax