diff --git a/molpipeline/explainability/visualization/visualization.py b/molpipeline/explainability/visualization/visualization.py index b9ca9b0c..5276ec83 100644 --- a/molpipeline/explainability/visualization/visualization.py +++ b/molpipeline/explainability/visualization/visualization.py @@ -450,14 +450,29 @@ def structure_heatmap_shap( fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0) - text = ( - f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ =" - "\n" - "\n" - f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401 - f"$features_{{present}}= {sum_present_shap:.2f}$ + " - f"$features_{{absent}}={sum_absent_shap:.2f}$" - ) + if isinstance(explanation.prediction, float): + # regression case + raise NotImplementedError("Regression case not yet implemented.") + elif isinstance(explanation.prediction, np.ndarray): + if len(explanation.prediction) == 2: + # binary classification case + text = ( + f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ =" + "\n" + "\n" + f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401 + f"$features_{{present}}= {sum_present_shap:.2f}$ + " + f"$features_{{absent}}={sum_absent_shap:.2f}$" + ) + else: + raise ValueError( + "Unsupported number of classes for prediction. Only binary classification is supported." + ) + else: + raise ValueError( + "Unsupported type for prediction. Expected float or np.ndarray." + ) + fig.text(0.5, 0.18, text, ha="center") image = plt_to_pil(fig)