Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 25, 2024
1 parent 0adb1ff commit ab06fac
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
63 changes: 34 additions & 29 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ def explain(
class _SHAPExplainerAdapter(AbstractSHAPExplainer, abc.ABC):
"""Adapter for SHAP explainer wrappers for handling molecules and pipelines."""

return_type_: type[SHAPFeatureExplanation] | type[SHAPFeatureAndAtomExplanation]
return_element_type_: (
type[SHAPFeatureExplanation] | type[SHAPFeatureAndAtomExplanation]
)
return_type_: list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]

def __init__(
self,
pipeline: Pipeline,
explainer: SHAPTreeExplainer | SHAPKernelExplainer,
explainer: shap.TreeExplainer | shap.KernelExplainer,
**kwargs: Any,
) -> None:
"""Initialize the SHAPTreeExplainer.
Expand All @@ -188,8 +191,8 @@ def __init__(
----------
pipeline : Pipeline
The pipeline containing the model to explain.
explainer : SHAPTreeExplainer | SHAPKernelExplainer
The explainer object.
explainer : shap.TreeExplainer | shap.KernelExplainer
The shap explainer object.
kwargs : Any
Additional keyword arguments for SHAP's TreeExplainer.
"""
Expand All @@ -215,10 +218,10 @@ def __init__(
# determine type of returned explanation
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]
if isinstance(featurization_element, MolToMorganFP):
self.return_type_ = SHAPFeatureAndAtomExplanation
self.return_element_type_ = SHAPFeatureAndAtomExplanation
self.has_atom_weights_ = True
else:
self.return_type_ = SHAPFeatureExplanation
self.return_element_type_ = SHAPFeatureExplanation
self.has_atom_weights_ = False

@staticmethod
Expand Down Expand Up @@ -249,9 +252,7 @@ def _prediction_is_valid(prediction: Any) -> bool:

# pylint: disable=C0103,W0613
@override
def explain(
self, X: Any, **kwargs: Any
) -> list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]:
def explain(self, X: Any, **kwargs: Any) -> return_type_:
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand All @@ -271,7 +272,7 @@ def explain(
"""
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]

explanation_results = []
explanation_results: _SHAPExplainerAdapter.return_type_ = []
for input_sample in X:

input_sample = [input_sample]
Expand All @@ -280,7 +281,7 @@ def explain(
prediction = _get_predictions(self.pipeline, input_sample)
if not self._prediction_is_valid(prediction):
# we use the prediction to check if the input is valid. If not, we cannot explain it.
explanation_results.append(self.return_type_())
explanation_results.append(self.return_element_type_())
continue

if prediction.ndim > 1:
Expand All @@ -298,7 +299,7 @@ def explain(
# if the feature vector is empty, we cannot explain the prediction.
# This happens for failed instances in pipeline with fill values
# that could be valid predictions, like 0.
explanation_results.append(self.return_type_())
explanation_results.append(self.return_element_type_())
continue

# Feature names should also be extracted from the Pipeline.
Expand All @@ -313,9 +314,9 @@ def explain(
atom_weights = None
bond_weights = None

if issubclass(self.return_type_, AtomExplanationMixin) and isinstance(
featurization_element, MolToMorganFP
):
if issubclass(
self.return_element_type_, AtomExplanationMixin
) and isinstance(featurization_element, MolToMorganFP):
# for Morgan fingerprint, we can map the shap values to atom weights
atom_weights = _convert_shap_feature_weights_to_atom_weights(
feature_weights,
Expand All @@ -329,19 +330,19 @@ def explain(
"molecule": molecule,
"prediction": prediction,
}
if issubclass(self.return_type_, FeatureInfoMixin):
if issubclass(self.return_element_type_, FeatureInfoMixin):
explanation_data["feature_vector"] = feature_vector
explanation_data["feature_names"] = feature_names
if issubclass(self.return_type_, FeatureExplanationMixin):
if issubclass(self.return_element_type_, FeatureExplanationMixin):
explanation_data["feature_weights"] = feature_weights
if issubclass(self.return_type_, AtomExplanationMixin):
if issubclass(self.return_element_type_, AtomExplanationMixin):
explanation_data["atom_weights"] = atom_weights
if issubclass(self.return_type_, BondExplanationMixin):
if issubclass(self.return_element_type_, BondExplanationMixin):
explanation_data["bond_weights"] = bond_weights
if issubclass(self.return_type_, SHAPExplanationMixin):
if issubclass(self.return_element_type_, SHAPExplanationMixin):
explanation_data["expected_value"] = self.explainer.expected_value

explanation_results.append(self.return_type_(**explanation_data))
explanation_results.append(self.return_element_type_(**explanation_data))

return explanation_results

Expand Down Expand Up @@ -373,24 +374,26 @@ def __init__(
kwargs : Any
Additional keyword arguments for SHAP's Explainer.
"""
explainer = self._create_explainer(**kwargs)
explainer = self._create_explainer(pipeline, **kwargs)
super().__init__(pipeline, explainer, **kwargs)

@staticmethod
def _create_explainer(self, **kwargs: Any) -> Any:
def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer:
"""Create the TreeExplainer object from shap.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model to explain.
kwargs : Any
Additional keyword arguments for the explainer.
Returns
-------
Any
shap.TreeExplainer
The explainer object.
"""
model = get_model_from_pipeline(self.pipeline, raise_not_found=True)
model = get_model_from_pipeline(pipeline, raise_not_found=True)
explainer = shap.TreeExplainer(
model,
**kwargs,
Expand All @@ -415,24 +418,26 @@ def __init__(
kwargs : Any
Additional keyword arguments for SHAP's Explainer.
"""
explainer = self._create_explainer(**kwargs)
explainer = self._create_explainer(pipeline, **kwargs)
super().__init__(pipeline, explainer, **kwargs)

@staticmethod
def _create_explainer(self, **kwargs: Any) -> Any:
def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.KernelExplainer:
"""Create the explainer object.
Parameters
----------
pipeline : Pipeline
The pipeline containing the model to explain.
kwargs : Any
Additional keyword arguments for the explainer.
Returns
-------
Any
shap.KernelExplainer
The explainer object.
"""
model = get_model_from_pipeline(self.pipeline, raise_not_found=True)
model = get_model_from_pipeline(pipeline, raise_not_found=True)
prediction_function = _get_prediction_function(model)
explainer = shap.KernelExplainer(
prediction_function,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_explainability/test_shap_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_explanations_fingerprint_pipeline(self) -> None:

self.assertTrue(explainer.has_atom_weights_)
self.assertTrue(
issubclass(explainer.return_type_, AtomExplanationMixin)
issubclass(explainer.return_element_type_, AtomExplanationMixin)
)

# get the subpipeline that extracts the molecule from the input data
Expand Down

0 comments on commit ab06fac

Please sign in to comment.