From b2fd0992412d94fc8c481d4feceafb49b9337f86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:05:18 +0200 Subject: [PATCH 1/2] [pre-commit.ci] pre-commit autoupdate (#581) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.5.0...v4.6.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eab19000..c6be3c25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: detect-private-key - id: check-ast From 5b000eafa20e3826d7cf1680bb4ac74ee22774e5 Mon Sep 17 00:00:00 2001 From: Altana Namsaraeva <99650244+namsaraeva@users.noreply.github.com> Date: Sat, 13 Apr 2024 00:03:26 +0200 Subject: [PATCH 2/2] Check and fix plotting functions (#579) * fix coda * change augur show/save semantics * fix tasccoda figsize and edit semantics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * dialogue * fix pre-commit * fix pre-commit * fix mixscape * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add return types and simplify code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * scgen * augur * dialogue * milo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mixscape * augur return statements * the rest * precomm * simplify if * augur if * coda return fig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add augur return fig * dialogue return fig * milo return fig * mixscape * returnn types coda * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * return types augur * gialogue return types * milo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mixscape * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pre comm * milo --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pertpy/plot/_coda.py | 10 ++--- pertpy/tools/_augur.py | 57 ++++++++++++++----------- pertpy/tools/_coda/_base_coda.py | 72 +++++++++++++++++++------------- pertpy/tools/_dialogue.py | 29 +++++++------ pertpy/tools/_milo.py | 16 ++++--- pertpy/tools/_mixscape.py | 36 ++++++++++++---- pertpy/tools/_scgen/_scgen.py | 7 ++-- 7 files changed, 138 insertions(+), 89 deletions(-) diff --git a/pertpy/plot/_coda.py b/pertpy/plot/_coda.py index 1464f670..f2e50566 100644 --- a/pertpy/plot/_coda.py +++ b/pertpy/plot/_coda.py @@ -361,8 +361,7 @@ def draw_tree( # pragma: no cover show: bool | None = True, save: str | None = None, units: Literal["px", "mm", "in"] | None = "px", - h: float | None = None, - w: float | None = None, + figsize: tuple[float, float] | None = (None, None), dpi: int | None = 90, ): """Plot a tree using input ete3 tree object. @@ -429,7 +428,7 @@ def draw_tree( # pragma: no cover show=show, save=save, units=units, - figsize=(w, h), + figsize=figsize, dpi=dpi, ) @@ -446,8 +445,7 @@ def draw_effects( # pragma: no cover show: bool | None = True, save: str | None = None, units: Literal["px", "mm", "in"] | None = "in", - h: float | None = None, - w: float | None = None, + figsize: tuple[float, float] | None = (None, None), dpi: int | None = 90, ): """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects. @@ -516,7 +514,7 @@ def draw_effects( # pragma: no cover show=show, save=save, units=units, - figsize=(w, h), + figsize=figsize, dpi=dpi, ) diff --git a/pertpy/tools/_augur.py b/pertpy/tools/_augur.py index ac135fbf..78eccd0c 100644 --- a/pertpy/tools/_augur.py +++ b/pertpy/tools/_augur.py @@ -979,17 +979,17 @@ def plot_dp_scatter( self, results: pd.DataFrame, top_n: int = None, + return_fig: bool | None = None, ax: Axes = None, show: bool | None = None, save: str | bool | None = None, - ) -> Figure | Axes | None: + ) -> Axes | Figure | None: """Plot scatterplot of differential prioritization. Args: results: Results after running differential prioritization. top_n: optionally, the number of top prioritized cell types to label in the plot ax: optionally, axes used to draw plot - return_figure: if `True` returns figure of the plot Returns: Axes of the plot. @@ -1039,24 +1039,26 @@ def plot_dp_scatter( legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5)) ax.add_artist(legend1) - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_important_features( self, data: dict[str, Any], key: str = "augurpy_results", top_n: int = 10, + return_fig: bool | None = None, ax: Axes = None, show: bool | None = None, save: str | bool | None = None, - ) -> Figure | Axes | None: + ) -> Axes | None: """Plot a lollipop plot of the n features with largest feature importances. Args: @@ -1108,23 +1110,25 @@ def plot_important_features( plt.ylabel("Gene") plt.yticks(y_axes_range, n_features["genes"]) - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_lollipop( self, data: dict[str, Any], key: str = "augurpy_results", + return_fig: bool | None = None, ax: Axes = None, show: bool | None = None, save: str | bool | None = None, - ) -> Figure | Axes | None: + ) -> Axes | Figure | None: """Plot a lollipop plot of the mean augur values. Args: @@ -1172,23 +1176,25 @@ def plot_lollipop( plt.ylabel("Cell Type") plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns) - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_scatterplot( self, results1: dict[str, Any], results2: dict[str, Any], top_n: int = None, + return_fig: bool | None = None, show: bool | None = None, save: str | bool | None = None, - ) -> Figure | Axes | None: + ) -> Axes | Figure | None: """Create scatterplot with two augur results. Args: @@ -1244,11 +1250,12 @@ def plot_scatterplot( plt.xlabel("Augur scores 1") plt.ylabel("Augur scores 2") - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index e4a9f301..13757a9c 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -34,6 +34,7 @@ from jax._src.typing import Array from matplotlib.axes import Axes from matplotlib.colors import Colormap + from matplotlib.figure import Figure config.update("jax_enable_x64", True) @@ -1193,11 +1194,12 @@ def plot_stacked_barplot( # pragma: no cover level_order: list[str] = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, + return_fig: bool | None = None, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, **kwargs, - ) -> plt.Figure | None: + ) -> plt.Axes | plt.Figure | None: """Plots a stacked barplot for all levels of a covariate or all samples (if feature_name=="samples"). Args: @@ -1275,12 +1277,13 @@ def plot_stacked_barplot( # pragma: no cover if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_effects_barplot( # pragma: no cover self, @@ -1296,10 +1299,11 @@ def plot_effects_barplot( # pragma: no cover args_barplot: dict | None = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, + return_fig: bool | None = None, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, - ) -> plt.Axes | sns.axisgrid.FacetGrid | None: + ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: """Barplot visualization for effects. The effect results for each covariate are shown as a group of barplots, with intra--group separation by cell types. @@ -1399,7 +1403,7 @@ def plot_effects_barplot( # pragma: no cover # If plot as facets, create a FacetGrid and map barplot to it. if plot_facets: if isinstance(palette, ListedColormap): - palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]) + palette = np.array([palette(i % palette.N) for i in range(len(plot_df["Cell Type"].unique()))]).tolist() if figsize is not None: height = figsize[0] aspect = np.round(figsize[1] / figsize[0], 2) @@ -1437,12 +1441,13 @@ def plot_effects_barplot( # pragma: no cover if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return g + return None # If not plot as facets, call barplot to plot cell types on the x-axis. else: @@ -1477,12 +1482,13 @@ def plot_effects_barplot( # pragma: no cover if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - if not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_boxplots( # pragma: no cover self, @@ -1500,10 +1506,11 @@ def plot_boxplots( # pragma: no cover level_order: list[str] = None, figsize: tuple[float, float] | None = None, dpi: int | None = 100, + return_fig: bool | None = None, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, - ) -> plt.Axes | sns.axisgrid.FacetGrid | None: + ) -> plt.Axes | plt.Figure | sns.axisgrid.FacetGrid | None: """Grouped boxplot visualization. The cell counts for each cell type are shown as a group of boxplots @@ -1649,12 +1656,13 @@ def plot_boxplots( # pragma: no cover if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return g + return None # If not plot as facets, call boxplot to plot cell types on the x-axis. else: @@ -1721,12 +1729,13 @@ def plot_boxplots( # pragma: no cover if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_rel_abundance_dispersion_plot( # pragma: no cover self, @@ -1738,10 +1747,11 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover label_cell_types: bool = True, figsize: tuple[float, float] | None = None, dpi: int | None = 100, + return_fig: bool | None = None, ax: plt.Axes | None = None, show: bool | None = None, save: str | bool | None = None, - ) -> plt.Axes | None: + ) -> plt.Axes | plt.Figure | None: """Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type. If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color. @@ -1846,12 +1856,13 @@ def label_point(x, y, val, ax): if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_draw_tree( # pragma: no cover self, @@ -1861,7 +1872,7 @@ def plot_draw_tree( # pragma: no cover tight_text: bool | None = False, show_scale: bool | None = False, units: Literal["px", "mm", "in"] | None = "px", - figsize: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = (None, None), dpi: int | None = 100, show: bool | None = True, save: str | bool | None = None, @@ -1950,11 +1961,11 @@ def plot_draw_effects( # pragma: no cover tight_text: bool | None = False, show_scale: bool | None = False, units: Literal["px", "mm", "in"] | None = "px", - figsize: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = (None, None), dpi: int | None = 100, show: bool | None = True, save: str | None = None, - ): + ) -> Tree | None: """Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects. Args: @@ -2133,6 +2144,7 @@ def my_layout(node): else: if not show_leaf_effects: return tree2, tree_style + return None def plot_effects_umap( # pragma: no cover self, @@ -2143,11 +2155,12 @@ def plot_effects_umap( # pragma: no cover modality_key_2: str = "coda", color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, + return_fig: bool | None = None, ax: Axes = None, show: bool = None, save: str | bool | None = None, **kwargs, - ) -> plt.Axes | None: + ) -> plt.Axes | plt.Figure | None: """Plot a UMAP visualization colored by effect strength. Effect results in .varm of aggregated sample-level AnnData (default is data['coda']) are assigned to cell-level AnnData @@ -2232,6 +2245,7 @@ def plot_effects_umap( # pragma: no cover vmin=vmin, palette=palette, color_map=color_map, + return_fig=return_fig, ax=ax, show=show, save=save, @@ -2620,14 +2634,14 @@ def from_scanpy( covariate_df_ = covariate_df_.join(covariate_df_uns, how="left") if covariate_obs: - is_unique = adata.obs.groupby(sample_identifier).transform(lambda x: x.nunique() == 1) + is_unique = adata.obs.groupby(sample_identifier, observed=True).transform(lambda x: x.nunique() == 1) unique_covariates = is_unique.columns[is_unique.all()].tolist() if len(unique_covariates) < len(covariate_obs): skipped = set(covariate_obs) - set(unique_covariates) print(f"[bold yellow]Covariates {skipped} have non-unique values! Skipping...") if unique_covariates: - covariate_df_obs = adata.obs.groupby(sample_identifier).first()[unique_covariates] + covariate_df_obs = adata.obs.groupby(sample_identifier, observed=True).first()[unique_covariates] covariate_df_ = covariate_df_.join(covariate_df_obs, how="left") if covariate_df is not None: diff --git a/pertpy/tools/_dialogue.py b/pertpy/tools/_dialogue.py index d51e4a27..63ed5d97 100644 --- a/pertpy/tools/_dialogue.py +++ b/pertpy/tools/_dialogue.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.figure import Figure class Dialogue: @@ -1070,10 +1071,11 @@ def plot_split_violins( celltype_key: str, split_which: tuple[str, str] = None, mcp: str = "mcp_0", + return_fig: bool | None = None, ax: Axes | None = None, save: bool | str | None = None, show: bool | None = None, - ) -> Axes | None: + ) -> Axes | Figure | None: """Plots split violin plots for a given MCP and split variable. Any cells with a value for split_key not in split_which are removed from the plot. @@ -1111,14 +1113,15 @@ def plot_split_violins( ax.set_xticklabels(ax.get_xticklabels(), rotation=90) - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None def plot_pairplot( self, @@ -1127,9 +1130,10 @@ def plot_pairplot( color: str, sample_id: str, mcp: str = "mcp_0", + return_fig: bool | None = None, show: bool | None = None, save: bool | str | None = None, - ) -> PairGrid: + ) -> PairGrid | Figure | None: """Generate a pairplot visualization for multi-cell perturbation (MCP) data. Computes the mean of a specified MCP feature (mcp) for each combination of sample and cell type, @@ -1168,11 +1172,12 @@ def plot_pairplot( mcp_pivot = pd.concat([mcp_pivot, aggstats[color]], axis=1) ax = sns.pairplot(mcp_pivot, hue=color, corner=True) - if show: - plt.show() - return None if save: plt.savefig(save, bbox_inches="tight") - return None - elif not show or show is None: + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): return ax + return None diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 0e55d3a6..a45873e5 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -796,6 +796,7 @@ def plot_nhood( basis: str = "X_umap", color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, + return_fig: bool | None = None, ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, @@ -834,6 +835,7 @@ def plot_nhood( title="Nhood" + str(ix), color_map=color_map, palette=palette, + return_fig=return_fig, ax=ax, show=show, save=save, @@ -848,6 +850,7 @@ def plot_da_beeswarm( alpha: float = 0.1, subset_nhoods: list[str] = None, palette: str | Sequence[str] | dict[str, str] | None = None, + return_fig: bool | None = None, save: bool | str | None = None, show: bool | None = None, ) -> None: @@ -957,12 +960,12 @@ def plot_da_beeswarm( if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: return plt.gcf() + if (not show and not save) or (show is None and save is None): + return plt.gca() def plot_nhood_counts_by_cond( self, @@ -970,6 +973,7 @@ def plot_nhood_counts_by_cond( test_var: str, subset_nhoods: list[str] = None, log_counts: bool = False, + return_fig: bool | None = None, save: bool | str | None = None, show: bool | None = None, ) -> None: @@ -1010,9 +1014,9 @@ def plot_nhood_counts_by_cond( if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if return_fig: return plt.gcf() + if not (show or save): + return plt.gca() diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 85b201da..3d8dde74 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -510,6 +510,7 @@ def plot_barplot( # pragma: no cover axis_title_size: int = 8, legend_title_size: int = 8, legend_text_size: int = 8, + return_fig: bool | None = None, ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, @@ -557,7 +558,8 @@ def plot_barplot( # pragma: no cover all_cells_percentage["guide_number"] = all_cells_percentage[guide_rna_column].str.rsplit("g", expand=True)[1] all_cells_percentage["guide_number"] = "g" + all_cells_percentage["guide_number"] NP_KO_cells = all_cells_percentage[all_cells_percentage["gene"] != "NT"] - if not show: + + if show: color_mapping = {"KO": "salmon", "NP": "lightgray", "NT": "grey"} unique_genes = NP_KO_cells["gene"].unique() fig, axs = plt.subplots(int(len(unique_genes) / 5), 5, figsize=(25, 25), sharey=True) @@ -594,13 +596,9 @@ def plot_barplot( # pragma: no cover fontsize=legend_text_size, title_fontsize=legend_title_size, ) - plt.tight_layout() - _utils.savefig_or_show("mixscape_barplot", show=show, save=save) - if not show or show is None: - return ax - else: - return None + plt.tight_layout() + _utils.savefig_or_show("mixscape_barplot", show=show, save=save) def plot_heatmap( # pragma: no cover self, @@ -613,6 +611,7 @@ def plot_heatmap( # pragma: no cover subsample_number: int | None = 900, vmin: float | None = -2, vmax: float | None = 2, + return_fig: bool | None = None, show: bool | None = None, save: bool | str | None = None, **kwds, @@ -665,6 +664,7 @@ def plot_heatmap( # pragma: no cover vmax=vmax, n_genes=20, groups=["NT"], + return_fig=return_fig, show=show, save=save, **kwds, @@ -681,6 +681,7 @@ def plot_perturbscore( # pragma: no cover split_by: str = None, before_mixscape: bool = False, perturbation_type: str = "KO", + return_fig: bool | None = None, ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, @@ -730,6 +731,7 @@ def plot_perturbscore( # pragma: no cover perturbation_score = pd.concat([perturbation_score, perturbation_score_temp]) perturbation_score["mix"] = adata.obs[mixscape_class][perturbation_score.index] gd = list(set(perturbation_score[labels]).difference({target_gene}))[0] + # If before_mixscape is True, split densities based on original target gene classification if before_mixscape is True: palette = {gd: "#7d7d7d", target_gene: color} @@ -771,6 +773,15 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="gene_target", title_fontsize=14, fontsize=12) sns.despine() + if save: + plt.savefig(save, bbox_inches="tight") + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): + return plt.gca() + # If before_mixscape is False, split densities based on mixscape classifications else: if palette is None: @@ -827,6 +838,15 @@ def plot_perturbscore( # pragma: no cover plt.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() + if save: + plt.savefig(save, bbox_inches="tight") + if show: + plt.show() + if return_fig: + return plt.gcf() + if not (show or save): + return plt.gca() + def plot_violin( # pragma: no cover self, adata: AnnData, @@ -1041,6 +1061,7 @@ def plot_lda( # pragma: no cover n_components: int | None = None, color_map: Colormap | str | None = None, palette: str | Sequence[str] | None = None, + return_fig: bool | None = None, ax: Axes | None = None, show: bool | None = None, save: bool | str | None = None, @@ -1092,6 +1113,7 @@ def plot_lda( # pragma: no cover color=mixscape_class, palette=palette, color_map=color_map, + return_fig=return_fig, show=show, save=save, ax=ax, diff --git a/pertpy/tools/_scgen/_scgen.py b/pertpy/tools/_scgen/_scgen.py index 3fc42c3a..90647b25 100644 --- a/pertpy/tools/_scgen/_scgen.py +++ b/pertpy/tools/_scgen/_scgen.py @@ -645,7 +645,7 @@ def plot_binary_classifier( show: bool = False, save: str | bool | None = None, fontsize: float = 14, - ) -> None: + ) -> plt.Axes | None: """Plots the dot product between delta and latent representation of a linear classifier. Builds a linear classifier based on the dot product between @@ -694,9 +694,8 @@ def plot_binary_classifier( if save: plt.savefig(save, bbox_inches="tight") - return None if show: plt.show() - return None - elif not show or show is None: + if not (show or save): return ax + return None