diff --git a/skyro/sktime.py b/skyro/sktime.py index 0e70758..98ae0a1 100644 --- a/skyro/sktime.py +++ b/skyro/sktime.py @@ -167,6 +167,9 @@ def _predict_proba(self, fh, X, marginal=True): if predictions.name is not None: warnings.warn(f"Name of the frame will be overwritten with '{columns}'!") + if len(columns) == 1: + columns = columns[0] + as_frame = predictions.to_dataframe(columns) if predictions.ndim > 2: