Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 20, 2024
1 parent adaa3e3 commit a0ed00e
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 48 deletions.
7 changes: 5 additions & 2 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class AbstractSHAPExplainer(abc.ABC):
@abc.abstractmethod
def explain(
self, X: Any, **kwargs: Any
) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]:
) -> list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.
Parameters
Expand Down Expand Up @@ -166,6 +166,8 @@ class SHAPTreeExplainer(AbstractSHAPExplainer):
None if these failed instances should not be explained.
"""

return_type: type[SHAPFeatureExplanation] | type[SHAPFeatureAndAtomExplanation]

def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
"""Initialize the SHAPTreeExplainer.
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
if self.featurization_subpipeline is None:
raise ValueError("Could not determine the featurization subpipeline.")

# determine type of returned explanation
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]
if isinstance(featurization_element, MolToMorganFP):
self.return_type = SHAPFeatureAndAtomExplanation
Expand Down Expand Up @@ -238,7 +241,7 @@ def _prediction_is_valid(self, prediction: Any) -> bool:
# pylint: disable=C0103,W0613
def explain(
self, X: Any, **kwargs: Any
) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]:
) -> list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand Down
5 changes: 2 additions & 3 deletions molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class SHAPFeatureExplanation(
SHAPExplanationMixin,
_AbstractMoleculeExplanation, # base-class should be the last element https://www.ianlewis.org/en/mixins-and-python
):
"""Explanation of a molecular prediction using feature importance scores and SHAP."""
"""Explanation using feature importance scores from SHAP."""

def is_valid(self) -> bool:
"""Check if the explanation is valid.
Expand Down Expand Up @@ -94,8 +94,7 @@ class SHAPFeatureAndAtomExplanation(
AtomExplanationMixin,
_AbstractMoleculeExplanation,
):
"""Explanation of a molecular prediction using feature importance scores,
atom importance scores and SHAP."""
"""Explanation using feature and atom importance scores from SHAP."""

def is_valid(self) -> bool:
"""Check if the explanation is valid.
Expand Down
64 changes: 64 additions & 0 deletions tests/test_explainability/test_visualization/test_gaussian_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Test gaussian grid visualization."""

import unittest
from typing import ClassVar

import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw

from molpipeline import Pipeline
from molpipeline.explainability import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPTreeExplainer,
)
from molpipeline.explainability.visualization.visualization import (
make_sum_of_gaussians_grid,
)
from tests.test_explainability.test_visualization.test_visualization import (
_get_test_morgan_rf_pipeline,
)

TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"]
CONTAINS_OX = [0, 1, 1, 0, 1, 0]


class TestSumOfGaussiansGrid(unittest.TestCase):
"""Test sum of gaussian grid ."""

test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[
list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation]
]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
cls.test_pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX)
cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations = cls.test_explainer.explain(TEST_SMILES)

def test_grid_with_shap_atom_weights(self) -> None:
"""Test grid with SHAP atom weights."""
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)

mol_copy = Chem.Mol(explanation.molecule)
mol_copy = Draw.PrepareMolForDrawing(mol_copy)
value_grid = make_sum_of_gaussians_grid(
mol_copy,
atom_weights=explanation.atom_weights,
atom_width=np.inf,
grid_resolution=[64, 64],
padding=[0.4, 0.4],
)
self.assertIsNotNone(value_grid)
self.assertEqual(value_grid.values.size, 64 * 64)

# test that the range of summed gaussian values is as expected for SHAP
self.assertTrue(value_grid.values.min() >= -1)
self.assertTrue(value_grid.values.max() <= 1)
43 changes: 0 additions & 43 deletions tests/test_explainability/test_visualization/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
structure_heatmap,
structure_heatmap_shap,
)
from molpipeline.explainability.visualization.visualization import (
make_sum_of_gaussians_grid,
)
from molpipeline.mol2any import MolToMorganFP

TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"]
Expand Down Expand Up @@ -131,43 +128,3 @@ def test_explicit_hydrogens(self) -> None:
) # type: ignore[union-attr]
self.assertIsNotNone(image)
self.assertEqual(image.format, "PNG")


class TestSumOfGaussiansGrid(unittest.TestCase):
"""Test visualization methods for explanations."""

test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[
list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation]
]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
cls.test_pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX)
cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations = cls.test_explainer.explain(TEST_SMILES)

def test_grid_with_shap_atom_weights(self) -> None:
"""Test grid with SHAP atom weights."""
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)

mol_copy = Chem.Mol(explanation.molecule)
mol_copy = Draw.PrepareMolForDrawing(mol_copy)
value_grid = make_sum_of_gaussians_grid(
mol_copy,
atom_weights=explanation.atom_weights,
atom_width=np.inf,
grid_resolution=[64, 64],
padding=[0.4, 0.4],
)
self.assertIsNotNone(value_grid)
self.assertEqual(value_grid.values.size, 64 * 64)

# test that the range of summed gaussian values is as expected for SHAP
self.assertTrue(value_grid.values.min() >= -1)
self.assertTrue(value_grid.values.max() <= 1)

0 comments on commit a0ed00e

Please sign in to comment.