Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/theislab/pertpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zethson committed Apr 14, 2024
2 parents 41aaa6f + 5b000ea commit 05a6a5c
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
Expand Down
10 changes: 4 additions & 6 deletions pertpy/plot/_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ def draw_tree( # pragma: no cover
show: bool | None = True,
save: str | None = None,
units: Literal["px", "mm", "in"] | None = "px",
h: float | None = None,
w: float | None = None,
figsize: tuple[float, float] | None = (None, None),
dpi: int | None = 90,
):
"""Plot a tree using input ete3 tree object.
Expand Down Expand Up @@ -429,7 +428,7 @@ def draw_tree( # pragma: no cover
show=show,
save=save,
units=units,
figsize=(w, h),
figsize=figsize,
dpi=dpi,
)

Expand All @@ -446,8 +445,7 @@ def draw_effects( # pragma: no cover
show: bool | None = True,
save: str | None = None,
units: Literal["px", "mm", "in"] | None = "in",
h: float | None = None,
w: float | None = None,
figsize: tuple[float, float] | None = (None, None),
dpi: int | None = 90,
):
"""Plot a tree with colored circles on the nodes indicating significant effects with bar plots which indicate leave-level significant effects.
Expand Down Expand Up @@ -516,7 +514,7 @@ def draw_effects( # pragma: no cover
show=show,
save=save,
units=units,
figsize=(w, h),
figsize=figsize,
dpi=dpi,
)

Expand Down
57 changes: 32 additions & 25 deletions pertpy/tools/_augur.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,17 +979,17 @@ def plot_dp_scatter(
self,
results: pd.DataFrame,
top_n: int = None,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
) -> Axes | Figure | None:
"""Plot scatterplot of differential prioritization.
Args:
results: Results after running differential prioritization.
top_n: optionally, the number of top prioritized cell types to label in the plot
ax: optionally, axes used to draw plot
return_figure: if `True` returns figure of the plot
Returns:
Axes of the plot.
Expand Down Expand Up @@ -1039,24 +1039,26 @@ def plot_dp_scatter(
legend1 = ax.legend(*scatter.legend_elements(), loc="center left", title="z-scores", bbox_to_anchor=(1, 0.5))
ax.add_artist(legend1)

if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

def plot_important_features(
self,
data: dict[str, Any],
key: str = "augurpy_results",
top_n: int = 10,
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
) -> Axes | None:
"""Plot a lollipop plot of the n features with largest feature importances.
Args:
Expand Down Expand Up @@ -1108,23 +1110,25 @@ def plot_important_features(
plt.ylabel("Gene")
plt.yticks(y_axes_range, n_features["genes"])

if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

def plot_lollipop(
self,
data: dict[str, Any],
key: str = "augurpy_results",
return_fig: bool | None = None,
ax: Axes = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
) -> Axes | Figure | None:
"""Plot a lollipop plot of the mean augur values.
Args:
Expand Down Expand Up @@ -1172,23 +1176,25 @@ def plot_lollipop(
plt.ylabel("Cell Type")
plt.yticks(y_axes_range, results["summary_metrics"].sort_values("mean_augur_score", axis=1).columns)

if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None

def plot_scatterplot(
self,
results1: dict[str, Any],
results2: dict[str, Any],
top_n: int = None,
return_fig: bool | None = None,
show: bool | None = None,
save: str | bool | None = None,
) -> Figure | Axes | None:
) -> Axes | Figure | None:
"""Create scatterplot with two augur results.
Args:
Expand Down Expand Up @@ -1244,11 +1250,12 @@ def plot_scatterplot(
plt.xlabel("Augur scores 1")
plt.ylabel("Augur scores 2")

if show:
plt.show()
return None
if save:
plt.savefig(save, bbox_inches="tight")
return None
elif not show or show is None:
if show:
plt.show()
if return_fig:
return plt.gcf()
if not (show or save):
return ax
return None
Loading

0 comments on commit 05a6a5c

Please sign in to comment.