Skip to content

Commit

Permalink
Harmonize plots (#574)
Browse files Browse the repository at this point in the history
* add scgen reg plot and fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add changes to coda

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix pre commit

* fix pyplot import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix pre-commit

* mixscape

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix3

* milo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* enrichment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dialogue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* augur

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* scgen

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix pre-commit

* precomm fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pertpy/tools/_coda/_base_coda.py

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>

* adressed feedback

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix precommit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>
  • Loading branch information
3 people authored Mar 30, 2024
1 parent 03bd173 commit 0f888e8
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 167 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 8 additions & 12 deletions pertpy/plot/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AugurpyPlot:
"""Plotting functions for Augurpy."""

@staticmethod
def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure: bool = False) -> Figure | Axes:
def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None) -> Figure | Axes:
"""Plot result of differential prioritization.
Args:
Expand Down Expand Up @@ -56,11 +56,11 @@ def dp_scatter(results: pd.DataFrame, top_n=None, ax: Axes = None, return_figure

ag = Augur("random_forest_classifier")

return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax, return_figure=return_figure)
return ag.plot_dp_scatter(results=results, top_n=top_n, ax=ax)

@staticmethod
def important_features(
data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None, return_figure: bool = False
data: dict[str, Any], key: str = "augurpy_results", top_n=10, ax: Axes = None
) -> Figure | Axes:
"""Plot a lollipop plot of the n features with largest feature importances.
Expand Down Expand Up @@ -95,12 +95,10 @@ def important_features(

ag = Augur("random_forest_classifier")

return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax, return_figure=return_figure)
return ag.plot_important_features(data=data, key=key, top_n=top_n, ax=ax)

@staticmethod
def lollipop(
data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
) -> Figure | Axes:
def lollipop(data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None) -> Figure | Axes | None:
"""Plot a lollipop plot of the mean augur values.
Args:
Expand Down Expand Up @@ -133,12 +131,10 @@ def lollipop(

ag = Augur("random_forest_classifier")

return ag.plot_lollipop(data=data, key=key, ax=ax, return_figure=return_figure)
return ag.plot_lollipop(data=data, key=key, ax=ax)

@staticmethod
def scatterplot(
results1: dict[str, Any], results2: dict[str, Any], top_n=None, return_figure: bool = False
) -> Figure | Axes:
def scatterplot(results1: dict[str, Any], results2: dict[str, Any], top_n=None) -> Figure | Axes:
"""Create scatterplot with two augur results.
Args:
Expand Down Expand Up @@ -172,4 +168,4 @@ def scatterplot(

ag = Augur("random_forest_classifier")

return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n, return_figure=return_figure)
return ag.plot_scatterplot(results1=results1, results2=results2, top_n=top_n)
6 changes: 2 additions & 4 deletions pertpy/plot/_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,7 @@ def draw_tree( # pragma: no cover
show=show,
save=save,
units=units,
h=h,
w=w,
figsize=(w, h),
dpi=dpi,
)

Expand Down Expand Up @@ -517,8 +516,7 @@ def draw_effects( # pragma: no cover
show=show,
save=save,
units=units,
h=h,
w=w,
figsize=(w, h),
dpi=dpi,
)

Expand Down
68 changes: 56 additions & 12 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,13 @@ def predict_differential_prioritization(
return delta

def plot_dp_scatter(
self, results: pd.DataFrame, top_n: int = None, ax: Axes = None, return_figure: bool = False
) -> Figure | Axes:
self,
results: pd.DataFrame,
top_n: int = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
"""Plot scatterplot of differential prioritization.
Args:
Expand Down Expand Up @@ -1034,16 +1039,24 @@ def plot_dp_scatter(
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
ax.add_artist(legend1)

return fig if return_figure else ax
if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
return ax

def plot_important_features(
self,
data: dict[str, Any],
key: str = "augurpy_results",
top_n: int = 10,
ax: Axes = None,
return_figure: bool = False,
) -> Figure | Axes:
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
"""Plot a lollipop plot of the n features with largest feature importances.
Args:
Expand Down Expand Up @@ -1095,11 +1108,23 @@ def plot_important_features(
plt.ylabel("Gene")
plt.yticks(y_axes_range, n_features["genes"])

return fig if return_figure else ax
if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
return ax

def plot_lollipop(
self, data: dict[str, Any], key: str = "augurpy_results", ax: Axes = None, return_figure: bool = False
) -> Figure | Axes:
self,
data: dict[str, Any],
key: str = "augurpy_results",
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
"""Plot a lollipop plot of the mean augur values.
Args:
Expand Down Expand Up @@ -1147,11 +1172,23 @@ def plot_lollipop(
plt.ylabel("Cell Type")
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)

return fig if return_figure else ax
if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
return ax

def plot_scatterplot(
self, results1: dict[str, Any], results2: dict[str, Any], top_n: int = None, return_figure: bool = False
) -> Figure | Axes:
self,
results1: dict[str, Any],
results2: dict[str, Any],
top_n: int = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
"""Create scatterplot with two augur results.
Args:
Expand Down Expand Up @@ -1207,4 +1244,11 @@ def plot_scatterplot(
plt.xlabel("Augur scores 1")
plt.ylabel("Augur scores 2")

return fig if return_figure else ax
if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
return ax
78 changes: 45 additions & 33 deletions pertpy/tools/_coda/_base_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
from scipy.cluster import hierarchy as sp_hierarchy

if TYPE_CHECKING:
from collections.abc import Sequence

import numpyro as npy
import toytree as tt
from ete3 import Tree
from jax._src.typing import Array
from matplotlib.axes import Axes
from matplotlib.colors import Colormap

config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -1185,11 +1188,11 @@ def plot_stacked_barplot( # pragma: no cover
data: AnnData | MuData,
feature_name: str,
modality_key: str = "coda",
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
palette: ListedColormap | None = cm.tab20,
show_legend: bool | None = True,
level_order: list[str] = None,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
ax: plt.Axes | None = None,
show: bool | None = None,
save: str | bool | None = None,
Expand Down Expand Up @@ -1288,11 +1291,11 @@ def plot_effects_barplot( # pragma: no cover
plot_facets: bool = True,
plot_zero_covariate: bool = True,
plot_zero_cell_type: bool = False,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
palette: str | ListedColormap | None = cm.tab20,
level_order: list[str] = None,
args_barplot: dict | None = None,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
ax: plt.Axes | None = None,
show: bool | None = None,
save: str | bool | None = None,
Expand Down Expand Up @@ -1492,11 +1495,11 @@ def plot_boxplots( # pragma: no cover
cell_types: list | None = None,
args_boxplot: dict | None = None,
args_swarmplot: dict | None = None,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
palette: str | None = "Blues",
show_legend: bool | None = True,
level_order: list[str] = None,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
ax: plt.Axes | None = None,
show: bool | None = None,
save: str | bool | None = None,
Expand Down Expand Up @@ -1738,7 +1741,7 @@ def plot_rel_abundance_dispersion_plot( # pragma: no cover
ax: plt.Axes | None = None,
show: bool | None = None,
save: str | bool | None = None,
) -> plt.Axes:
) -> plt.Axes | None:
"""Plots total variance of relative abundance versus minimum relative abundance of all cell types for determination of a reference cell type.
If the count of the cell type is larger than 0 in more than abundant_threshold percent of all samples, the cell type will be marked in a different color.
Expand Down Expand Up @@ -1858,12 +1861,11 @@ def plot_draw_tree( # pragma: no cover
tight_text: bool | None = False,
show_scale: bool | None = False,
units: Literal["px", "mm", "in"] | None = "px",
h: float | None = None,
w: float | None = None,
dpi: int | None = 90,
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
show: bool | None = True,
save: str | bool | None = None,
):
) -> Tree | None:
"""Plot a tree using input ete3 tree object.
Args:
Expand All @@ -1882,11 +1884,9 @@ def plot_draw_tree( # pragma: no cover
file_name: Path to the output image file. Valid extensions are .SVG, .PDF, .PNG.
Output image can be saved whether show is True or not.
Defaults to None.
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches.
Defaults to "px".
h: Height of the image in units. Defaults to None.
w: Width of the image in units. Defaults to None.
dpi: Dots per inches. Defaults to 90.
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
figsize: Figure size. Defaults to None.
dpi: Dots per inches. Defaults to 100.
Returns:
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`) or plot the tree inline (`show = False`)
Expand Down Expand Up @@ -1931,10 +1931,11 @@ def my_layout(node):
tree_style.show_leaf_name = False
tree_style.layout_fn = my_layout
tree_style.show_scale = show_scale

if save is not None:
tree.render(save, tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore
tree.render(save, tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
if show:
return tree.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi) # type: ignore
return tree.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi) # type: ignore
else:
return tree, tree_style

Expand All @@ -1948,12 +1949,11 @@ def plot_draw_effects( # pragma: no cover
show_leaf_effects: bool | None = False,
tight_text: bool | None = False,
show_scale: bool | None = False,
units: Literal["px", "mm", "in"] | None = "px",
figsize: tuple[float, float] | None = None,
dpi: int | None = 100,
show: bool | None = True,
save: str | None = None,
units: Literal["px", "mm", "in"] | None = "in",
h: float | None = None,
w: float | None = None,
dpi: int | None = 90,
):
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
Expand All @@ -1975,10 +1975,9 @@ def plot_draw_effects( # pragma: no cover
show: If True, plot the tree inline. If false, return tree and tree_style objects. Defaults to True.
file_name: Path to the output image file. valid extensions are .SVG, .PDF, .PNG. Output image can be saved whether show is True or not.
Defaults to None.
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Default is "in". Defaults to "in".
h: Height of the image in units. Defaults to None.
w: Width of the image in units. Defaults to None.
dpi: Dots per inches. Defaults to 90.
units: Unit of image sizes. “px”: pixels, “mm”: millimeters, “in”: inches. Defaults to "px".
figsize: Figure size. Defaults to None.
dpi: Dots per inches. Defaults to 100.
Returns:
Depending on `show`, returns :class:`ete3.TreeNode` and :class:`ete3.TreeStyle` (`show = False`)
Expand Down Expand Up @@ -2130,21 +2129,23 @@ def my_layout(node):
tree2.render(save, tree_style=tree_style, units=units)
if show:
if not show_leaf_effects:
return tree2.render("%%inline", tree_style=tree_style, units=units, w=w, h=h, dpi=dpi)
return tree2.render("%%inline", tree_style=tree_style, units=units, w=figsize[0], h=figsize[1], dpi=dpi)
else:
if not show_leaf_effects:
return tree2, tree_style

def plot_effects_umap( # pragma: no cover
self,
data: MuData,
mdata: MuData,
effect_name: str | list | None,
cluster_key: str,
modality_key_1: str = "rna",
modality_key_2: str = "coda",
color_map: Colormap | str | None = None,
palette: str | Sequence[str] | None = None,
ax: Axes = None,
show: bool = None,
save: str | bool | None = None,
ax: Axes = None,
**kwargs,
) -> plt.Axes | None:
"""Plot a UMAP visualization colored by effect strength.
Expand All @@ -2153,7 +2154,7 @@ def plot_effects_umap( # pragma: no cover
(default is data['rna']) depending on the cluster they were assigned to.
Args:
data: AnnData object or MuData object.
mudata: MuData object.
effect_name: The name of the effect results in .varm of aggregated sample-level AnnData to plot
cluster_key: The cluster information in .obs of cell-level AnnData (default is data['rna']).
To assign cell types' effects to original cells.
Expand Down Expand Up @@ -2207,8 +2208,8 @@ def plot_effects_umap( # pragma: no cover
.. image:: /_static/docstring_previews/tasccoda_effects_umap.png
"""
# TODO: Add effect_name parameter and cluster_key and test the example
data_rna = data[modality_key_1]
data_coda = data[modality_key_2]
data_rna = mdata[modality_key_1]
data_coda = mdata[modality_key_2]
if isinstance(effect_name, str):
effect_name = [effect_name]
for _, effect in enumerate(effect_name):
Expand All @@ -2224,7 +2225,18 @@ def plot_effects_umap( # pragma: no cover
else:
vmax = max(data_rna.obs[effect].max() for _, effect in enumerate(effect_name))

return sc.pl.umap(data_rna, color=effect_name, vmax=vmax, vmin=vmin, ax=ax, show=show, save=save, **kwargs)
return sc.pl.umap(
data_rna,
color=effect_name,
vmax=vmax,
vmin=vmin,
palette=palette,
color_map=color_map,
ax=ax,
show=show,
save=save,
**kwargs,
)


def get_a(
Expand Down
Loading

0 comments on commit 0f888e8

Please sign in to comment.