Skip to content

Commit

Permalink
Merge pull request #95 from duneanalytics/add-sankey
Browse files Browse the repository at this point in the history
adding charting function for Sankey diagram
  • Loading branch information
msf authored Oct 3, 2023
2 parents b510c97 + 8406356 commit 84b3e94
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dune_client/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_query_dataframe(
This is a convenience method that uses run_query_csv() + pandas.read_csv() underneath
"""
try:
import pandas # type: ignore # pylint: disable=import-outside-toplevel
import pandas # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"dependency failure, pandas is required but missing"
Expand Down
2 changes: 1 addition & 1 deletion dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def refresh_into_dataframe(
This is a convenience method that uses refresh_csv underneath
"""
try:
import pandas # type: ignore # pylint: disable=import-outside-toplevel
import pandas # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"dependency failure, pandas is required but missing"
Expand Down
Empty file added dune_client/viz/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions dune_client/viz/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Functions you can call to make different graphs
"""

from typing import Dict, Union

# https://github.com/plotly/colorlover/issues/35
import colorlover as cl # type: ignore[import]
import pandas as pd
import plotly.graph_objects as go # type: ignore[import]
from plotly.graph_objs import Figure # type: ignore[import]


# function to create Sankey diagram
def create_sankey(
query_result: pd.DataFrame,
predefined_colors: Dict[str, str],
columns: Dict[str, str],
viz_config: Dict[str, Union[int, float]],
title: str = "unnamed",
) -> Figure:
"""
Creates a Sankey diagram based on input query_result,
which must contain source, target, value columns.
Column names don't have to be exact same but there must be
these three columns conceptually and value column must be numeric.
"""
# Check if the dataframe contains required columns
required_columns = [columns["source"], columns["target"], columns["value"]]
for col in required_columns:
if col not in query_result.columns:
raise ValueError(f"Error: The dataframe is missing the '{col}' column")

# Check if 'value' column is numeric
if not pd.api.types.is_numeric_dtype(query_result[columns["value"]]):
raise ValueError("Error: The 'value' column must be numeric")

# preprocess query result dataframe
all_nodes = list(
pd.concat(
[query_result[columns["source"]], query_result[columns["target"]]]
).unique()
)
# In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices.
query_result["source_idx"] = query_result[columns["source"]].map(all_nodes.index)
query_result["target_idx"] = query_result[columns["target"]].map(all_nodes.index)

# create color map for Sankey
colors = cl.scales["12"]["qual"]["Set3"] # default color
color_map = {}
for node in all_nodes:
for name, color in predefined_colors.items():
if name.lower() in node.lower(): # check if name exists in the node name
color_map[node] = color
break
else:
color_map[node] = colors[
len(color_map) % len(colors)
] # default color assignment

fig = go.Figure(
go.Sankey(
node={
"pad": viz_config["node_pad"],
"thickness": viz_config["node_thickness"],
"line": {"color": "black", "width": viz_config["node_line_width"]},
"label": all_nodes,
"color": [
color_map.get(node, "blue") for node in all_nodes
], # customize node color
},
link={
"source": query_result["source_idx"],
"target": query_result["target_idx"],
"value": query_result[columns["value"]],
"color": [
color_map.get(query_result[columns["source"]].iloc[i], "black")
for i in range(len(query_result))
], # customize link color
},
)
)
fig.update_layout(
title_text=title,
font_size=viz_config["font_size"],
height=viz_config["figure_height"],
width=viz_config["figure_width"],
)

return fig
3 changes: 3 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
-r prod.txt
black>=23.7.0
pandas>=1.0.0
pandas-stubs>=1.0.0
pylint>=2.17.5
pytest>=7.4.1
python-dotenv>=1.0.0
mypy>=1.5.1
aiounittest>=1.4.2
colorlover>=0.3.0
plotly>=5.9.0
69 changes: 69 additions & 0 deletions tests/unit/test_viz_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
from unittest.mock import patch
import pandas as pd
from dune_client.viz.graphs import create_sankey


class TestCreateSankey(unittest.TestCase):
# Setting up a dataframe for testing
def setUp(self):
self.df = pd.DataFrame(
{
"source_col": ["WBTC", "USDC"],
"target_col": ["USDC", "WBTC"],
"value_col": [2184, 2076],
}
)

self.predefined_colors = {
"USDC": "rgb(38, 112, 196)",
"WBTC": "rgb(247, 150, 38)",
}

self.columns = {
"source": "source_col",
"target": "target_col",
"value": "value_col",
}
self.viz_config: dict = {
"node_pad": 15,
"node_thickness": 20,
"node_line_width": 0.5,
"font_size": 10,
"figure_height": 1000,
"figure_width": 1500,
}

def test_missing_column(self):
# Remove a required column from dataframe
df_without_target = self.df.drop(columns=["target_col"])
with self.assertRaises(ValueError):
create_sankey(
df_without_target, self.predefined_colors, self.columns, self.viz_config
)

def test_value_column_not_numeric(self):
# Change the 'value' column to a non-numeric type
df_with_str_values = self.df.copy()
df_with_str_values["value_col"] = ["10", "11"]
with self.assertRaises(ValueError):
create_sankey(
df_with_str_values,
self.predefined_colors,
self.columns,
self.viz_config,
)

# Mocking the visualization creation and just testing the processing logic
@patch("plotly.graph_objects.Figure")
def test_mocked_visualization(self, MockFigure):
result = create_sankey(
self.df, self.predefined_colors, self.columns, self.viz_config, "test"
)

# Ensuring our mocked Figure was called with the correct parameters
MockFigure.assert_called_once()


if __name__ == "__main__":
unittest.main()

0 comments on commit 84b3e94

Please sign in to comment.