diff --git a/daal4py/mb/model_builders.py b/daal4py/mb/model_builders.py index f949cff25f..a10b8f0d41 100644 --- a/daal4py/mb/model_builders.py +++ b/daal4py/mb/model_builders.py @@ -222,7 +222,9 @@ def _predict_regression( ) predict_result = predict_algo.compute(X, self.daal_model_) - if pred_interactions: + if pred_contribs: + return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1)) + elif pred_interactions: return predict_result.prediction.ravel().reshape( (-1, X.shape[1] + 1, X.shape[1] + 1) )