Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Aug 30, 2024
1 parent 07befa2 commit d956cef
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.explainability.explanation import Explanation, SHAPExplanation
from molpipeline.explainability.explanation import SHAPExplanation
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor
Expand Down
2 changes: 1 addition & 1 deletion molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ class SHAPExplanation(Explanation):
This Explanation holds additional information only present in SHAP explanations.
"""

expected_value: npt.NDArray[np.float64]
expected_value: npt.NDArray[np.float64] = np.nan
18 changes: 13 additions & 5 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ def _add_gaussians_for_bonds(
def make_sum_of_gaussians_grid(
mol: Chem.Mol,
grid_resolution: Sequence[int],
padding: Sequence[float],
atom_weights: Sequence[float] | npt.NDArray[np.float64] | None = None,
bond_weights: Sequence[float] | npt.NDArray[np.float64] | None = None,
atom_width: float = 0.3,
bond_width: float = 0.25,
bond_length: float = 0.5,
padding: Sequence[float] | None = None,
) -> rdMolDraw2D:
"""Map weights of atoms and bonds to the drawing of a RDKit molecular depiction.
Expand All @@ -211,8 +211,10 @@ def make_sum_of_gaussians_grid(
----------
mol: Chem.Mol
RDKit molecule object which is displayed.
grid_resolution: Sequence[int] | None
grid_resolution: Sequence[int]
Number of pixels of x- and y-axis.
padding: Sequence[float]
Increase of heatmap size, relative to size of molecule.
atom_weights: Sequence[float] | npt.NDArray[np.float64] | None
Array of weights for atoms.
bond_weights: Sequence[float] | npt.NDArray[np.float64] | None
Expand All @@ -223,9 +225,6 @@ def make_sum_of_gaussians_grid(
Value for the width of displayed bond weights (perpendicular to bond-axis).
bond_length: float
Value for the length of displayed bond weights (along the bond-axis).
padding: Sequence[float] | None
Increase of heatmap size, relative to size of molecule. Usually the heatmap is increased by 100% in each axis
by padding 50% in each side.
Returns
-------
Expand Down Expand Up @@ -405,6 +404,15 @@ def structure_heatmap_shap(
Image
The image as PNG.
"""
if explanation.feature_weights is None:
raise ValueError("SHAPExplanation does not contain feature weights.")
if explanation.feature_vector is None:
raise ValueError("SHAPExplanation does not contain feature_vector.")
if explanation.molecule is None:
raise ValueError("SHAPExplanation does not contain molecule.")
if explanation.atom_weights is None:
raise ValueError("SHAPExplanation does not contain atom weights.")

present_shap = explanation.feature_weights[:, 1] * explanation.feature_vector
absent_shap = explanation.feature_weights[:, 1] * (1 - explanation.feature_vector)
sum_present_shap = sum(present_shap)
Expand Down

0 comments on commit d956cef

Please sign in to comment.