Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Aug 30, 2024
1 parent 6d5c2ae commit a7ff889
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,29 @@ def structure_heatmap_shap(

fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0)

text = (
f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
if isinstance(explanation.prediction, float):
# regression case
raise NotImplementedError("Regression case not yet implemented.")
elif isinstance(explanation.prediction, np.ndarray):
if len(explanation.prediction) == 2:
# binary classification case
text = (
f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
else:
raise ValueError(
"Unsupported number of classes for prediction. Only binary classification is supported."
)
else:
raise ValueError(
"Unsupported type for prediction. Expected float or np.ndarray."
)

fig.text(0.5, 0.18, text, ha="center")

image = plt_to_pil(fig)
Expand Down

0 comments on commit a7ff889

Please sign in to comment.