Skip to content

Commit

Permalink
try to add further SHAP explainers
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 21, 2024
1 parent 29af5e6 commit d6d880e
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 54 deletions.
130 changes: 107 additions & 23 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import shap
from scipy.sparse import issparse, spmatrix
from sklearn.base import BaseEstimator

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
Expand Down Expand Up @@ -48,6 +49,28 @@ def _to_dense(
return feature_matrix


def _get_prediction_function(pipeline: Pipeline | BaseEstimator) -> Any:
"""Get the prediction function of a model.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model.
Returns
-------
Any
The prediction function.
"""
if hasattr(pipeline, "predict_proba"):
return pipeline.predict_proba
if hasattr(pipeline, "decision_function"):
return pipeline.decision_function
if hasattr(pipeline, "predict"):
return pipeline.predict
raise ValueError("Could not determine the model output predictions")


# This function might also be put at a more central position in the lib.
def _get_predictions(
pipeline: Pipeline, feature_matrix: npt.NDArray[Any] | spmatrix
Expand All @@ -68,14 +91,8 @@ def _get_predictions(
npt.NDArray[np.float64]
The predictions.
"""
if hasattr(pipeline, "predict_proba"):
prediction = pipeline.predict_proba(feature_matrix)
elif hasattr(pipeline, "decision_function"):
prediction = pipeline.decision_function(feature_matrix)
elif hasattr(pipeline, "predict"):
prediction = pipeline.predict(feature_matrix)
else:
raise ValueError("Could not determine the model output predictions")
prediction_function = _get_prediction_function(pipeline)
prediction = prediction_function(feature_matrix)
return np.array(prediction)


Expand Down Expand Up @@ -129,7 +146,7 @@ def _convert_shap_feature_weights_to_atom_weights(

# pylint: disable=R0903
class AbstractSHAPExplainer(abc.ABC):
"""Abstract class for explainer objects."""
"""Abstract class for SHAP explainer objects."""

# pylint: disable=C0103,W0613
@abc.abstractmethod
Expand All @@ -153,22 +170,17 @@ def explain(


# pylint: disable=R0903
class SHAPTreeExplainer(AbstractSHAPExplainer):
"""Class for SHAP's TreeExplainer wrapper.
Wraps SHAP's TreeExplainer to explain predictions of a pipeline containing a
tree-based model.
Note on failed instances:
SHAPTreeExplainer will automatically handle fill values for failed instances and
returns an invalid explanation for them. However, fill values that could be valid
predictions, e.g. 0, are not necessarily detected. Set the fill value to np.nan or
None if these failed instances should not be explained.
"""
class _SHAPExplainerAdapter(AbstractSHAPExplainer):
"""Adapter for SHAP explainer wrappers for handling molecules and pipelines."""

return_type: type[SHAPFeatureExplanation] | type[SHAPFeatureAndAtomExplanation]

def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
def __init__(
self,
explainer_type: type[shap.Explainer, shap.TreeExplainer],
pipeline: Pipeline,
**kwargs: Any,
) -> None:
"""Initialize the SHAPTreeExplainer.
Parameters
Expand All @@ -186,8 +198,10 @@ def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
if model is None:
raise ValueError("Could not determine the model to explain.")

prediction_function = _get_prediction_function(model)
# set up the actual explainer
self.explainer = shap.TreeExplainer(
self.explainer = explainer_type(
# prediction_function,
model,
**kwargs,
)
Expand Down Expand Up @@ -334,3 +348,73 @@ def explain(
explanation_results.append(self.return_type(**explanation_data))

return explanation_results


class SHAPExplainer(_SHAPExplainerAdapter):
"""Wrapper for SHAP's Explainer that can handle pipelines and molecules."""

def __init__(
self,
pipeline: Pipeline,
**kwargs: Any,
) -> None:
"""Initialize the SHAPExplainer.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model to explain.
kwargs : Any
Additional keyword arguments for SHAP's Explainer.
"""
super().__init__(shap.Explainer, pipeline, **kwargs)


class SHAPTreeExplainer(_SHAPExplainerAdapter):
"""Wrapper for SHAP's TreeExplainer that can handle pipelines and molecules.
Wraps SHAP's TreeExplainer to explain predictions of a pipeline containing a
tree-based model.
Note on failed instances:
SHAPTreeExplainer will automatically handle fill values for failed instances and
returns an invalid explanation for them. However, fill values that could be valid
predictions, e.g. 0, are not necessarily detected. Set the fill value to np.nan or
None if these failed instances should not be explained.
"""

def __init__(
self,
pipeline: Pipeline,
**kwargs: Any,
) -> None:
"""Initialize the SHAPTreeExplainer.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model to explain.
kwargs : Any
Additional keyword arguments for SHAP's Explainer.
"""
super().__init__(shap.TreeExplainer, pipeline, **kwargs)


class SHAPKernelExplainer(_SHAPExplainerAdapter):
"""Wrapper for SHAP's KernelExplainer that can handle pipelines and molecules."""

def __init__(
self,
pipeline: Pipeline,
**kwargs: Any,
) -> None:
"""Initialize the SHAPKernelExplainer.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model to explain.
kwargs : Any
Additional keyword arguments for SHAP's Explainer.
"""
super().__init__(shap.KernelExplainer, pipeline, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
import shap
from rdkit import Chem, rdBase
from sklearn.base import BaseEstimator, is_classifier, is_regressor
from sklearn.ensemble import (
Expand All @@ -12,11 +13,17 @@
RandomForestClassifier,
RandomForestRegressor,
)
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.svm import SVC, SVR

from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability.explainer import SHAPTreeExplainer
from molpipeline.explainability.explainer import (
SHAPTreeExplainer,
SHAPExplainer,
SHAPKernelExplainer,
)
from molpipeline.explainability.explanation import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
Expand Down Expand Up @@ -47,8 +54,8 @@
_RANDOM_STATE = 67056


class TestSHAPTreeExplainer(unittest.TestCase):
"""Test SHAP's TreeExplainer wrapper."""
class TestSHAPExplainers(unittest.TestCase):
"""Test SHAP's Explainer wrappers."""

def _test_valid_explanation(
self,
Expand Down Expand Up @@ -131,44 +138,54 @@ def _test_valid_explanation(
def test_explanations_fingerprint_pipeline(self) -> None:
"""Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints."""

estimators = [
tree_estimators = [
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE),
]
# TODO: which estimators work with SHAP's Explainer and KernelExplainer?
# other_estimators = [SVC(kernel="rbf", probability=False), SVR(kernel="linear")]
# other_estimators = [LogisticRegression(), LinearRegression()]
other_estimators = []
n_bits = 64

# test explanations with different estimators
for estimator in estimators:
pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
("morgan", MolToMorganFP(radius=1, n_bits=n_bits)),
("model", estimator),
]
)
pipeline.fit(TEST_SMILES, CONTAINS_OX)
explainer_types = [SHAPExplainer, SHAPTreeExplainer]
explainer_estimators = [tree_estimators + other_estimators, tree_estimators]
# explainer_kwargs = [{}, {}]

explainer = SHAPTreeExplainer(pipeline)
explanations = explainer.explain(TEST_SMILES)
self.assertEqual(len(explanations), len(TEST_SMILES))
for estimators, explainer_type in zip(explainer_estimators, explainer_types):

# get the subpipeline that extracts the molecule from the input data
mol_reader_subpipeline = SubpipelineExtractor(
pipeline
).get_molecule_reader_subpipeline()
self.assertIsInstance(mol_reader_subpipeline, Pipeline)

for i, explanation in enumerate(explanations):
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[arg-type]
n_bits,
TEST_SMILES[i],
is_morgan_fingerprint=True,
# test explanations with different estimators
for estimator in estimators:
pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
("morgan", MolToMorganFP(radius=1, n_bits=n_bits)),
("model", estimator),
]
)
pipeline.fit(TEST_SMILES, CONTAINS_OX)

explainer = explainer_type(pipeline)
explanations = explainer.explain(TEST_SMILES)
self.assertEqual(len(explanations), len(TEST_SMILES))

# get the subpipeline that extracts the molecule from the input data
mol_reader_subpipeline = SubpipelineExtractor(
pipeline
).get_molecule_reader_subpipeline()
self.assertIsInstance(mol_reader_subpipeline, Pipeline)

for i, explanation in enumerate(explanations):
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[arg-type]
n_bits,
TEST_SMILES[i],
is_morgan_fingerprint=True,
)

# pylint: disable=too-many-locals
def test_explanations_pipeline_with_invalid_inputs(self) -> None:
Expand Down

0 comments on commit d6d880e

Please sign in to comment.