diff --git a/molpipeline/explainability/__init__.py b/molpipeline/explainability/__init__.py new file mode 100644 index 00000000..df450bd4 --- /dev/null +++ b/molpipeline/explainability/__init__.py @@ -0,0 +1,6 @@ +"""Explainability module for the molpipeline package.""" + +from molpipeline.explainability.explainer import SHAPTreeExplainer +from molpipeline.explainability.explanation import Explanation + +__all__ = ["Explanation", "SHAPTreeExplainer"] diff --git a/molpipeline/explainability/explainer.py b/molpipeline/explainability/explainer.py new file mode 100644 index 00000000..6f5d0a97 --- /dev/null +++ b/molpipeline/explainability/explainer.py @@ -0,0 +1,292 @@ +"""Explainer classes for explaining predictions.""" + +from __future__ import annotations + +import abc +from typing import Any + +import numpy as np +import numpy.typing as npt +import shap +from scipy.sparse import issparse, spmatrix + +from molpipeline import Pipeline +from molpipeline.abstract_pipeline_elements.core import OptionalMol +from molpipeline.explainability.explanation import Explanation +from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights +from molpipeline.mol2any import MolToMorganFP +from molpipeline.utils.subpipeline import SubpipelineExtractor + + +# pylint: disable=C0103,W0613 +def _to_dense( + feature_matrix: npt.NDArray[Any] | spmatrix, +) -> npt.NDArray[Any]: + """Mitigate feature incompatibility with SHAP objects. + + Parameters + ---------- + feature_matrix : npt.NDArray[Any] | spmatrix + The input features. + + Returns + ------- + Any + The input features in a compatible format. + """ + if issparse(feature_matrix): + return feature_matrix.todense() # type: ignore[union-attr] + return feature_matrix + + +# This function might also be put at a more central position in the lib. +def _get_predictions( + pipeline: Pipeline, feature_matrix: npt.NDArray[Any] | spmatrix +) -> npt.NDArray[np.float_]: + """Get the predictions of a model. + + Raises if no adequate method is found. + + Parameters + ---------- + pipeline : Pipeline + The pipeline containing the model. + feature_matrix : Any + The input data. + + Returns + ------- + npt.NDArray[np.float_] + The predictions. + """ + if hasattr(pipeline, "predict_proba"): + return pipeline.predict_proba(feature_matrix) + if hasattr(pipeline, "decision_function"): + return pipeline.decision_function(feature_matrix) + if hasattr(pipeline, "predict"): + return pipeline.predict(feature_matrix) + raise ValueError("Could not determine the model output predictions") + + +def _convert_shap_feature_weights_to_atom_weights( + feature_weights: npt.NDArray[np.float_], + molecule: OptionalMol, + featurization_element: MolToMorganFP, + feature_vector: npt.NDArray[np.float_], +) -> npt.NDArray[np.float_]: + """Convert SHAP feature weights to atom weights. + + Parameters + ---------- + feature_weights : npt.NDArray[np.float_] + The feature weights. + molecule : OptionalMol + The molecule. + featurization_element : MolToMorganFP + The featurization element. + feature_vector : npt.NDArray[np.float_] + The feature vector. + + Returns + ------- + npt.NDArray[np.float_] + The atom weights. + """ + if feature_weights.ndim == 1: + # regression case + feature_weights_present_bits_only = feature_weights.copy() + elif feature_weights.ndim == 2: + # binary classification case. Take the weights for the positive class. + feature_weights_present_bits_only = feature_weights[:, 1].copy() + else: + raise ValueError( + "Unsupported number of dimensions for feature weights. Expected 1 or 2." + ) + + # reset shap values for bits that are not present in the molecule + feature_weights_present_bits_only[feature_vector == 0] = 0 + + atom_weights = np.array( + fingerprint_shap_to_atomweights( + molecule, + featurization_element, + feature_weights_present_bits_only, + ) + ) + return atom_weights + + +# pylint: disable=R0903 +class AbstractExplainer(abc.ABC): + """Abstract class for explainer objects.""" + + # pylint: disable=C0103,W0613 + @abc.abstractmethod + def explain(self, X: Any, **kwargs: Any) -> list[Explanation]: + """Explain the predictions for the input data. + + Parameters + ---------- + X : Any + The input data to explain. + kwargs : Any + Additional keyword arguments. + + Returns + ------- + list[Explanation] + List of explanations corresponding to the input samples. + """ + + +# pylint: disable=R0903 +class SHAPTreeExplainer(AbstractExplainer): + """Class for SHAP's TreeExplainer wrapper.""" + + def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None: + """Initialize the SHAPTreeExplainer. + + Parameters + ---------- + pipeline : Pipeline + The pipeline containing the model to explain. + kwargs : Any + Additional keyword arguments for SHAP's TreeExplainer. + """ + self.pipeline = pipeline + pipeline_extractor = SubpipelineExtractor(self.pipeline) + + # extract the fitted model + model = pipeline_extractor.get_model_element() + if model is None: + raise ValueError("Could not determine the model to explain.") + + # set up the actual explainer + self.explainer = shap.TreeExplainer( + model, + **kwargs, + ) + + # extract the molecule reader subpipeline + self.molecule_reader_subpipeline = ( + pipeline_extractor.get_molecule_reader_subpipeline() + ) + if self.molecule_reader_subpipeline is None: + raise ValueError("Could not determine the molecule reader subpipeline.") + + # extract the featurization subpipeline + self.featurization_subpipeline = ( + pipeline_extractor.get_featurization_subpipeline() + ) + if self.featurization_subpipeline is None: + raise ValueError("Could not determine the featurization subpipeline.") + + # extract fill values for checking error handling + self.fill_values = pipeline_extractor.get_all_filter_reinserter_fill_values() + self.fill_values_contain_nan = np.isnan(self.fill_values).any() + + def _prediction_is_valid(self, prediction: Any) -> bool: + """Check if the prediction is valid using some heuristics. + + Can be used to catch inputs that failed the pipeline for some reason. + + Parameters + ---------- + prediction : Any + The prediction. + Returns + ------- + bool + Whether the prediction is valid. + """ + # if no prediction could be obtained (length is 0); the prediction guaranteed failed. + if len(prediction) == 0: + return False + + # if a value in the prediction is a fill-value, we - assume - the explanation has failed. + if np.isin(prediction, self.fill_values).any(): + return False + if self.fill_values_contain_nan and np.isnan(prediction).any(): + # the extra nan check is necessary because np.isin does not work with nan + return False + + return True + + # pylint: disable=C0103,W0613 + def explain(self, X: Any, **kwargs: Any) -> list[Explanation]: + """Explain the predictions for the input data. + + If the calculation of the SHAP values for an input sample fails, the explanation will be invalid. + This can be checked with the Explanation.is_valid() method. + + Parameters + ---------- + X : Any + The input data to explain. + kwargs : Any + Additional keyword arguments for SHAP's TreeExplainer.shap_values. + + Returns + ------- + list[Explanation] + List of explanations corresponding to the input data. + """ + featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr] + + explanation_results = [] + for input_sample in X: + + input_sample = [input_sample] + + # get predictions + prediction = _get_predictions(self.pipeline, input_sample) + if not self._prediction_is_valid(prediction): + # we use the prediction to check if the input is valid. If not, we cannot explain it. + explanation_results.append(Explanation()) + continue + + if prediction.ndim > 1: + prediction = prediction.squeeze() + + # get the molecule + molecule = self.molecule_reader_subpipeline.transform(input_sample)[0] # type: ignore[union-attr] + + # get feature vectors + feature_vector = self.featurization_subpipeline.transform(input_sample) # type: ignore[union-attr] + feature_vector = _to_dense(feature_vector) + feature_vector = np.asarray(feature_vector).squeeze() + + # Feature names should also be extracted from the Pipeline. + # But first, we need to add the names to the pipelines. + # Therefore, feature_names is just None currently. + feature_names = None + + # compute the shap values for the features + feature_weights = self.explainer.shap_values(feature_vector, **kwargs) + feature_weights = np.asarray(feature_weights).squeeze() + + atom_weights = None + bond_weights = None + + if isinstance(featurization_element, MolToMorganFP): + # for Morgan fingerprint, we can map the shap values to atom weights + atom_weights = _convert_shap_feature_weights_to_atom_weights( + feature_weights, + molecule, + featurization_element, + feature_vector, + ) + + explanation_results.append( + Explanation( + feature_vector=feature_vector, + feature_names=feature_names, + molecule=molecule, + prediction=prediction, + feature_weights=feature_weights, + atom_weights=atom_weights, + bond_weights=bond_weights, + ) + ) + + return explanation_results diff --git a/molpipeline/explainability/explanation.py b/molpipeline/explainability/explanation.py new file mode 100644 index 00000000..07ab0829 --- /dev/null +++ b/molpipeline/explainability/explanation.py @@ -0,0 +1,52 @@ +"""Module for explanation class.""" + +from __future__ import annotations + +import dataclasses + +import numpy as np +import numpy.typing as npt + +from molpipeline.abstract_pipeline_elements.core import RDKitMol + + +@dataclasses.dataclass() +class Explanation: + """Class representing explanations of a prediction.""" + + # input data + feature_vector: npt.NDArray[np.float_] | None = None + feature_names: list[str] | None = None + molecule: RDKitMol | None = None + prediction: float | npt.NDArray[np.float_] | None = None + + # explanation results mappable to the feature vector + feature_weights: npt.NDArray[np.float_] | None = None + + # explanation results mappable to the molecule. + atom_weights: npt.NDArray[np.float_] | None = None + bond_weights: npt.NDArray[np.float_] | None = None + + def is_valid(self) -> bool: + """Check if the explanation is valid. + + Returns + ------- + bool + True if the explanation is valid, False otherwise. + """ + return all( + [ + self.feature_vector is not None, + # self.feature_names is not None, + self.molecule is not None, + self.prediction is not None, + any( + [ + self.feature_weights is not None, + self.atom_weights is not None, + self.bond_weights is not None, + ] + ), + ] + ) diff --git a/molpipeline/explainability/fingerprint_utils.py b/molpipeline/explainability/fingerprint_utils.py new file mode 100644 index 00000000..e91374e4 --- /dev/null +++ b/molpipeline/explainability/fingerprint_utils.py @@ -0,0 +1,83 @@ +"""Utility functions for explainability.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Sequence + +import numpy as np +import numpy.typing as npt + +from molpipeline.abstract_pipeline_elements.core import RDKitMol +from molpipeline.mol2any import MolToMorganFP +from molpipeline.utils.substructure_handling import AtomEnvironment + + +def assign_prediction_importance( + bit_dict: dict[int, Sequence[AtomEnvironment]], weights: npt.NDArray[np.float_] +) -> dict[int, float]: + """Assign the prediction importance. + + Originally from Christian W. Feldmann + https://github.com/c-feldmann/compchemkit/blob/64e5543e2b8f72e93711186b2e0b42366820fb52/compchemkit/molecular_heatmaps.py#L28 + + Parameters + ---------- + bit_dict : dict[int, Sequence[AtomEnvironment]] + The bit dictionary. + weights : npt.NDArray[np.float_] + The weights. + + Returns + ------- + dict[int, float] + The atom contribution. + """ + atom_contribution: dict[int, float] = defaultdict(lambda: 0) + for bit, atom_env_list in bit_dict.items(): # type: int, Sequence[AtomEnvironment] + n_machtes = len(atom_env_list) + for atom_set in atom_env_list: + for atom in atom_set.environment_atoms: + atom_contribution[atom] += weights[bit] / ( + len(atom_set.environment_atoms) * n_machtes + ) + if not np.isclose(sum(weights), sum(atom_contribution.values())).all(): + raise AssertionError( + f"Weights and atom contributions don't sum to the same value:" + f" {weights.sum()} != {sum(atom_contribution.values())}" + ) + return atom_contribution + + +def fingerprint_shap_to_atomweights( + mol: RDKitMol, fingerprint_element: MolToMorganFP, shap_mat: npt.NDArray[np.float_] +) -> list[float]: + """Convert SHAP values to atom weights. + + Originally from Christian W. Feldmann + https://github.com/c-feldmann/compchemkit/blob/64e5543e2b8f72e93711186b2e0b42366820fb52/compchemkit/molecular_heatmaps.py#L15 + + Parameters + ---------- + mol : RDKitMol + The molecule. + fingerprint_element : MolToMorganFP + The fingerprint element. + shap_mat : npt.NDArray[np.float_] + The SHAP values. + + Returns + ------- + list[float] + The atom weights. + """ + bit_atom_env_dict: dict[int, Sequence[AtomEnvironment]] + bit_atom_env_dict = dict( + fingerprint_element.bit2atom_mapping(mol) + ) # MyPy invariants make me do this. + atom_weight_dict = assign_prediction_importance(bit_atom_env_dict, shap_mat) + atom_weight_list = [ + atom_weight_dict[a_idx] if a_idx in atom_weight_dict else 0 + for a_idx in range(mol.GetNumAtoms()) + ] + return atom_weight_list diff --git a/molpipeline/explainability/visualization.py b/molpipeline/explainability/visualization.py new file mode 100644 index 00000000..65da506f --- /dev/null +++ b/molpipeline/explainability/visualization.py @@ -0,0 +1,164 @@ +"""Visualization functions for the explainability module.""" + +from __future__ import annotations + +import io + +import numpy as np +import numpy.typing as npt +from PIL import Image +from rdkit import Geometry +from rdkit.Chem import Draw + +from molpipeline.abstract_pipeline_elements.core import RDKitMol + +RNGATuple = tuple[float, float, float, float] + + +def get_similaritymap_from_weights( + mol: RDKitMol, + weights: npt.NDArray[np.float_] | list[float] | tuple[float], + draw2d: Draw.MolDraw2DCairo, + sigma: float | None = None, + sigma_f: float = 0.3, + contour_lines: int = 10, + contour_params: Draw.ContourParams | None = None, +) -> Draw.MolDraw2D: + """Generate the similarity map for a molecule given the atomic weights. + + Strongly inspired from Chem.Draw.SimilarityMaps. + + Parameters + ---------- + mol: RDKitMol + The molecule of interest. + weights: Union[npt.NDArray[np.float_], List[float], Tuple[float]] + The atomic weights. + draw2d: Draw.MolDraw2DCairo + The drawer. + sigma: Optional[float] + The sigma value. + sigma_f: float + The sigma factor. + contour_lines: int + The number of contour lines. + contour_params: Optional[Draw.ContourParams] + The contour parameters. + + Returns + ------- + Draw.MolDraw2D + The drawer. + """ + if mol.GetNumAtoms() < 2: + raise ValueError("too few atoms") + mol = Draw.rdMolDraw2D.PrepareMolForDrawing(mol, addChiralHs=False) + if not mol.GetNumConformers(): + Draw.rdDepictor.Compute2DCoords(mol) + if sigma is None: + if mol.GetNumBonds() > 0: + bond = mol.GetBondWithIdx(0) + idx1 = bond.GetBeginAtomIdx() + idx2 = bond.GetEndAtomIdx() + sigma = ( + sigma_f + * ( + mol.GetConformer().GetAtomPosition(idx1) + - mol.GetConformer().GetAtomPosition(idx2) + ).Length() + ) + else: + sigma = ( + sigma_f + * ( + mol.GetConformer().GetAtomPosition(0) + - mol.GetConformer().GetAtomPosition(1) + ).Length() + ) + sigma = round(sigma, 2) + sigmas = [sigma] * mol.GetNumAtoms() + locs = [] + for i in range(mol.GetNumAtoms()): + atom_pos = mol.GetConformer().GetAtomPosition(i) + locs.append(Geometry.Point2D(atom_pos.x, atom_pos.y)) + draw2d.DrawMolecule(mol) + draw2d.ClearDrawing() + if not contour_params: + contour_params = Draw.ContourParams() + contour_params.fillGrid = True + contour_params.gridResolution = 0.1 + contour_params.extraGridPadding = 0.5 + Draw.ContourAndDrawGaussians( + draw2d, locs, weights, sigmas, nContours=contour_lines, params=contour_params + ) + draw2d.drawOptions().clearBackground = False + draw2d.DrawMolecule(mol) + return draw2d + + +def rdkit_gaussplot( + mol: RDKitMol, + weights: npt.NDArray[np.float_], + n_contour_lines: int = 5, + color_tuple: tuple[RNGATuple, RNGATuple, RNGATuple] | None = None, +) -> Draw.MolDraw2D: + """Create a Gaussian plot on the molecular structure, highlight atoms with weighted Gaussians. + + Parameters + ---------- + mol: RDKitMol + The molecule. + weights: npt.NDArray[np.float_] + The weights. + n_contour_lines: int + The number of contour lines. + color_tuple: Tuple[RNGATuple, RNGATuple, RNGATuple] + The color tuple. + + Returns + ------- + Draw.MolDraw2D + The configured drawer. + """ + drawer = Draw.MolDraw2DCairo(600, 600) + # Coloring atoms of element 0 to 100 black + drawer.drawOptions().updateAtomPalette({i: (0, 0, 0, 1) for i in range(100)}) + cps = Draw.ContourParams() + cps.fillGrid = True + cps.gridResolution = 0.02 + cps.extraGridPadding = 1.2 + coolwarm = ((0.017, 0.50, 0.850, 0.5), (1.0, 1.0, 1.0, 0.5), (1.0, 0.25, 0.0, 0.5)) + + if color_tuple is None: + color_tuple = coolwarm + + cps.setColourMap(color_tuple) + + drawer = get_similaritymap_from_weights( + mol, + weights, + contour_lines=n_contour_lines, + draw2d=drawer, + contour_params=cps, + sigma_f=0.4, + ) + drawer.FinishDrawing() + return drawer + + +def show_png(data: bytes) -> Image.Image: + """Show a PNG image from a byte stream. + + Parameters + ---------- + data: bytes + The image data. + + Returns + ------- + Image + The image. + """ + bio = io.BytesIO(data) + img = Image.open(bio) + return img diff --git a/requirements.txt b/requirements.txt index c6fab9f9..12d0bf87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ rdkit >= 2023.9.1 scipy setuptools scikit-learn >= 1.4.0 +shap typing_extensions diff --git a/tests/test_explainability/__init__.py b/tests/test_explainability/__init__.py new file mode 100644 index 00000000..dad24e8b --- /dev/null +++ b/tests/test_explainability/__init__.py @@ -0,0 +1 @@ +"""Test explainability methods and utilities.""" diff --git a/tests/test_explainability/test_shap_tree_explainer.py b/tests/test_explainability/test_shap_tree_explainer.py new file mode 100644 index 00000000..907d59ca --- /dev/null +++ b/tests/test_explainability/test_shap_tree_explainer.py @@ -0,0 +1,341 @@ +"""Test SHAP's TreeExplainer wrapper.""" + +import unittest + +import numpy as np +from rdkit import Chem +from sklearn.base import BaseEstimator, is_classifier, is_regressor +from sklearn.ensemble import ( + GradientBoostingClassifier, + GradientBoostingRegressor, + RandomForestClassifier, + RandomForestRegressor, +) + +from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper +from molpipeline.abstract_pipeline_elements.core import RDKitMol +from molpipeline.any2mol import SmilesToMol +from molpipeline.explainability.explainer import SHAPTreeExplainer +from molpipeline.explainability.explanation import Explanation +from molpipeline.mol2any import ( + MolToConcatenatedVector, + MolToMorganFP, + MolToRDKitPhysChem, +) +from molpipeline.mol2mol import SaltRemover +from molpipeline.utils.subpipeline import SubpipelineExtractor + +TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] +CONTAINS_OX = [0, 1, 1, 0, 1, 0] + +TEST_SMILES_WITH_BAD_SMILES = [ + "CC", + "CCO", + "COC", + "MY_FIRST_BAD_SMILES", + "c1ccccc1(N)", + "CCC(-O)O", + "CCCN", + "BAD_SMILES_2", +] +CONTAINS_OX_BAD_SMILES = [0, 1, 1, 0, 0, 1, 0, 1] + +_RANDOM_STATE = 67056 + + +class TestSHAPTreeExplainer(unittest.TestCase): + """Test SHAP's TreeExplainer wrapper.""" + + def _test_valid_explanation( + self, + explanation: Explanation, + estimator: BaseEstimator, + molecule_reader_subpipeline: Pipeline, + nof_features: int, + test_smiles: str, + is_morgan_fingerprint: bool, + ) -> None: + """Helper method to test if the explanation is valid and has the correct shape and content. + + Parameters + ---------- + explanation : Explanation + The explanation to be tested. + estimator : BaseEstimator + The estimator used in the pipeline. + molecule_reader_subpipeline : Pipeline + The subpipeline that extracts the molecule from the input data. + nof_features : int + The number of features in the feature vector. + test_smiles : str + The SMILES string of the molecule. + is_morgan_fingerprint : bool + Whether the feature vector is a Morgan fingerprint or not. + """ + self.assertTrue(explanation.is_valid()) + + self.assertIsInstance(explanation.feature_vector, np.ndarray) + self.assertEqual( + (nof_features,), explanation.feature_vector.shape # type: ignore[union-attr] + ) + + # feature names are not implemented yet + self.assertIsNone(explanation.feature_names) + # self.assertEqual(len(explanation.feature_names), explanation.feature_vector.shape[0]) + + self.assertIsInstance(explanation.molecule, RDKitMol) + self.assertEqual( + Chem.MolToInchi(*molecule_reader_subpipeline.transform([test_smiles])), + Chem.MolToInchi(explanation.molecule), + ) + + self.assertIsInstance(explanation.prediction, np.ndarray) + self.assertIsInstance(explanation.feature_weights, np.ndarray) + if is_regressor(estimator): + self.assertTrue((1,), explanation.prediction.shape) # type: ignore[union-attr] + self.assertEqual( + (nof_features,), explanation.feature_weights.shape # type: ignore[union-attr] + ) + elif is_classifier(estimator): + self.assertTrue((2,), explanation.prediction.shape) # type: ignore[union-attr] + if isinstance(estimator, GradientBoostingClassifier): + # there is currently a bug in SHAP's TreeExplainer for GradientBoostingClassifier + # https://github.com/shap/shap/issues/3177 returning only one feature weight + # which is also based on log odds. This check is a workaround until the bug is fixed. + self.assertEqual( + (nof_features,), explanation.feature_weights.shape # type: ignore[union-attr] + ) + else: + # normal binary classification case + self.assertEqual( + (nof_features, 2), explanation.feature_weights.shape # type: ignore[union-attr] + ) + else: + raise ValueError("Error in unittest. Unsupported estimator.") + + if is_morgan_fingerprint: + self.assertIsInstance(explanation.atom_weights, np.ndarray) + self.assertEqual( + explanation.atom_weights.shape, # type: ignore[union-attr] + (explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr] + ) + else: + self.assertIsNone(explanation.atom_weights) + + self.assertIsNone( + explanation.bond_weights + ) # SHAPTreeExplainer doesn't set bond weights yet + + def test_explanations_fingerprint_pipeline(self) -> None: + """Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints.""" + + estimators = [ + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE), + ] + n_bits = 64 + + # test explanations with different estimators + for estimator in estimators: + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=n_bits)), + ("model", estimator), + ] + ) + pipeline.fit(TEST_SMILES, CONTAINS_OX) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES) + self.assertEqual(len(explanations), len(TEST_SMILES)) + + # get the subpipeline that extracts the molecule from the input data + mol_reader_subpipeline = SubpipelineExtractor( + pipeline + ).get_molecule_reader_subpipeline() + self.assertIsInstance(mol_reader_subpipeline, Pipeline) + + for i, explanation in enumerate(explanations): + self._test_valid_explanation( + explanation, + estimator, + mol_reader_subpipeline, # type: ignore[arg-type] + n_bits, + TEST_SMILES[i], + is_morgan_fingerprint=True, + ) + + def test_explanations_pipeline_with_invalid_inputs(self) -> None: + """Test SHAP's TreeExplainer wrapper with invalid inputs.""" + + estimators = [ + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE), + ] + + n_bits = 64 + + for estimator in estimators: + + # pipeline with ErrorFilter + error_filter1 = ErrorFilter() + pipeline1 = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("salt_remover", SaltRemover()), + ("error_filter", error_filter1), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", estimator), + ] + ) + + # pipeline with ErrorFilter and FilterReinserter + error_filter2 = ErrorFilter() + error_reinserter2 = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter2, np.nan) + ) + pipeline2 = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("salt_remover", SaltRemover()), + ("error_filter", error_filter2), + ("morgan", MolToMorganFP(radius=1, n_bits=n_bits)), + ("model", estimator), + ("error_reinserter", error_reinserter2), + ] + ) + + for pipeline in [pipeline1, pipeline2]: + + pipeline.fit(TEST_SMILES_WITH_BAD_SMILES, CONTAINS_OX_BAD_SMILES) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES_WITH_BAD_SMILES) + self.assertEqual(len(explanations), len(TEST_SMILES_WITH_BAD_SMILES)) + + # get the subpipeline that extracts the molecule from the input data + mol_reader_subpipeline = SubpipelineExtractor( + pipeline + ).get_molecule_reader_subpipeline() + self.assertIsNotNone(mol_reader_subpipeline) + + for i, explanation in enumerate(explanations): + + # check that bad input results in invalid explanation + if i in [3, 7]: + self.assertFalse(explanation.is_valid()) + continue + + self._test_valid_explanation( + explanation, + estimator, + mol_reader_subpipeline, # type: ignore[arg-type] + n_bits, + TEST_SMILES_WITH_BAD_SMILES[i], + is_morgan_fingerprint=True, + ) + + def test_explanations_pipeline_with_physchem(self) -> None: + """Test SHAP's TreeExplainer wrapper on physchem feature vector.""" + + estimators = [ + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE), + ] + + # test explanations with different estimators + for estimator in estimators: + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("physchem", MolToRDKitPhysChem()), + ("model", estimator), + ] + ) + + pipeline.fit(TEST_SMILES, CONTAINS_OX) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES) + self.assertEqual(len(explanations), len(TEST_SMILES)) + + # get the subpipeline that extracts the molecule from the input data + mol_reader_subpipeline = SubpipelineExtractor( + pipeline + ).get_molecule_reader_subpipeline() + self.assertIsNotNone(mol_reader_subpipeline) + + for i, explanation in enumerate(explanations): + self._test_valid_explanation( + explanation, + estimator, + mol_reader_subpipeline, # type: ignore[arg-type] + pipeline.named_steps["physchem"].n_features, + TEST_SMILES[i], + is_morgan_fingerprint=False, + ) + + def test_explanations_pipeline_with_concatenated_features(self) -> None: + """Test SHAP's TreeExplainer wrapper on concatenated feature vector.""" + + estimators = [ + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE), + GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE), + ] + + n_bits = 64 + + # test explanations with different estimators + for estimator in estimators: + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ( + "features", + MolToConcatenatedVector( + [ + ( + "RDKitPhysChem", + MolToRDKitPhysChem(), + ), + ( + "MorganFP", + MolToMorganFP(radius=1, n_bits=n_bits), + ), + ] + ), + ), + ("model", estimator), + ] + ) + + pipeline.fit(TEST_SMILES, CONTAINS_OX) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES) + self.assertEqual(len(explanations), len(TEST_SMILES)) + + # get the subpipeline that extracts the molecule from the input data + mol_reader_subpipeline = SubpipelineExtractor( + pipeline + ).get_molecule_reader_subpipeline() + self.assertIsNotNone(mol_reader_subpipeline) + + for i, explanation in enumerate(explanations): + self._test_valid_explanation( + explanation, + estimator, + mol_reader_subpipeline, # type: ignore[arg-type] + pipeline.named_steps["features"].n_features, + TEST_SMILES[i], + is_morgan_fingerprint=False, + ) diff --git a/tests/test_explainability/test_visualization.py b/tests/test_explainability/test_visualization.py new file mode 100644 index 00000000..4b8234d5 --- /dev/null +++ b/tests/test_explainability/test_visualization.py @@ -0,0 +1,55 @@ +"""Test visualization methods for explanations.""" + +import unittest + +import numpy as np +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol +from molpipeline.explainability import SHAPTreeExplainer +from molpipeline.explainability.visualization import rdkit_gaussplot, show_png +from molpipeline.mol2any import MolToMorganFP + +TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] +CONTAINS_OX = [0, 1, 1, 0, 1, 0] + +_RANDOM_STATE = 67056 + + +class TestExplainabilityVisualization(unittest.TestCase): + """Test visualization methods for explanations.""" + + def test_test_fingerprint_based_atom_coloring(self) -> None: + """Test fingerprint-based atom coloring.""" + + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=1024)), + ( + "model", + RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), + ), + ] + ) + pipeline.fit(TEST_SMILES, CONTAINS_OX) + + explainer = SHAPTreeExplainer(pipeline) + explanations = explainer.explain(TEST_SMILES) + + for explanation in explanations: + self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation.atom_weights, np.ndarray) + drawer = rdkit_gaussplot( + explanation.molecule, + explanation.atom_weights.tolist(), # type: ignore[union-attr] + ) # type: ignore[union-attr] + + self.assertIsNotNone(drawer) + + figure_bytes = drawer.GetDrawingText() + + image = show_png(figure_bytes) + + self.assertEqual(image.format, "PNG")