Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for gradient coloring and alpha in edges #36

Merged
merged 16 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions pysankey/gradient_test.ipynb

Large diffs are not rendered by default.

85 changes: 77 additions & 8 deletions pysankey/sankey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -72,6 +73,8 @@ def sankey(
closePlot: bool = False,
figSize: Optional[Tuple[int, int]] = None,
ax: Optional[Any] = None,
color_gradient: bool = False,
alphaDict: Optional[Dict[Union[str, Tuple[str, str]], float]] = None,
) -> Any:
"""
Make Sankey Diagram showing flow from left-->right
Expand Down Expand Up @@ -128,6 +131,20 @@ def sankey(
rightWidths, topEdge = _get_positions_and_total_widths(
data_frame, rightLabels, "right"
)
# If no alphaDict given, make one
if alphaDict is None:
alphaDict = {}
for _, label in enumerate(all_labels):
alphaDict[label] = 0.65
else:
missing = [label for label in all_labels if label not in alphaDict.keys()]
if missing:
msg = (
"The alphaDict parameter is missing values for the following labels : "
)
msg += ", ".join(missing)
raise ValueError(msg)
LOGGER.debug("The alphadict value are : %s", alphaDict)
# Total vertical extent of diagram
xMax = topEdge / aspect
draw_vertical_bars(
Expand All @@ -152,6 +169,8 @@ def sankey(
rightLabels,
rightWidths,
xMax,
alphaDict,
color_gradient,
)
if figSize is not None:
plt.gcf().set_size_inches(figSize)
Expand Down Expand Up @@ -353,7 +372,10 @@ def _create_dataframe(

def plot_strips(
ax: Any,
colorDict: Union[Dict[str, Tuple[float, float, float]], Dict[str, str]],
colorDict: Union[
Dict[Union[str, Tuple[str, str]], Tuple[float, float, float]],
Dict[Union[str, Tuple[str, str]], str],
],
dataFrame: DataFrame,
leftLabels: ndarray,
leftWidths: Dict,
Expand All @@ -363,6 +385,8 @@ def plot_strips(
rightLabels: ndarray,
rightWidths: Dict,
xMax: float64,
alphaDict: Dict[Union[str, Tuple[str, str]], float],
color_gradient: bool = False,
) -> None:
# Plot strips
for leftLabel in leftLabels:
Expand Down Expand Up @@ -398,13 +422,58 @@ def plot_strips(
# right place
leftWidths[leftLabel]["bottom"] += ns_l[leftLabel][rightLabel]
rightWidths[rightLabel]["bottom"] += ns_r[leftLabel][rightLabel]
ax.fill_between(
np.linspace(0, xMax, len(ys_d)),
ys_d,
ys_u,
alpha=0.65,
color=colorDict[label_color],
)

if (leftLabel, rightLabel) in alphaDict:
alpha = alphaDict[leftLabel, rightLabel]
else:
alpha = alphaDict[label_color]
if color_gradient:
if (leftLabel, rightLabel) in colorDict:
cleft = cright = colorDict[leftLabel, rightLabel]
else:
cleft = colorDict[leftLabel]
cright = colorDict[rightLabel]

x = list(np.linspace(0, xMax, len(ys_d)))
(poly,) = ax.fill(
x + x[::-1] + [x[0]],
list(ys_d) + list(ys_u)[::-1] + [ys_d[0]],
facecolor="none",
)

# get the extent of the axes
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()

# create a dummy image
img_data = np.arange(xmin, xmax, (xmax - xmin) / 100.0)
img_data = img_data.reshape(img_data.size, 1).T

# plot and clip the image
im = ax.imshow(
img_data,
aspect="auto",
origin="lower",
cmap=mpl.colors.LinearSegmentedColormap.from_list(
"custom", [cleft, cright]
),
alpha=alpha,
extent=[xmin, xmax, ymin, ymax],
)

im.set_clip_path(poly)
else:
if (leftLabel, rightLabel) in colorDict:
color = colorDict[leftLabel, rightLabel]
else:
color = colorDict[label_color]
ax.fill_between(
np.linspace(0, xMax, len(ys_d)),
ys_d,
ys_u,
alpha=alpha,
color=color,
)
ax.axis("off")


Expand Down
Loading