Skip to content

Commit

Permalink
explainability: add more visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Aug 29, 2024
1 parent 80e7fcf commit 8e0cc7b
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 140 deletions.
7 changes: 4 additions & 3 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.explainability.explanation import Explanation
from molpipeline.explainability.explanation import Explanation, SHAPExplanation
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor
Expand Down Expand Up @@ -213,7 +213,7 @@ def _prediction_is_valid(self, prediction: Any) -> bool:
return True

# pylint: disable=C0103,W0613
def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand Down Expand Up @@ -278,14 +278,15 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
)

explanation_results.append(
Explanation(
SHAPExplanation(
feature_vector=feature_vector,
feature_names=feature_names,
molecule=molecule,
prediction=prediction,
feature_weights=feature_weights,
atom_weights=atom_weights,
bond_weights=bond_weights,
expected_value=self.explainer.expected_value,
)
)

Expand Down
12 changes: 11 additions & 1 deletion molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from molpipeline.abstract_pipeline_elements.core import RDKitMol


@dataclasses.dataclass()
@dataclasses.dataclass(kw_only=True)
class Explanation:
"""Class representing explanations of a prediction."""

Expand Down Expand Up @@ -50,3 +50,13 @@ def is_valid(self) -> bool:
),
]
)


@dataclasses.dataclass(kw_only=True)
class SHAPExplanation(Explanation):
"""Class representing SHAP explanations of a prediction.
This Explanation holds additional information only present in SHAP explanations.
"""

expected_value: npt.NDArray[np.float64]
33 changes: 24 additions & 9 deletions molpipeline/explainability/visualization/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ def evaluate(self) -> None:

def map2color(
self,
c_map: colors.Colormap | str,
v_lim: Sequence[float] | None = None,
c_map: colors.Colormap,
normalizer: colors.Normalize,
) -> ColorGrid:
"""Generate a ColorGrid from `self.values` according to given colormap.
Parameters
----------
c_map: Union[colors.Colormap, str]
c_map: colors.Colormap
Colormap to be used for mapping values to colors.
v_lim: Optional[Tuple[float, float]]
Limits for the colormap. If not given, the maximum absolute value of `self.values` is used as limit.
Expand All @@ -227,17 +227,32 @@ def map2color(
ColorGrid with colors corresponding to ValueGrid.
"""
color_grid = ColorGrid(self.x_lim, self.y_lim, self.x_res, self.y_res)
if not v_lim:
abs_max = np.max(np.abs(self.values))
v_lim = -abs_max, abs_max
normalizer = colors.Normalize(vmin=v_lim[0], vmax=v_lim[1])
if isinstance(c_map, str):
c_map = cm.get_cmap(c_map)
norm = normalizer(self.values)
color_grid.color_grid = np.array(c_map(norm))
return color_grid


def get_color_normalizer_from_data(
values: npt.NDArray[np.float64],
) -> colors.Normalize:
"""Create a color normalizer based on the data distribution of 'values'.
Parameters
----------
values: npt.NDArray[np.float64]
Data to derive limits of normalizer. The maximum absolute value of
values` is used as limit.
Returns
-------
colors.Normalize
Normalizer for colors.
"""
abs_max = np.max(np.abs(values))
normalizer = colors.Normalize(vmin=-abs_max, vmax=abs_max)
return normalizer


def color_canvas(canvas: Draw.MolDraw2D, color_grid: ColorGrid) -> None:
"""Draw a ColorGrid object to a RDKit Draw.MolDraw2D canvas.
Expand Down
Loading

0 comments on commit 8e0cc7b

Please sign in to comment.