diff --git a/molpipeline/explainability/explainer.py b/molpipeline/explainability/explainer.py index 59a254c5..68e8ae39 100644 --- a/molpipeline/explainability/explainer.py +++ b/molpipeline/explainability/explainer.py @@ -12,7 +12,7 @@ from molpipeline import Pipeline from molpipeline.abstract_pipeline_elements.core import OptionalMol -from molpipeline.explainability.explanation import Explanation +from molpipeline.explainability.explanation import Explanation, SHAPExplanation from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights from molpipeline.mol2any import MolToMorganFP from molpipeline.utils.subpipeline import SubpipelineExtractor @@ -213,7 +213,7 @@ def _prediction_is_valid(self, prediction: Any) -> bool: return True # pylint: disable=C0103,W0613 - def explain(self, X: Any, **kwargs: Any) -> list[Explanation]: + def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]: """Explain the predictions for the input data. If the calculation of the SHAP values for an input sample fails, the explanation will be invalid. @@ -278,7 +278,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]: ) explanation_results.append( - Explanation( + SHAPExplanation( feature_vector=feature_vector, feature_names=feature_names, molecule=molecule, @@ -286,6 +286,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]: feature_weights=feature_weights, atom_weights=atom_weights, bond_weights=bond_weights, + expected_value=self.explainer.expected_value, ) ) diff --git a/molpipeline/explainability/explanation.py b/molpipeline/explainability/explanation.py index 054cd904..45a44fc2 100644 --- a/molpipeline/explainability/explanation.py +++ b/molpipeline/explainability/explanation.py @@ -10,7 +10,7 @@ from molpipeline.abstract_pipeline_elements.core import RDKitMol -@dataclasses.dataclass() +@dataclasses.dataclass(kw_only=True) class Explanation: """Class representing explanations of a prediction.""" @@ -50,3 +50,13 @@ def is_valid(self) -> bool: ), ] ) + + +@dataclasses.dataclass(kw_only=True) +class SHAPExplanation(Explanation): + """Class representing SHAP explanations of a prediction. + + This Explanation holds additional information only present in SHAP explanations. + """ + + expected_value: npt.NDArray[np.float64] diff --git a/molpipeline/explainability/visualization/heatmaps.py b/molpipeline/explainability/visualization/heatmaps.py index e1a3eca6..a1fda193 100644 --- a/molpipeline/explainability/visualization/heatmaps.py +++ b/molpipeline/explainability/visualization/heatmaps.py @@ -209,14 +209,14 @@ def evaluate(self) -> None: def map2color( self, - c_map: colors.Colormap | str, - v_lim: Sequence[float] | None = None, + c_map: colors.Colormap, + normalizer: colors.Normalize, ) -> ColorGrid: """Generate a ColorGrid from `self.values` according to given colormap. Parameters ---------- - c_map: Union[colors.Colormap, str] + c_map: colors.Colormap Colormap to be used for mapping values to colors. v_lim: Optional[Tuple[float, float]] Limits for the colormap. If not given, the maximum absolute value of `self.values` is used as limit. @@ -227,17 +227,32 @@ def map2color( ColorGrid with colors corresponding to ValueGrid. """ color_grid = ColorGrid(self.x_lim, self.y_lim, self.x_res, self.y_res) - if not v_lim: - abs_max = np.max(np.abs(self.values)) - v_lim = -abs_max, abs_max - normalizer = colors.Normalize(vmin=v_lim[0], vmax=v_lim[1]) - if isinstance(c_map, str): - c_map = cm.get_cmap(c_map) norm = normalizer(self.values) color_grid.color_grid = np.array(c_map(norm)) return color_grid +def get_color_normalizer_from_data( + values: npt.NDArray[np.float64], +) -> colors.Normalize: + """Create a color normalizer based on the data distribution of 'values'. + + Parameters + ---------- + values: npt.NDArray[np.float64] + Data to derive limits of normalizer. The maximum absolute value of + values` is used as limit. + + Returns + ------- + colors.Normalize + Normalizer for colors. + """ + abs_max = np.max(np.abs(values)) + normalizer = colors.Normalize(vmin=-abs_max, vmax=abs_max) + return normalizer + + def color_canvas(canvas: Draw.MolDraw2D, color_grid: ColorGrid) -> None: """Draw a ColorGrid object to a RDKit Draw.MolDraw2D canvas. diff --git a/molpipeline/explainability/visualization/visualization.py b/molpipeline/explainability/visualization/visualization.py index 0bbd8033..86baeb09 100644 --- a/molpipeline/explainability/visualization/visualization.py +++ b/molpipeline/explainability/visualization/visualization.py @@ -13,16 +13,20 @@ import numpy as np import numpy.typing as npt from PIL import Image +from matplotlib import pyplot as plt, colors from matplotlib.colors import Colormap, ListedColormap +from matplotlib.pyplot import get_cmap from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.Draw import rdMolDraw2D from molpipeline.abstract_pipeline_elements.core import RDKitMol +from molpipeline.explainability.explanation import SHAPExplanation from molpipeline.explainability.visualization.gauss import GaussFunctor2D from molpipeline.explainability.visualization.heatmaps import ( color_canvas, ValueGrid, + get_color_normalizer_from_data, ) RGBAtuple = tuple[float, float, float, float] @@ -113,9 +117,8 @@ def color_tuple_to_colormap( return newcmp -def _make_grid( +def _make_grid_from_mol( mol: Chem.Mol, - canvas: rdMolDraw2D.MolDraw2D, grid_resolution: Sequence[int], padding: Sequence[float], ) -> ValueGrid: @@ -125,8 +128,6 @@ def _make_grid( ---------- mol: Chem.Mol RDKit molecule object. - canvas: rdMolDraw2D.MolDraw2D - RDKit canvas. grid_resolution: Sequence[int] Resolution of the grid. padding: Sequence[float] @@ -147,14 +148,15 @@ def _make_grid( mol_width = xl[1] - xl[0] height_to_width_ratio_mol = mol_height / mol_width - height_to_width_ratio_canvas = canvas.Height() / canvas.Width() + # the grids height / weight is the canvas height / width + height_to_width_ratio_canvas = grid_resolution[1] / grid_resolution[0] if height_to_width_ratio_mol < height_to_width_ratio_canvas: - mol_height_new = canvas.Height() / canvas.Width() * mol_width + mol_height_new = height_to_width_ratio_canvas * mol_width yl[0] -= (mol_height_new - mol_height) / 2 yl[1] += (mol_height_new - mol_height) / 2 else: - mol_width_new = canvas.Width() / canvas.Height() * mol_height + mol_width_new = grid_resolution[0] / grid_resolution[1] * mol_height xl[0] -= (mol_width_new - mol_width) / 2 xl[1] += (mol_width_new - mol_width) / 2 @@ -266,17 +268,14 @@ def _add_gaussians_for_bonds( return v_map -def mapvalues2mol( +def make_sum_of_gaussians_grid( mol: Chem.Mol, + grid_resolution: Sequence[int], atom_weights: Sequence[float] | npt.NDArray[np.float64] | None = None, bond_weights: Sequence[float] | npt.NDArray[np.float64] | None = None, atom_width: float = 0.3, bond_width: float = 0.25, bond_length: float = 0.5, - canvas: rdMolDraw2D.MolDraw2D | None = None, - grid_resolution: Sequence[int] | None = None, - value_lims: Sequence[float] | None = None, - color: str | Colormap = "bwr", padding: Sequence[float] | None = None, ) -> rdMolDraw2D: """Map weights of atoms and bonds to the drawing of a RDKit molecular depiction. @@ -302,14 +301,8 @@ def mapvalues2mol( Value for the width of displayed bond weights (perpendicular to bond-axis). bond_length: float Value for the length of displayed bond weights (along the bond-axis). - canvas: rdMolDraw2D.MolDraw2D | None - RDKit canvas the molecule and heatmap are drawn onto. grid_resolution: Sequence[int] | None Number of pixels of x- and y-axis. - value_lims: Sequence[float] | None - Lower and upper limit of displayed values. Values exceeding limit are displayed as maximum (or minimum) value. - color: str | Colormap - Matplotlib colormap or string referring to a matplotlib colormap padding: Sequence[float] | None Increase of heatmap size, relative to size of molecule. Usually the heatmap is increased by 100% in each axis by padding 50% in each side. @@ -330,21 +323,6 @@ def mapvalues2mol( elif not isinstance(bond_weights, np.ndarray): bond_weights = np.array(bond_weights) - if not canvas: - canvas = rdMolDraw2D.MolDraw2DCairo(800, 450) - draw_opt = canvas.drawOptions() - draw_opt.padding = 0.2 - draw_opt.bondLineWidth = 3 - canvas.SetDrawOptions(draw_opt) - - if grid_resolution is None: - grid_resolution = [canvas.Width(), canvas.Height()] - - if padding is None: - # take padding from DrawOptions - draw_opt = canvas.drawOptions() - padding = [draw_opt.padding * 2, draw_opt.padding * 2] - # validate input if not len(atom_weights) == len(mol.GetAtoms()): raise ValueError("len(atom_weights) is not equal to number of bonds in mol") @@ -356,33 +334,50 @@ def mapvalues2mol( conf = mol.GetConformer(0) # setup grid and add functions for atoms and bonds - v_map = _make_grid(mol, canvas, grid_resolution, padding) - v_map = _add_gaussians_for_atoms(mol, conf, v_map, atom_weights, atom_width) - v_map = _add_gaussians_for_bonds( - mol, conf, v_map, bond_weights, bond_width, bond_length + value_grid = _make_grid_from_mol(mol, grid_resolution, padding) + value_grid = _add_gaussians_for_atoms( + mol, conf, value_grid, atom_weights, atom_width + ) + value_grid = _add_gaussians_for_bonds( + mol, conf, value_grid, bond_weights, bond_width, bond_length ) # evaluate all functions at pixel positions to obtain pixel values - v_map.evaluate() + value_grid.evaluate() - # create color-grid from the value grid. - c_grid = v_map.map2color(color, v_lim=value_lims) - # draw the molecule and erase it to initialize the grid - canvas.DrawMolecule(mol) - canvas.ClearDrawing() - # add the Colormap to the canvas - color_canvas(canvas, c_grid) - # add the molecule to the canvas - canvas.DrawMolecule(mol) - return canvas + return value_grid -def structure_heatmap( +def get_color_map_from_input( + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None +) -> Colormap: + # read user definer color scheme as ColorMap + if color is None: + coolwarm = ( + (0.017, 0.50, 0.850, 0.5), + (1.0, 1.0, 1.0, 0.5), + (1.0, 0.25, 0.0, 0.5), + ) + coolwarm = (coolwarm[2], coolwarm[1], coolwarm[0]) + color = coolwarm + if isinstance(color, Colormap): + color_map = color + elif isinstance(color, tuple): + color_map = color_tuple_to_colormap(color) # type: ignore + elif isinstance(color, str): + color_map = get_cmap(color) + else: + raise ValueError("Color must be a tuple, string or ColorMap.") + return color_map + + +def _structure_heatmap( mol: RDKitMol, atom_weights: npt.NDArray[np.float64], - color_tuple: tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None, + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None, width: int = 600, height: int = 600, + color_limits: tuple[float, float] | None = None, ) -> Draw.MolDraw2D: """Create a Gaussian plot on the molecular structure, highlight atoms with weighted Gaussians. @@ -392,8 +387,8 @@ def structure_heatmap( The molecule. atom_weights: npt.NDArray[np.float64] The atom weights. - color_tuple: Tuple[RGBAtuple, RGBAtuple, RGBAtuple] - The color tuple. + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None + The color map. width: int The width of the image in number of pixels. height: int @@ -410,36 +405,108 @@ def structure_heatmap( draw_opt = drawer.drawOptions() draw_opt.padding = 0.2 - if color_tuple is None: - coolwarm = ( - (0.017, 0.50, 0.850, 0.5), - (1.0, 1.0, 1.0, 0.5), - (1.0, 0.25, 0.0, 0.5), - ) - color_tuple = coolwarm - - color_map = color_tuple_to_colormap(color_tuple) + color_map = get_color_map_from_input(color) + # create the sums of gaussians value grid mol_copy = Chem.Mol(mol) mol_copy = Draw.PrepareMolForDrawing(mol_copy) - mapvalues2mol( + value_grid = make_sum_of_gaussians_grid( mol_copy, atom_weights=atom_weights, bond_weights=None, atom_width=0.5, # 0.4 bond_width=0.25, bond_length=0.5, - canvas=drawer, - grid_resolution=None, - value_lims=None, - color=color_map, - padding=None, + grid_resolution=[drawer.Width(), drawer.Height()], + padding=[draw_opt.padding * 2, draw_opt.padding * 2], ) + # create color-grid from the value grid. + if color_limits is None: + normalizer = get_color_normalizer_from_data(value_grid.values) + else: + normalizer = colors.Normalize(vmin=color_limits[0], vmax=color_limits[1]) + color_grid = value_grid.map2color(color_map, normalizer=normalizer) + + # draw the molecule and erase it to initialize the grid + drawer.DrawMolecule(mol) + drawer.ClearDrawing() + # add the Colormap to the canvas + color_canvas(drawer, color_grid) + # add the molecule to the canvas + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + return drawer, value_grid, color_grid, normalizer, color_map + + +def structure_heatmap( + mol: RDKitMol, + atom_weights: npt.NDArray[np.float64], + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None, + width: int = 600, + height: int = 600, + color_limits: tuple[float, float] | None = None, +) -> Draw.MolDraw2D: + drawer, *_ = _structure_heatmap( + mol, atom_weights, color, width, height, color_limits + ) return drawer +def structure_heatmap_shap_explanation( + explanation: SHAPExplanation, + color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None, + width: int = 600, + height: int = 600, + color_limits: tuple[float, float] | None = None, +) -> Draw.MolDraw2D: + # TODO this should only work if the feature vector is binary. Maybe raise an error otherwise? Or do something else? + present_shap = explanation.feature_weights[:, 1] * explanation.feature_vector + absent_shap = explanation.feature_weights[:, 1] * (1 - explanation.feature_vector) + sum_present_shap = sum(present_shap) + sum_absent_shap = sum(absent_shap) + + drawer, value_grid, color_grid, normalizer, color_map = _structure_heatmap( + explanation.molecule, + explanation.atom_weights, + color=color, + width=width, + height=height, + color_limits=color_limits, + ) + figure_bytes = drawer.GetDrawingText() + image = show_png(figure_bytes) + image_array = np.array(image) + + fig, ax = plt.subplots(figsize=(8, 8)) + + im = ax.imshow( + image_array, + cmap=color_map, + norm=normalizer, + ) + # remove ticks + ax.set_xticks([]) + ax.set_yticks([]) + # remove border + for spine in ax.spines.values(): + spine.set_visible(False) + + fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0) + + text = ( + f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ =" + "\n" + "\n" + f" $expected \ value={explanation.expected_value[1]:.2f}$ + " + f"$features_{{present}}= {sum_present_shap:.2f}$ + " + f"$features_{{absent}}={sum_absent_shap:.2f}$" + ) + fig.text(0.5, 0.18, text, ha="center") + return fig + + def show_png(data: bytes) -> Image.Image: """Show a PNG image from a byte stream. diff --git a/tests/test_explainability/test_visualization.py b/tests/test_explainability/test_visualization.py deleted file mode 100644 index 7a07ddbf..00000000 --- a/tests/test_explainability/test_visualization.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Test visualization methods for explanations.""" - -import unittest - -import numpy as np -from sklearn.ensemble import RandomForestClassifier - -from molpipeline import Pipeline -from molpipeline.any2mol import SmilesToMol -from molpipeline.explainability import SHAPTreeExplainer -from molpipeline.explainability.visualization.visualization import ( - structure_heatmap, - show_png, -) -from molpipeline.mol2any import MolToMorganFP - -TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] -CONTAINS_OX = [0, 1, 1, 0, 1, 0] - -_RANDOM_STATE = 67056 - - -class TestExplainabilityVisualization(unittest.TestCase): - """Test visualization methods for explanations.""" - - def test_test_fingerprint_based_atom_coloring(self) -> None: - """Test fingerprint-based atom coloring.""" - - pipeline = Pipeline( - [ - ("smi2mol", SmilesToMol()), - ("morgan", MolToMorganFP(radius=1, n_bits=1024)), - ( - "model", - RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), - ), - ] - ) - pipeline.fit(TEST_SMILES, CONTAINS_OX) - - explainer = SHAPTreeExplainer(pipeline) - explanations = explainer.explain(TEST_SMILES) - - for explanation in explanations: - self.assertTrue(explanation.is_valid()) - self.assertIsInstance(explanation.atom_weights, np.ndarray) - drawer = structure_heatmap( - explanation.molecule, - explanation.atom_weights, # type: ignore[arg-type] - width=128, - height=128, - ) # type: ignore[union-attr] - - self.assertIsNotNone(drawer) - - figure_bytes = drawer.GetDrawingText() - - image = show_png(figure_bytes) - - self.assertEqual(image.format, "PNG") diff --git a/tests/test_explainability/test_visualization/__init__.py b/tests/test_explainability/test_visualization/__init__.py new file mode 100644 index 00000000..5dd7b293 --- /dev/null +++ b/tests/test_explainability/test_visualization/__init__.py @@ -0,0 +1 @@ +"""Test explainability visualization.""" diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py new file mode 100644 index 00000000..939197bb --- /dev/null +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -0,0 +1,118 @@ +"""Test visualization methods for explanations.""" + +import unittest + +import numpy as np +from rdkit import Chem +from rdkit.Chem import Draw +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol +from molpipeline.explainability import SHAPTreeExplainer, Explanation +from molpipeline.explainability.visualization.visualization import ( + structure_heatmap, + show_png, + make_sum_of_gaussians_grid, +) +from molpipeline.mol2any import MolToMorganFP + +TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] +CONTAINS_OX = [0, 1, 1, 0, 1, 0] + +_RANDOM_STATE = 67056 + + +def _get_test_explanations() -> list[Explanation]: + """Get test explanations.""" + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=1024)), + ( + "model", + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + ), + ] + ) + pipeline.fit(TEST_SMILES, CONTAINS_OX) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES) + return explanations + + +class TestExplainabilityVisualization(unittest.TestCase): + """Test the public interface of the visualization methods for explanations.""" + + @classmethod + def setUpClass(cls) -> None: + """Set up the tests.""" + cls.explanations = _get_test_explanations() + + def test_fingerprint_based_atom_coloring(self) -> None: + """Test fingerprint-based atom coloring.""" + + for explanation in self.explanations: + self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation.atom_weights, np.ndarray) + drawer = structure_heatmap( + explanation.molecule, + explanation.atom_weights, # type: ignore[arg-type] + width=128, + height=128, + ) # type: ignore[union-attr] + self.assertIsNotNone(drawer) + figure_bytes = drawer.GetDrawingText() + image = show_png(figure_bytes) + self.assertEqual(image.format, "PNG") + + +class TestSumOfGaussiansGrid(unittest.TestCase): + """Test visualization methods for explanations.""" + + @classmethod + def setUpClass(cls) -> None: + """Set up the tests.""" + cls.explanations = _get_test_explanations() + + def test_grid_with_shap_atom_weights(self) -> None: + """Test grid with SHAP atom weights.""" + + for explanation in self.explanations: + self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation.atom_weights, np.ndarray) + + mol_copy = Chem.Mol(explanation.molecule) + mol_copy = Draw.PrepareMolForDrawing(mol_copy) + value_grid = make_sum_of_gaussians_grid( + mol_copy, + atom_weights=explanation.atom_weights, + atom_width=np.inf, + grid_resolution=[64, 64], + padding=[0.4, 0.4], + ) + self.assertIsNotNone(value_grid) + self.assertEqual(value_grid.values.size, 64 * 64) + + # test that the range of summed gaussian values is as expected for SHAP + self.assertTrue(value_grid.values.min() >= -1) + self.assertTrue(value_grid.values.max() <= 1) + + # def test_color_limits(self) -> None: + # """Test color limits.""" + # + # for explanation in self.explanations: + # self.assertTrue(explanation.is_valid()) + # self.assertIsInstance(explanation.atom_weights, np.ndarray) + # drawer = structure_heatmap( + # explanation.molecule, + # explanation.atom_weights, # type: ignore[arg-type] + # width=128, + # height=128, + # color_limits=(-1, 1), + # ) + # self.assertIsNotNone(drawer) + # figure_bytes = drawer.GetDrawingText() + # image = show_png(figure_bytes) + # self.assertEqual(image.format, "PNG")