Skip to content

Commit

Permalink
add shap in env files and add framework for shap
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Oct 19, 2023
1 parent e54eec3 commit f76de65
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 6 deletions.
1 change: 1 addition & 0 deletions env-ubuntu-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- pytorch-scatter
- tensorboard
- torchvision
- shap
- tqdm
- pytest
- black
Expand Down
1 change: 1 addition & 0 deletions env-windows-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- pytorch-scatter
- tensorboard
- torchvision
- shap
- tqdm
- pytest
- black
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ hydroutils~=0.0.2
hydrodataset~=0.1.3
pandas~=1.5.1
torch~=1.12.1
shap
hydroerr~=1.24
setuptools~=65.5.0
torchvision~=0.13.1
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ torch~=1.12.1
hydroerr~=1.24
setuptools~=65.5.0
torchvision~=0.13.1
shap
matplotlib~=3.5.3
seaborn~=0.12.1

Expand Down
199 changes: 199 additions & 0 deletions torchhydro/explainers/shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""
Author: Wenyu Ouyang
Date: 2023-10-19 21:34:29
LastEditTime: 2023-10-19 21:59:34
LastEditors: Wenyu Ouyang
Description: SHAP methods for deep learning models
FilePath: /torchhydro/torchhydro/explainers/shap.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""
import torch
import numpy as np
import shap


def plot_summary_shap_values(shap_values: torch.tensor, columns):
mean_shap_values = shap_values.mean(axis=["preds", "batches"])

fig = go.Figure()
bar_plot = go.Bar(
y=columns, x=mean_shap_values.abs().mean(axis="observations"), orientation="h"
)
fig.add_trace(bar_plot)
fig.update_layout(yaxis={"categoryorder": "array", "categoryarray": columns[::-1]})

return fig


def plot_summary_shap_values_over_time_series(shap_values: torch.tensor, columns):
abs_mean_shap_values = shap_values.mean(axis=["batches"]).abs()
multi_shap_values = abs_mean_shap_values.mean(axis="observations")

fig = go.Figure()
for i, pred_shap_values in enumerate(multi_shap_values.align_to("preds", ...)):
fig.add_trace(
go.Bar(
y=columns, x=pred_shap_values, name=f"time-step {i}", orientation="h"
)
)
fig.update_layout(
barmode="stack",
yaxis={"categoryorder": "array", "categoryarray": columns[::-1]},
)
return fig


def plot_shap_values_from_history(shap_values: torch.tensor, history: torch.tensor):
mean_shap_values = shap_values.mean(axis=["preds", "batches"])
mean_history_values = history.mean(axis="batches")

figs: List[go.Figure] = []
for feature_history, feature_shap_values in zip(
mean_history_values.align_to("features", ...),
mean_shap_values.align_to("features", ...),
):
fig = go.Figure()
scatter = go.Scatter(
y=jitter(feature_shap_values),
x=feature_shap_values,
mode="markers",
marker=dict(
color=feature_history,
colorbar=dict(title=dict(side="right", text="feature values")),
colorscale=px.colors.sequential.Bluered,
),
)
fig.add_trace(scatter)
fig.update_yaxes(range=[-0.05, 0.05])
fig.update_xaxes(title_text="shap value")
fig.update_layout(showlegend=False)
figs.append(fig)
return figs


def deep_explain_model_summary_plot(deep_hydro, test_loader) -> None:
"""Generate feature summary plot for trained deep learning models
Parameters
----------
model (object): trained model
test_loader (TestLoader): test data loader
"""
deep_hydro.model.eval()

# background shape (L, N, M)
# L - batch size, N - history length, M - feature size
s_values_list = []
if isinstance(history, list):
deep_hydro.model = deep_hydro.model.to("cpu")
deep_explainer = shap.DeepExplainer(deep_hydro.model, history)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(deep_hydro.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = np.stack(shap_values)
# shap_values needs to be 4-dimensional
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
shap_values, names=["preds", "batches", "observations", "features"]
)

# summary plot shows overall feature ranking
# by average absolute shap values
fig = plot_summary_shap_values(shap_values, test_loader.df.columns)
abs_mean_shap_values = shap_values.mean(axis=["preds", "batches"])
multi_shap_values = abs_mean_shap_values.mean(axis="observations")

# summary plot for multi-step outputs
# multi_shap_values = shap_values.apply_along_axis(np.mean, 'batches')
fig = plot_summary_shap_values_over_time_series(shap_values, test_loader.df.columns)

# summary plot for one prediction at datetime_start
if isinstance(history, list):
hist = history[0]
else:
hist = history

history_numpy = torch.tensor(
hist.cpu().numpy(), names=["batches", "observations", "features"]
)

shap_values = deep_explainer.shap_values(history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
shap_values, names=["preds", "batches", "observations", "features"]
)

figs = plot_shap_values_from_history(shap_values, history_numpy)


def plot_shap_value_heatmaps(shap_values: torch.tensor):
average_shap_value_over_batches = shap_values.mean(axis="batches")

x = [i for i in range(shap_values.align_to("observations", ...).shape[0])]
y = [i for i in range(shap_values.align_to("preds", ...).shape[0])]

figs: List[go.Figure] = []
for shap_values_features in average_shap_value_over_batches.align_to(
"features", ...
):
fig = go.Figure()
heatmap = go.Heatmap(
z=shap_values_features,
x=x,
y=y,
colorbar=dict(title=dict(side="right", text="feature values")),
colorscale=px.colors.sequential.Bluered,
)
fig.add_trace(heatmap)
fig.update_xaxes(title_text="sequence history steps")
fig.update_yaxes(title_text="prediction steps")
figs.append(fig)
return figs


def deep_explain_model_heatmap(deep_hydro, test_loader) -> None:
"""Generate feature heatmap for prediction at a start time
Args:
model ([type]): trained model
test_loader ([TestLoader]): test data loader
Returns:
None
"""
deep_hydro.model.eval()

# background shape (L, N, M)
# L - batch size, N - history length, M - feature size
# for each element in each N x M batch in L,
# attribute to each prediction in forecast len
s_values_list = []
if isinstance(history, list):
deep_explainer = shap.DeepExplainer(deep_hydro.model, history)
shap_values = deep_explainer.shap_values(history)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(deep_hydro.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = np.stack(shap_values) # forecast_len x N x L x M
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
shap_values, names=["preds", "batches", "observations", "features"]
)
figs = plot_shap_value_heatmaps(shap_values)

# heatmap one prediction sequence at datetime_start
# (seq_len*forecast_len) per fop feature
to_explain = history
shap_values = deep_explainer.shap_values(to_explain)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
shap_values = np.expand_dims(shap_values, axis=0)
shap_values = torch.tensor(
shap_values, names=["preds", "batches", "observations", "features"]
) # no fake ballo t
figs = plot_shap_value_heatmaps(shap_values)
15 changes: 9 additions & 6 deletions torchhydro/trainers/deep_hydro.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2023-10-11 11:35:13
LastEditTime: 2023-10-19 22:08:50
LastEditors: Wenyu Ouyang
Description: HydroDL model class
FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py
FilePath: /torchhydro/torchhydro/trainers/deep_hydro.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""

Expand All @@ -21,6 +21,10 @@
from hydrodataset import HydroDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchhydro.explainers.shap import (
deep_explain_model_heatmap,
deep_explain_model_summary_plot,
)
from torchhydro.configs.config import update_nested_dict
from torchhydro.datasets.sampler import KuaiSampler, fl_sample_basin, fl_sample_region
from torchhydro.datasets.data_dict import datasets_dict
Expand All @@ -39,9 +43,9 @@
compute_validation,
model_infer,
torch_single_train,
cellstates_when_inference,
)
from torchhydro.trainers.train_logger import TrainLogger
from trainers.train_utils import cellstates_when_inference


class DeepHydroInterface(ABC):
Expand Down Expand Up @@ -358,13 +362,12 @@ def model_evaluate(self) -> Tuple[Dict, np.array, np.array]:
]

# Finally, try to explain model behaviour using shap
# TODO: SHAP has not been supported
is_shap = False
if is_shap:
deep_explain_model_summary_plot(
model, test_data, data_cfgs["t_range_test"][0]
self, test_data, data_cfgs["t_range_test"][0]
)
deep_explain_model_heatmap(model, test_data, data_cfgs["t_range_test"][0])
deep_explain_model_heatmap(self, test_data, data_cfgs["t_range_test"][0])

return eval_log, preds_xr, obss_xr

Expand Down

0 comments on commit f76de65

Please sign in to comment.