diff --git a/lightautoml/ml_algo/dl_model.py b/lightautoml/ml_algo/dl_model.py index ac1dae5d..0460fa05 100644 --- a/lightautoml/ml_algo/dl_model.py +++ b/lightautoml/ml_algo/dl_model.py @@ -340,7 +340,7 @@ def _init_params_on_input(self, train_valid_iterator) -> dict: target = train_valid_iterator.train.target if params["n_out"] is None: - new_params["n_out"] = 1 if task_name != "multiclass" else np.max(target) + 1 + new_params["n_out"] = 1 if task_name != "multiclass" else (np.max(target) + 1).astype(int) new_params["n_out"] = target.shape[1] if task_name in ["multi:reg", "multilabel"] else new_params["n_out"] cat_dims = []