diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index bc2b004fc..e7c5a5328 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -23,6 +23,24 @@ @dataclass class Model: lagged_reg_layers: Optional[List[int]] + quantiles: Optional[List[float]] = None + + def setup_quantiles(self): + # convert quantiles to empty list [] if None + if self.quantiles is None: + self.quantiles = [] + # assert quantiles is a list type + assert isinstance(self.quantiles, list), "Quantiles must be provided as list." + # check if quantiles are float values in (0, 1) + assert all( + 0 < quantile < 1 for quantile in self.quantiles + ), "The quantiles specified need to be floats in-between (0, 1)." + # sort the quantiles + self.quantiles.sort() + # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index + self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] + # 0 is the median quantile index + self.quantiles.insert(0, 0.5) @dataclass @@ -92,10 +110,11 @@ class Train: batch_size: Optional[int] loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] optimizer: Union[str, Type[torch.optim.Optimizer]] - quantiles: List[float] = field(default_factory=list) + # quantiles: List[float] = field(default_factory=list) optimizer_args: dict = field(default_factory=dict) - scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None scheduler_args: dict = field(default_factory=dict) + early_stopping: Optional[bool] = False newer_samples_weight: float = 1.0 newer_samples_start: float = 0.0 reg_delay_pct: float = 0.5 @@ -103,19 +122,19 @@ class Train: trend_reg_threshold: Optional[Union[bool, float]] = None n_data: int = field(init=False) loss_func_name: str = field(init=False) - lr_finder_args: dict = field(default_factory=dict) + pl_trainer_config: dict = field(default_factory=dict) def __post_init__(self): - # assert the uncertainty estimation params and then finalize the quantiles - self.set_quantiles() assert self.newer_samples_weight >= 1.0 assert self.newer_samples_start >= 0.0 assert self.newer_samples_start < 1.0 - self.set_loss_func() - self.set_optimizer() - self.set_scheduler() + # self.set_loss_func(self.quantiles) - def set_loss_func(self): + # called in TimeNet configure_optimizers: + # self.set_optimizer() + # self.set_scheduler() + + def set_loss_func(self, quantiles: List[float]): if isinstance(self.loss_func, str): if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]: # keeping 'huber' for backwards compatiblility, though not identical @@ -135,25 +154,8 @@ def set_loss_func(self): self.loss_func_name = type(self.loss_func).__name__ else: raise NotImplementedError(f"Loss function {self.loss_func} not found") - if len(self.quantiles) > 1: - self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles) - - def set_quantiles(self): - # convert quantiles to empty list [] if None - if self.quantiles is None: - self.quantiles = [] - # assert quantiles is a list type - assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar." - # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index - self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)] - # check if quantiles are float values in (0, 1) - assert all( - 0 < quantile < 1 for quantile in self.quantiles - ), "The quantiles specified need to be floats in-between (0, 1)." - # sort the quantiles - self.quantiles.sort() - # 0 is the median quantile index - self.quantiles.insert(0, 0.5) + if len(quantiles) > 1: + self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles) def set_auto_batch_epoch( self, @@ -182,51 +184,88 @@ def set_optimizer(self): """ Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet. + + Parameters + ---------- + optimizer_name : int + Object provided to NeuralProphet as optimizer. + optimizer_args : dict + Arguments for the optimizer. + """ - self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config( - self.optimizer, self.optimizer_args - ) + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adamw": + # Tends to overfit, but reliable + self.optimizer = torch.optim.AdamW + self.optimizer_args["weight_decay"] = 1e-3 + elif self.optimizer.lower() == "sgd": + # better validation performance, but diverges sometimes + self.optimizer = torch.optim.SGD + self.optimizer_args["momentum"] = 0.9 + self.optimizer_args["weight_decay"] = 1e-4 + else: + raise ValueError( + f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class." + ) + elif not issubclass(self.optimizer, torch.optim.Optimizer): + raise ValueError("The provided optimizer is not supported.") def set_scheduler(self): """ - Set the scheduler and scheduler args. + Set the scheduler and scheduler arg depending on the user selection. The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet. """ - self.scheduler = torch.optim.lr_scheduler.OneCycleLR - self.scheduler_args.update( - { - "pct_start": 0.3, - "anneal_strategy": "cos", - "div_factor": 10.0, - "final_div_factor": 10.0, - "three_phase": True, - } - ) - def set_lr_finder_args(self, dataset_size, num_batches): - """ - Set the lr_finder_args. - This is the range of learning rates to test. - """ - num_training = 100 + int(np.log10(dataset_size) * 20) - if num_batches < num_training: - log.warning( - f"Learning rate finder: The number of batches ({num_batches}) is too small than the required number \ - for the learning rate finder ({num_training}). The results might not be optimal." - ) - # num_training = num_batches - self.lr_finder_args.update( - { - "min_lr": 1e-7, - "max_lr": 10, - "num_training": num_training, - "early_stop_threshold": None, - } - ) + if self.scheduler is None: + log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.") + self.scheduler = "exponentiallr" + + if isinstance(self.scheduler, str): + if self.scheduler.lower() in ["onecycle", "onecyclelr"]: + self.scheduler = torch.optim.lr_scheduler.OneCycleLR + defaults = { + "pct_start": 0.3, + "anneal_strategy": "cos", + "div_factor": 10.0, + "final_div_factor": 10.0, + "three_phase": True, + } + elif self.scheduler.lower() == "steplr": + self.scheduler = torch.optim.lr_scheduler.StepLR + defaults = { + "step_size": 10, + "gamma": 0.1, + } + elif self.scheduler.lower() == "exponentiallr": + self.scheduler = torch.optim.lr_scheduler.ExponentialLR + defaults = { + "gamma": 0.9, + } + elif self.scheduler.lower() == "cosineannealinglr": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + defaults = { + "T_max": 50, + } + elif self.scheduler.lower() == "cosineannealingwarmrestarts": + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + defaults = { + "T_0": 5, + "T_mult": 2, + } + else: + raise NotImplementedError( + f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class." + ) + if self.scheduler_args is not None: + defaults.update(self.scheduler_args) + self.scheduler_args = defaults + else: + assert issubclass( + self.scheduler, torch.optim.lr_scheduler.LRScheduler + ), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler" - def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0): + def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0): # Ignore type warning of epochs possibly being None (does not work with dataclasses) - progress = (e + iter_progress) / float(self.epochs) # type: ignore if reg_start_pct == reg_full_pct: reg_progress = float(progress > reg_start_pct) else: @@ -239,6 +278,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re delay_weight = 1 return delay_weight + def set_batches_per_epoch(self, batches_per_epoch: int): + self.batches_per_epoch = batches_per_epoch + @dataclass class Trend: diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 3039ca567..51f818377 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -11,10 +11,9 @@ import torch from matplotlib import pyplot from matplotlib.axes import Axes -from pytorch_lightning.tuner.tuning import Tuner from torch.utils.data import DataLoader -from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_metrics +from neuralprophet import configure, df_utils, np_types, time_dataset, time_net, utils, utils_lightning, utils_metrics from neuralprophet.data.process import ( _check_dataframe, _convert_raw_predictions_to_raw_df, @@ -234,7 +233,7 @@ class NeuralProphet: Train Config COMMENT learning_rate : float - Maximum learning rate setting for 1cycle policy scheduler. + Maximum learning rate setting for lr scheduler. Note ---- @@ -299,6 +298,20 @@ class NeuralProphet: >>> # use custorm torchmetrics names >>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError", + scheduler : str, torch.optim.lr_scheduler._LRScheduler + Type of learning rate scheduler to use. + + Options + * (default) ``OneCycleLR``: One Cycle Learning Rate scheduler + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler + + Examples + -------- + >>> from neuralprophet import NeuralProphet + >>> m = NeuralProphet(scheduler="ExponentialLR", scheduler_args={"gamma": 0.8}) + COMMENT Uncertainty Estimation COMMENT @@ -362,7 +375,7 @@ class NeuralProphet: select an available accelerator. Provide `None` to deactivate the use of accelerators. trainer_config: dict - Dictionary of additional trainer configuration parameters. + Dictionary of additional Pytorch Lighning Trainer configuration parameters. prediction_frequency: dict Set a periodic interval in which forecasts should be made. @@ -432,9 +445,11 @@ def __init__( batch_size: Optional[int] = None, loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss", optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW", + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = "onecyclelr", + scheduler_args: Optional[dict] = None, newer_samples_weight: float = 2, newer_samples_start: float = 0.0, - quantiles: List[float] = [], + quantiles: Optional[List[float]] = None, impute_missing: bool = True, impute_linear: int = 10, impute_rolling: int = 10, @@ -445,7 +460,7 @@ def __init__( global_time_normalization: bool = True, unknown_data_normalization: bool = False, accelerator: Optional[str] = None, - trainer_config: dict = {}, + trainer_config: Optional[dict] = None, prediction_frequency: Optional[dict] = None, ): self.config = locals() @@ -487,7 +502,11 @@ def __init__( self.max_lags = self.n_lags # Model - self.config_model = configure.Model(lagged_reg_layers=lagged_reg_layers) + self.config_model = configure.Model( + lagged_reg_layers=lagged_reg_layers, + quantiles=quantiles, + ) + self.config_model.setup_quantiles() # Trend self.config_trend = configure.Trend( @@ -503,15 +522,17 @@ def __init__( # Training self.config_train = configure.Train( - quantiles=quantiles, learning_rate=learning_rate, + scheduler=scheduler, + scheduler_args=scheduler_args, epochs=epochs, batch_size=batch_size, loss_func=loss_func, optimizer=optimizer, newer_samples_weight=newer_samples_weight, newer_samples_start=newer_samples_start, - trend_reg_threshold=self.config_trend.trend_reg_threshold, + early_stopping=False, + pl_trainer_config=trainer_config, ) # Seasonality @@ -549,7 +570,6 @@ def __init__( # Pytorch Lightning Trainer self.accelerator = accelerator - self.trainer_config = trainer_config # set during prediction self.future_periods = None @@ -723,7 +743,7 @@ def add_events( upper_window : int the upper window for the events in the list of events regularization : float - optional scale for regularization strength + optional scale for regularization strength (try values ~0.00001-0.001) mode : str ``additive`` (default) or ``multiplicative``. @@ -782,7 +802,7 @@ def add_country_holidays( upper_window : int the upper window for all the country holidays regularization : float - optional scale for regularization strength + optional scale for regularization strength (try values ~0.00001-0.001) mode : str ``additive`` (default) or ``multiplicative``. """ @@ -913,9 +933,11 @@ def fit( metrics_log_dir: Optional[str] = None, progress: Optional[str] = "bar", checkpointing: bool = False, - continue_training: bool = False, num_workers: int = 0, deterministic: bool = False, + scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None, + scheduler_args: Optional[dict] = None, + trainer_config: Optional[dict] = None, ): """Train, and potentially evaluate model. @@ -960,13 +982,18 @@ def fit( * `None` checkpointing : bool Flag whether to save checkpoints during training - continue_training : bool - Flag whether to continue training from the last checkpoint num_workers : int Number of workers for data loading. If 0, data will be loaded in the main process. Note: using multiple workers and therefore distributed training might significantly increase the training time since each batch needs to be copied to each worker for each epoch. Keeping all data on the main process might be faster for most datasets. + scheduler : str + Type of learning rate scheduler to use for continued training. If None, uses ExponentialLR as + default as specified in the model config. + Options + * ``StepLR``: Step Learning Rate scheduler + * ``ExponentialLR``: Exponential Learning Rate scheduler + * ``CosineAnnealingLR``: Cosine Annealing Learning Rate scheduler Returns ------- @@ -974,28 +1001,33 @@ def fit( metrics with training and potentially evaluation metrics """ if minimal: + # overrides these settings: checkpointing = False self.metrics = False progress = None if self.fitted: - raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") + raise RuntimeError("Model has been fitted already.") - # Train Config overrides + # Train Configuration: overwrite self.config_train with user provided values + if learning_rate is not None: + self.config_train.learning_rate = learning_rate + if scheduler is not None: + self.config_train.scheduler = scheduler + if scheduler_args is not None: + self.config_train.scheduler_args = scheduler_args if epochs is not None: self.config_train.epochs = epochs - if batch_size is not None: self.config_train.batch_size = batch_size - - if learning_rate is not None: - self.config_train.learning_rate = learning_rate - + if trainer_config is not None: + self.config_train.pl_trainer_config = trainer_config if early_stopping is not None: - self.early_stopping = early_stopping + self.config_train.early_stopping = early_stopping + self.config_train.set_loss_func(quantiles=self.config_model.quantiles) - # Warning for early stopping and regularization - if early_stopping: + # Warnings + if self.config_train.early_stopping: reg_enabled = utils.check_for_regularization( [ self.config_seasonality, @@ -1073,8 +1105,6 @@ def fit( or any(value != 1 for value in self.num_seasonalities_modelled_dict.values()) ) - if self.fitted is True and not continue_training: - log.error("Model has already been fitted. Re-fitting may break or produce different results.") self.max_lags = df_utils.get_max_num_lags( n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors ) @@ -1093,7 +1123,6 @@ def fit( progress_bar_enabled=bool(progress), metrics_enabled=bool(self.metrics), checkpointing_enabled=checkpointing, - continue_training=continue_training, num_workers=num_workers, deterministic=deterministic, ) @@ -1118,7 +1147,6 @@ def fit( progress_bar_enabled=bool(progress), metrics_enabled=bool(self.metrics), checkpointing_enabled=checkpointing, - continue_training=continue_training, num_workers=num_workers, deterministic=deterministic, ) @@ -1204,7 +1232,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a dates=dates, predicted=predicted, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, components=components, ) if auto_extend and periods_added[df_name] > 0: @@ -1219,7 +1247,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a n_forecasts=self.n_forecasts, max_lags=self.max_lags, freq=self.data_freq, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, config_lagged_regressors=self.config_lagged_regressors, ) if auto_extend and periods_added[df_name] > 0: @@ -1261,9 +1289,12 @@ def test(self, df: pd.DataFrame, verbose: bool = True): config_seasonality=self.config_seasonality, predicting=False, ) - loader = self._init_val_loader(df) + df, _, _, _ = df_utils.prep_or_copy_df(df) + df = _normalize(df=df, config_normalization=self.config_normalization) + dataset = _create_dataset(self, df, predict_mode=False) + test_loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False) # Use Lightning to calculate metrics - val_metrics = self.trainer.test(self.model, dataloaders=loader, verbose=verbose) + val_metrics = self.trainer.test(self.model, dataloaders=test_loader, verbose=verbose) val_metrics_df = pd.DataFrame(val_metrics) # TODO Check whether supported by Lightning if not self.config_normalization.global_normalization: @@ -1860,7 +1891,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5): else: meta_name_tensor = None - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) trend = self.model.trend(t, meta_name_tensor).detach().numpy()[:, :, quantile_index].squeeze() data_params = self.config_normalization.get_data_params(df_name) @@ -1925,7 +1956,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): for name in self.config_seasonality.periods: features = inputs["seasonalities"][name] - quantile_index = self.config_train.quantiles.index(quantile) + quantile_index = self.config_model.quantiles.index(quantile) y_season = torch.squeeze( self.model.seasonality.compute_fourier(features=features, name=name, meta=meta_name_tensor)[ :, :, quantile_index @@ -2057,7 +2088,7 @@ def plot( log.info(f"Plotting data from ID {df_name}") if forecast_in_focus is None: forecast_in_focus = self.highlight_forecast_step_n - if len(self.config_train.quantiles) > 1: + if len(self.config_model.quantiles) > 1: if (self.highlight_forecast_step_n) is None and ( self.n_forecasts > 1 or self.n_lags > 0 ): # rather query if n_forecasts >1 than n_lags>1 @@ -2097,7 +2128,7 @@ def plot( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, xlabel=xlabel, ylabel=ylabel, figsize=tuple(x * 70 for x in figsize), @@ -2108,7 +2139,7 @@ def plot( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ax=ax, xlabel=xlabel, ylabel=ylabel, @@ -2177,7 +2208,7 @@ def get_latest_forecast( elif include_history_data is True: fcst = fcst fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts ) return fcst @@ -2246,7 +2277,7 @@ def plot_latest_forecast( else: fcst = fcst[fcst["ID"] == df_name].copy(deep=True) log.info(f"Plotting data from ID {df_name}") - if len(self.config_train.quantiles) > 1: + if len(self.config_model.quantiles) > 1: log.warning( "Plotting latest forecasts when uncertainty estimation enabled" " plots only the median quantile forecasts." @@ -2258,7 +2289,7 @@ def plot_latest_forecast( elif plot_history_data is True: fcst = fcst fcst = utils.fcst_df_to_latest_forecast( - fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts + fcst, self.config_model.quantiles, n_last=1 + include_previous_forecasts ) # Check whether a local or global plotting backend is set. @@ -2268,7 +2299,7 @@ def plot_latest_forecast( if plotting_backend.startswith("plotly"): return plot_plotly( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ylabel=ylabel, xlabel=xlabel, figsize=tuple(x * 70 for x in figsize), @@ -2280,7 +2311,7 @@ def plot_latest_forecast( else: return plot( fcst=fcst, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ax=ax, ylabel=ylabel, xlabel=xlabel, @@ -2446,7 +2477,7 @@ def plot_components( m=self, fcst=fcst, plot_configuration=valid_plot_configuration, - quantile=self.config_train.quantiles[0], # plot components only for median quantile + quantile=self.config_model.quantiles[0], # plot components only for median quantile figsize=figsize, df_name=df_name, one_period_per_season=one_period_per_season, @@ -2556,11 +2587,11 @@ def plot_parameters( if not (0 < quantile < 1): raise ValueError("The quantile selected needs to be a float in-between (0,1)") # ValueError if selected quantile is out of range - if quantile not in self.config_train.quantiles: + if quantile not in self.config_model.quantiles: raise ValueError("Selected quantile is not specified in the model configuration.") else: # plot parameters for median quantile if not specified - quantile = self.config_train.quantiles[0] + quantile = self.config_model.quantiles[0] # Validate components to be plotted valid_parameters_set = [ @@ -2628,13 +2659,9 @@ def plot_parameters( ) def _init_model(self): - """Build Pytorch model with configured hyperparamters. - - Returns - ------- - TimeNet model - """ - self.model = time_net.TimeNet( + """Build Pytorch model with configured hyperparamters.""" + model = time_net.TimeNet( + config_model=self.config_model, config_train=self.config_train, config_trend=self.config_trend, config_ar=self.config_ar, @@ -2656,10 +2683,10 @@ def _init_model(self): num_seasonalities_modelled_dict=self.num_seasonalities_modelled_dict, meta_used_in_model=self.meta_used_in_model, ) - log.debug(self.model) - return self.model + log.debug(model) + return model - def _init_train_loader(self, df, num_workers=0): + def _data_setup(self, df): """Executes data preparation steps and initiates training procedure. Parameters @@ -2673,66 +2700,38 @@ def _init_train_loader(self, df, num_workers=0): ------- torch DataLoader """ - df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided? - # if not self.fitted: - self.config_normalization.init_data_params( - df=df, - config_lagged_regressors=self.config_lagged_regressors, - config_regressors=self.config_regressors, - config_events=self.config_events, - config_seasonality=self.config_seasonality, - ) - - df = _normalize(df=df, config_normalization=self.config_normalization) - # if not self.fitted: - if self.config_trend.changepoints is not None: - # scale user-specified changepoint times - df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) - - df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) - self.config_trend.changepoints = df_normalized["t"].values # type: ignore - - # df_merged, _ = df_utils.join_dataframes(df) - # df_merged = df_merged.sort_values("ds") - # df_merged.drop_duplicates(inplace=True, keep="first", subset=["ds"]) - df_merged = df_utils.merge_dataframes(df) - self.config_seasonality = utils.set_auto_seasonalities(df_merged, config_seasonality=self.config_seasonality) - if self.config_country_holidays is not None: - self.config_country_holidays.init_holidays(df_merged) - - dataset = _create_dataset( - self, df, predict_mode=False, prediction_frequency=self.prediction_frequency - ) # needs to be called after set_auto_seasonalities - - # Determine the max_number of epochs - self.config_train.set_auto_batch_epoch(n_data=len(dataset)) + df, _, _, _ = df_utils.prep_or_copy_df(df) - loader = DataLoader( - dataset, - batch_size=self.config_train.batch_size, - shuffle=True, - num_workers=num_workers, - ) + if not self.fitted: + # Initialize data normalization parameters + self.config_normalization.init_data_params( + df=df, + config_lagged_regressors=self.config_lagged_regressors, + config_regressors=self.config_regressors, + config_events=self.config_events, + config_seasonality=self.config_seasonality, + ) - return loader + if not self.fitted: + # scale user-specified changepoint times + if self.config_trend.changepoints is not None: + df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) + df_aux = _normalize(df=df_aux, config_normalization=self.config_normalization) + self.config_trend.changepoints = df_aux["t"].values - def _init_val_loader(self, df): - """Executes data preparation steps and initiates evaluation procedure. + # Apply normalization to data + df = _normalize(df=df, config_normalization=self.config_normalization) - Parameters - ---------- - df : pd.DataFrame - dataframe containing column ``ds``, ``y``, and optionally``ID`` with all data + if not self.fitted: + # Temporarily merge df to set auto seasaoanlities and country holidays + df_merged = df_utils.merge_dataframes(df) + self.config_seasonality = utils.set_auto_seasonalities( + df_merged, config_seasonality=self.config_seasonality + ) + if self.config_country_holidays is not None: + self.config_country_holidays.init_holidays(df_merged) - Returns - ------- - torch DataLoader - """ - df, _, _, _ = df_utils.prep_or_copy_df(df) - df = _normalize(df=df, config_normalization=self.config_normalization) - dataset = _create_dataset(self, df, predict_mode=False) - loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False) - return loader + return df def _train( self, @@ -2741,7 +2740,6 @@ def _train( progress_bar_enabled: bool = True, metrics_enabled: bool = False, checkpointing_enabled: bool = False, - continue_training=False, num_workers=0, deterministic: bool = False, ): @@ -2760,8 +2758,6 @@ def _train( whether to collect metrics during training checkpointing_enabled : bool whether to save checkpoints during training - continue_training : bool - whether to continue training from the last checkpoint num_workers : int number of workers for data loading @@ -2770,91 +2766,98 @@ def _train( pd.DataFrame metrics """ - # Set up data the training dataloader - df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be removed? - train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(df) # train_loader.dataset - - # Internal flag to check if validation is enabled - validation_enabled = df_val is not None - - # Init the model, if not continue from checkpoint - if continue_training: - raise NotImplementedError( - "Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \ - upcoming releases." - ) - else: - self.model = self._init_model() + # Set up train dataset and data dependent configurations + df = self._data_setup(df) + # Note: _create_dataset() needs to be called after set_auto_seasonalities() + dataset = _create_dataset(self, df, predict_mode=False, prediction_frequency=self.prediction_frequency) + # Determine the max_number of epochs + self.config_train.set_auto_batch_epoch(n_data=len(dataset)) + + # Set up DataLoaders: Train + loader = DataLoader( + dataset, + batch_size=self.config_train.batch_size, + shuffle=True, + num_workers=num_workers, + ) + self.config_train.set_batches_per_epoch(len(loader)) + log.info(f"Train Dataset size: {len(dataset)}") + log.info(f"Number of batches per training epoch: {len(loader)}") - self.model.train_loader = train_loader + # Set up DataLoaders: Validation + validation_enabled = df_val is not None and isinstance(df_val, pd.DataFrame) + if validation_enabled: + df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) + df_val = _normalize(df=df_val, config_normalization=self.config_normalization) + dataset_val = _create_dataset(self, df_val, predict_mode=False) + loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False) # Init the Trainer - self.trainer, checkpoint_callback = utils.configure_trainer( + self.trainer, checkpoint_callback = utils_lightning.configure_trainer( config_train=self.config_train, - config=self.trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.early_stopping, early_stopping_target="Loss_val" if validation_enabled else "Loss", accelerator=self.accelerator, progress_bar_enabled=progress_bar_enabled, metrics_enabled=metrics_enabled, checkpointing_enabled=checkpointing_enabled, - num_batches_per_epoch=len(train_loader), + num_batches_per_epoch=len(loader), deterministic=deterministic, ) - # Tune hyperparams and train - if validation_enabled: - # Set up data the validation dataloader - df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) - val_loader = self._init_val_loader(df_val) - - if not continue_training and not self.config_train.learning_rate: - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - # Find suitable learning rate - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - # val_dataloaders=val_loader, # not be used, but may lead to Lightning bug if not provided - **self.config_train.lr_finder_args, - ) - # Estimate the optimal learning rate from the loss curve - assert lr_finder is not None - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - start = time.time() - self.trainer.fit( - self.model, - train_loader, - val_loader, - ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, - ) - else: - if not continue_training and not self.config_train.learning_rate: - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - # Find suitable learning rate - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - **self.config_train.lr_finder_args, - ) - assert lr_finder is not None - # Estimate the optimal learning rate from the loss curve - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - start = time.time() - self.trainer.fit( - self.model, - train_loader, - ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None, - ) + # Set up the model for training + if not self.fitted: + self.model = self._init_model() - log.debug("Train Time: {:8.3f}".format(time.time() - start)) + # Find suitable learning rate if not set + if self.config_train.learning_rate is None: + assert not self.fitted, "Learning rate must be provided for re-training a fitted model." + + # Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety) + # Note Leads to a CUDA issue. Needs to be fixed before enabling this feature. + # model_lr_finder = self._init_model() + # loader_lr_finder = DataLoader( + # dataset, + # batch_size=self.config_train.batch_size, + # shuffle=True, + # num_workers=num_workers, + # ) + # trainer_lr_finder, _ = utils_lightning.configure_trainer( + # config_train=self.config_train, + # metrics_logger=self.metrics_logger, + # early_stopping_target="Loss", + # accelerator=self.accelerator, + # progress_bar_enabled=progress_bar_enabled, + # metrics_enabled=False, + # checkpointing_enabled=False, + # num_batches_per_epoch=len(loader), + # deterministic=deterministic, + # ) + + # Setup and execute LR finder + suggested_lr = utils_lightning.find_learning_rate( + model=self.model, # model_lr_finder, + loader=loader, # loader_lr_finder, + trainer=self.trainer, # trainer_lr_finder, + train_epochs=self.config_train.epochs, + ) + # Clean up the LR finder copies of Model, Loader and Trainer + # del model_lr_finder, loader_lr_finder, trainer_lr_finder + + # Save the suggested learning rate + self.config_train.learning_rate = suggested_lr + self.model.finding_lr = False + + # Execute Training Loop + start = time.time() + self.trainer.fit( + model=self.model, + train_dataloaders=loader, + val_dataloaders=loader_val if validation_enabled else None, + ) + log.info("Train Time: {:8.3f}".format(time.time() - start)) - # Load best model from training + # Load best model from checkpoint if end state not best if checkpoint_callback is not None: if checkpoint_callback.best_model_score < checkpoint_callback.current_score: log.info( @@ -2863,10 +2866,12 @@ def _train( ) self.model = time_net.TimeNet.load_from_checkpoint(checkpoint_callback.best_model_path) - if not metrics_enabled: - return None - # Return metrics collected in logger as dataframe - metrics_df = pd.DataFrame(self.metrics_logger.history) + if metrics_enabled: + # Return metrics collected in logger as dataframe + metrics_df = pd.DataFrame(self.metrics_logger.history) + else: + metrics_df = None + return metrics_df def restore_trainer(self, accelerator: Optional[str] = None): @@ -2878,11 +2883,9 @@ def restore_trainer(self, accelerator: Optional[str] = None): """ Restore the trainer based on the forecaster configuration. """ - self.trainer, _ = utils.configure_trainer( + self.trainer, _ = utils_lightning.configure_trainer( config_train=self.config_train, - config=self.trainer_config, metrics_logger=self.metrics_logger, - early_stopping=self.early_stopping, accelerator=accelerator, metrics_enabled=bool(self.metrics), ) @@ -3085,7 +3088,7 @@ def conformal_predict( alpha=alpha, method=method, n_forecasts=self.n_forecasts, - quantiles=self.config_train.quantiles, + quantiles=self.config_model.quantiles, ) df_forecast = c.predict(df=df_test, df_cal=df_cal, show_all_PI=show_all_PI) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index a4fbfee3a..1097c8e57 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -42,6 +42,7 @@ class TimeNet(pl.LightningModule): def __init__( self, + config_model: configure.Model, config_seasonality: configure.ConfigSeasonality, config_train: Optional[configure.Train] = None, config_trend: Optional[configure.Trend] = None, @@ -149,6 +150,7 @@ def __init__( pass # General + self.config_model = config_model self.n_forecasts = n_forecasts # Lightning Config @@ -156,14 +158,13 @@ def __init__( self.config_normalization = config_normalization self.compute_components_flag = compute_components_flag - # Optimizer and LR Scheduler - self._optimizer = self.config_train.optimizer - self._scheduler = self.config_train.scheduler + # Manual optimization: we are responsible for calling .backward(), .step(), .zero_grad(). self.automatic_optimization = False # Hyperparameters (can be tuned using trainer.tune()) - self.learning_rate = self.config_train.learning_rate if self.config_train.learning_rate is not None else 1e-3 + self.learning_rate = self.config_train.learning_rate self.batch_size = self.config_train.batch_size + self.finding_lr = False # flag to indicate if we are in lr finder mode # Metrics Config self.metrics_enabled = bool(metrics) # yields True if metrics is not an empty dictionary @@ -200,7 +201,7 @@ def __init__( ) # Quantiles - self.quantiles = self.config_train.quantiles + self.quantiles = self.config_model.quantiles # Trend self.config_trend = config_trend @@ -761,20 +762,21 @@ def loss_func(self, inputs, predicted, targets): loss = None # Compute loss. no reduction. loss = self.config_train.loss_func(predicted, targets) - # Weigh newer samples more. - loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) + if self.config_train.newer_samples_weight > 1.0: + # Weigh newer samples more. + loss = loss * self._get_time_based_sample_weight(t=inputs["time"][:, self.n_lags :]) loss = loss.sum(dim=2).mean() # Regularize. - if self.reg_enabled: - steps_per_epoch = math.ceil(self.trainer.estimated_stepping_batches / self.trainer.max_epochs) - progress_in_epoch = 1 - ((steps_per_epoch * (self.current_epoch + 1) - self.global_step) / steps_per_epoch) - loss, reg_loss = self._add_batch_regularizations(loss, self.current_epoch, progress_in_epoch) + if self.reg_enabled and not self.finding_lr: + loss, reg_loss = self._add_batch_regularizations(loss, self.train_progress) else: reg_loss = torch.tensor(0.0, device=self.device) return loss, reg_loss def training_step(self, batch, batch_idx): inputs, targets, meta = batch + epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch) + self.train_progress = epoch_float / float(self.config_train.epochs) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) @@ -794,19 +796,25 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step() + if self.finding_lr: + scheduler.step() + else: + scheduler.step(epoch=epoch_float) - # Manually track the loss for the lr finder - self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + if self.finding_lr: + # Manually track the loss for the lr finder + self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + # self.log("reg_loss", reg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) # Metrics - if self.metrics_enabled: + if self.metrics_enabled and not self.finding_lr: predicted_denorm = self.denormalize(predicted[:, :, 0]) target_denorm = self.denormalize(targets.squeeze(dim=2)) self.log_dict(self.metrics_train(predicted_denorm, target_denorm), **self.log_args) self.log("Loss", loss, **self.log_args) self.log("RegLoss", reg_loss, **self.log_args) + # self.log("TrainProgress", self.train_progress, **self.log_args) + self.log("LR", scheduler.get_last_lr()[0], **self.log_args) return loss def validation_step(self, batch, batch_idx): @@ -861,48 +869,65 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return prediction, components def configure_optimizers(self): + self.train_steps_per_epoch = self.config_train.batches_per_epoch + # self.trainer.num_training_batches = self.train_steps_per_epoch * self.config_train.epochs + + self.config_train.set_optimizer() + self.config_train.set_scheduler() + # Optimizer - optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) + if self.finding_lr and self.learning_rate is None: + self.learning_rate = 0.1 + optimizer = self.config_train.optimizer( + self.parameters(), + lr=self.learning_rate, + **self.config_train.optimizer_args, + ) # Scheduler - lr_scheduler = self._scheduler( - optimizer, - max_lr=self.learning_rate, - total_steps=self.trainer.estimated_stepping_batches, - **self.config_train.scheduler_args, - ) + if self.config_train.scheduler == torch.optim.lr_scheduler.OneCycleLR: + lr_scheduler = self.config_train.scheduler( + optimizer, + max_lr=self.learning_rate, + # total_steps=self.trainer.estimated_stepping_batches, # if using self.lr_schedulers().step() + total_steps=self.config_train.epochs, # if using self.lr_schedulers().step(epoch=epoch_float) + **self.config_train.scheduler_args, + ) + else: + lr_scheduler = self.config_train.scheduler( + optimizer, + **self.config_train.scheduler_args, + ) return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} def _get_time_based_sample_weight(self, t): - weight = torch.ones_like(t) - if self.config_train.newer_samples_weight > 1.0: - end_w = self.config_train.newer_samples_weight - start_t = self.config_train.newer_samples_start - time = (t.detach() - start_t) / (1.0 - start_t) - time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1 - time = np.pi * (time - 1.0) # time = -pi to 0 - time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1 - # scales end to be end weight times bigger than start weight - # with end weight being 1.0 - weight = (1.0 + time * (end_w - 1.0)) / end_w - return weight.unsqueeze(dim=2) # add an extra dimension for the quantiles - - def _add_batch_regularizations(self, loss, epoch, progress): + end_w = self.config_train.newer_samples_weight + start_t = self.config_train.newer_samples_start + time = (t.detach() - start_t) / (1.0 - start_t) + time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1 + time = np.pi * (time - 1.0) # time = -pi to 0 + time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1 + # scales end to be end weight times bigger than start weight + # with end weight being 1.0 + weight = (1.0 + time * (end_w - 1.0)) / end_w + # add an extra dimension for the quantiles + weight = weight.unsqueeze(dim=2) + return weight + + def _add_batch_regularizations(self, loss, progress): """Add regularization terms to loss, if applicable Parameters ---------- loss : torch.Tensor, scalar current batch loss - epoch : int - current epoch number progress : float - progress within the epoch, between 0 and 1 + progress within training, across all epochs and batches, between 0 and 1 Returns ------- loss, reg_loss """ - delay_weight = self.config_train.get_reg_delay_weight(epoch, progress) + delay_weight = self.config_train.get_reg_delay_weight(progress) reg_loss = torch.zeros(1, dtype=torch.float, requires_grad=False, device=self.device) if delay_weight > 0: @@ -985,8 +1010,8 @@ def denormalize(self, ts): ts = scale_y * ts + shift_y return ts - def train_dataloader(self): - return self.train_loader + # def train_dataloader(self): + # return self.train_loader class FlatNet(nn.Module): diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index 62b9e7481..10fa63f43 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -749,241 +749,3 @@ def set_log_level(log_level: str = "INFO", include_handlers: bool = False): >>> set_log_level("ERROR") """ set_logger_level(logging.getLogger("NP"), log_level, include_handlers) - - -def smooth_loss_and_suggest(lr_finder, window=10): - """ - Smooth loss using a Hamming filter. - - Parameters - ---------- - loss : np.array - Loss values - - Returns - ------- - loss_smoothed : np.array - Smoothed loss values - lr: np.array - Learning rate values - suggested_lr: float - Suggested learning rate based on gradient - """ - lr_finder_results = lr_finder.results - lr = lr_finder_results["lr"] - loss = lr_finder_results["loss"] - # Derive window size from num lr searches, ensure window is divisible by 2 - # half_window = math.ceil(round(len(loss) * 0.1) / 2) - half_window = math.ceil(window / 2) - # Pad sequence and initialialize hamming filter - loss = np.pad(np.array(loss), pad_width=half_window, mode="edge") - window = np.hamming(half_window * 2) - # Convolve the over the loss distribution - try: - loss = np.convolve( - window / window.sum(), - loss, - mode="valid", - )[1:] - except ValueError: - log.warning( - f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " - f"{window}." - ) - # Suggest the lr with steepest negative gradient - try: - # Find the steepest gradient and the minimum loss after that - suggestion = lr[np.argmin(np.gradient(loss))] - except ValueError: - log.error( - f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " - "samples or manually set the learning rate." - ) - raise - suggestion_default = lr_finder.suggestion(skip_begin=10, skip_end=3) - if suggestion is not None and suggestion_default is not None: - log_suggestion_smooth = np.log(suggestion) - log_suggestion_default = np.log(suggestion_default) - lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) - elif suggestion is None and suggestion_default is None: - log.error("Automatic learning rate test failed. Please set manually the learning rate.") - raise - else: - lr_suggestion = suggestion if suggestion is not None else suggestion_default - return (loss, lr, lr_suggestion) - - -def _smooth_loss(loss, beta=0.9): - smoothed_loss = np.zeros_like(loss) - smoothed_loss[0] = loss[0] - for i in range(1, len(loss)): - smoothed_loss[i] = smoothed_loss[i - 1] * beta + (1 - beta) * loss[i] - return smoothed_loss - - -def configure_trainer( - config_train: Train, - config: dict, - metrics_logger, - early_stopping: bool = False, - early_stopping_target: str = "Loss", - accelerator: Optional[str] = None, - progress_bar_enabled: bool = True, - metrics_enabled: bool = False, - checkpointing_enabled: bool = False, - num_batches_per_epoch: int = 100, - deterministic: bool = False, -): - """ - Configures the PyTorch Lightning trainer. - - Parameters - ---------- - config_train : Dict - dictionary containing the overall training configuration. - config : dict - dictionary containing the custom PyTorch Lightning trainer configuration. - metrics_logger : MetricsLogger - MetricsLogger object to log metrics to. - early_stopping: bool - If True, early stopping is enabled. - early_stopping_target : str - Target metric to use for early stopping. - accelerator : str - Accelerator to use for training. - progress_bar_enabled : bool - If False, no progress bar is shown. - metrics_enabled : bool - If False, no metrics are logged. Calculating metrics is computationally expensive and reduces the training - speed. - checkpointing_enabled : bool - If False, no checkpointing is performed. Checkpointing reduces the training speed. - num_batches_per_epoch : int - Number of batches per epoch. - - Returns - ------- - pl.Trainer - PyTorch Lightning trainer - checkpoint_callback - PyTorch Lightning checkpoint callback to load the best model - """ - config = config.copy() - - # Set max number of epochs - if hasattr(config_train, "epochs"): - if config_train.epochs is not None: - config["max_epochs"] = config_train.epochs - - # Configure the Ligthing-logs directory - if "default_root_dir" not in config.keys(): - config["default_root_dir"] = os.getcwd() - - # Accelerator - if isinstance(accelerator, str): - if (accelerator == "auto" and torch.cuda.is_available()) or accelerator == "gpu": - config["accelerator"] = "gpu" - config["devices"] = -1 - elif (accelerator == "auto" and hasattr(torch.backends, "mps")) or accelerator == "mps": - if torch.backends.mps.is_available(): - config["accelerator"] = "mps" - config["devices"] = 1 - elif accelerator != "auto": - config["accelerator"] = accelerator - config["devices"] = 1 - - if "accelerator" in config: - log.info(f"Using accelerator {config['accelerator']} with {config['devices']} device(s).") - else: - log.info("No accelerator available. Using CPU for training.") - - # Configure metrics - if metrics_enabled: - config["logger"] = metrics_logger - else: - config["logger"] = False - - config["deterministic"] = deterministic - - # Configure callbacks - callbacks = [] - has_custom_callbacks = True if "callbacks" in config else False - - # Configure checkpointing - has_modelcheckpoint_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in config["callbacks"]) - else False - ) - if has_modelcheckpoint_callback and not checkpointing_enabled: - raise ValueError( - "Checkpointing is disabled but a ModelCheckpoint callback is provided. Please enable checkpointing or " - "remove the callback." - ) - if checkpointing_enabled: - if not has_modelcheckpoint_callback: - # Callback to access both the last and best model - checkpoint_callback = pl.callbacks.ModelCheckpoint( - monitor=early_stopping_target, mode="min", save_top_k=1, save_last=True - ) - callbacks.append(checkpoint_callback) - else: - checkpoint_callback = next( - callback for callback in config["callbacks"] if isinstance(callback, pl.callbacks.ModelCheckpoint) - ) - else: - config["enable_checkpointing"] = False - checkpoint_callback = None - - # Configure the progress bar, refresh every epoch - has_progressbar_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in config["callbacks"]) - else False - ) - if has_progressbar_callback and not progress_bar_enabled: - raise ValueError( - "Progress bar is disabled but a ProgressBar callback is provided. Please enable the progress bar or remove" - " the callback." - ) - if progress_bar_enabled: - if not has_progressbar_callback: - prog_bar_callback = ProgressBar(refresh_rate=num_batches_per_epoch, epochs=config_train.epochs) - callbacks.append(prog_bar_callback) - else: - config["enable_progress_bar"] = False - - # Early stopping monitor - has_earlystopping_callback = ( - True - if has_custom_callbacks - and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in config["callbacks"]) - else False - ) - if has_earlystopping_callback and not early_stopping: - raise ValueError( - "Early stopping is disabled but an EarlyStopping callback is provided. Please enable early stopping or " - "remove the callback." - ) - if early_stopping: - if not metrics_enabled: - raise ValueError("Early stopping requires metrics to be enabled.") - if not has_earlystopping_callback: - early_stop_callback = pl.callbacks.EarlyStopping( - monitor=early_stopping_target, mode="min", patience=20, divergence_threshold=5.0 - ) - callbacks.append(early_stop_callback) - - if has_custom_callbacks: - config["callbacks"].extend(callbacks) - else: - config["callbacks"] = callbacks - config["num_sanity_val_steps"] = 0 - config["enable_model_summary"] = False - # TODO: Disabling sampler_ddp brings a good speedup in performance, however, check whether this is a good idea - # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#replace-sampler-ddp - # config["replace_sampler_ddp"] = False - - return pl.Trainer(**config), checkpoint_callback diff --git a/neuralprophet/utils_lightning.py b/neuralprophet/utils_lightning.py new file mode 100644 index 000000000..2ecdb3eb3 --- /dev/null +++ b/neuralprophet/utils_lightning.py @@ -0,0 +1,307 @@ +import logging +import math +import os +from typing import Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.tuner.tuning import Tuner + +from neuralprophet.configure import Train +from neuralprophet.logger import ProgressBar + +log = logging.getLogger("NP.utils_lightning") + + +def smooth_loss_and_suggest(lr_finder, window=10): + """ + Smooth loss using a Hamming filter. + + Parameters + ---------- + loss : np.array + Loss values + + Returns + ------- + loss_smoothed : np.array + Smoothed loss values + lr: np.array + Learning rate values + suggested_lr: float + Suggested learning rate based on gradient + """ + lr_finder_results = lr_finder.results + lr = lr_finder_results["lr"] + loss = np.array(lr_finder_results["loss"]) + # Derive window size from num lr searches, ensure window is divisible by 2 + # half_window = math.ceil(round(len(loss) * 0.1) / 2) + half_window = math.ceil(window / 2) + # Pad sequence and initialialize hamming filter + loss = np.pad(loss, pad_width=half_window, mode="edge") + hamming_window = np.hamming(2 * half_window) + # Convolve the over the loss distribution + try: + loss_smooth = np.convolve( + hamming_window / hamming_window.sum(), + loss, + mode="valid", + )[1:] + except ValueError: + log.warning( + f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " + f"{window}." + ) + + # Suggest the lr with steepest negative gradient + try: + # Find the steepest gradient and the minimum loss after that + suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] + suggestion_minimum = lr[np.argmin(np.array(lr_finder_results["loss"]))] + except ValueError: + log.error( + f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " + "samples or manually set the learning rate." + ) + raise + # get the tuner's default suggestion + suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10) + + log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") + log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") + log.info(f"Learning rate finder ---- minimum (not used): {suggestion_minimum}") + if suggestion_steepest is not None and suggestion_default is not None: + log_suggestion_smooth = np.log(suggestion_steepest) + log_suggestion_default = np.log(suggestion_default) + lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) + log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") + elif suggestion_steepest is None and suggestion_default is None: + log.error("Automatic learning rate test failed. Please set manually the learning rate.") + raise + else: + lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default + + log.info(f"Learning rate finder ---- returning: {lr_suggestion}") + log.info(f"Learning rate finder ---- LR (start): {lr[:5]}") + log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}") + log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}") + log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}") + return loss, lr, lr_suggestion + + +def _smooth_loss(loss, beta=0.9): + smoothed_loss = np.zeros_like(loss) + smoothed_loss[0] = loss[0] + for i in range(1, len(loss)): + smoothed_loss[i] = smoothed_loss[i - 1] * beta + (1 - beta) * loss[i] + return smoothed_loss + + +def configure_trainer( + config_train: Train, + metrics_logger, + early_stopping_target: str = "Loss", + accelerator: Optional[str] = None, + progress_bar_enabled: bool = True, + metrics_enabled: bool = False, + checkpointing_enabled: bool = False, + num_batches_per_epoch: int = 100, + deterministic: bool = False, +): + """ + Configures the PyTorch Lightning trainer. + + Parameters + ---------- + config_train : Dict + dictionary containing the overall training configuration. + metrics_logger : MetricsLogger + MetricsLogger object to log metrics to. + early_stopping_target : str + Target metric to use for early stopping. + accelerator : str + Accelerator to use for training. + progress_bar_enabled : bool + If False, no progress bar is shown. + metrics_enabled : bool + If False, no metrics are logged. Calculating metrics is computationally expensive and reduces the training + speed. + checkpointing_enabled : bool + If False, no checkpointing is performed. Checkpointing reduces the training speed. + num_batches_per_epoch : int + Number of batches per epoch. + + Returns + ------- + pl.Trainer + PyTorch Lightning trainer + checkpoint_callback + PyTorch Lightning checkpoint callback to load the best model + """ + if config_train.pl_trainer_config is None: + config_train.pl_trainer_config = {} + + pl_trainer_config = config_train.pl_trainer_config + # pl_trainer_config = pl_trainer_config.copy() + + # Set max number of epochs + assert hasattr(config_train, "epochs") and config_train.epochs is not None + pl_trainer_config["max_epochs"] = config_train.epochs + + # Configure the Ligthing-logs directory + if "default_root_dir" not in pl_trainer_config.keys(): + pl_trainer_config["default_root_dir"] = os.getcwd() + + # Accelerator + if isinstance(accelerator, str): + if (accelerator == "auto" and torch.cuda.is_available()) or accelerator == "gpu": + pl_trainer_config["accelerator"] = "gpu" + pl_trainer_config["devices"] = -1 + elif (accelerator == "auto" and hasattr(torch.backends, "mps")) or accelerator == "mps": + if torch.backends.mps.is_available(): + pl_trainer_config["accelerator"] = "mps" + pl_trainer_config["devices"] = 1 + elif accelerator != "auto": + pl_trainer_config["accelerator"] = accelerator + pl_trainer_config["devices"] = 1 + + if "accelerator" in pl_trainer_config: + log.info(f"Using accelerator {pl_trainer_config['accelerator']} with {pl_trainer_config['devices']} device(s).") + elif accelerator == "auto": + log.info("No accelerator available. Using CPU for training.") + + # Configure metrics + if metrics_enabled: + pl_trainer_config["logger"] = metrics_logger + else: + pl_trainer_config["logger"] = False + + pl_trainer_config["deterministic"] = deterministic + + # Configure callbacks + callbacks = [] + has_custom_callbacks = True if "callbacks" in pl_trainer_config else False + + # Configure checkpointing + has_modelcheckpoint_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, pl.callbacks.ModelCheckpoint) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_modelcheckpoint_callback and not checkpointing_enabled: + raise ValueError( + "Checkpointing is disabled but a ModelCheckpoint callback is provided. Please enable checkpointing or " + "remove the callback." + ) + if checkpointing_enabled: + if not has_modelcheckpoint_callback: + # Callback to access both the last and best model + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor=early_stopping_target, mode="min", save_top_k=1, save_last=True + ) + callbacks.append(checkpoint_callback) + else: + checkpoint_callback = next( + callback + for callback in pl_trainer_config["callbacks"] + if isinstance(callback, pl.callbacks.ModelCheckpoint) + ) + else: + pl_trainer_config["enable_checkpointing"] = False + checkpoint_callback = None + + # Configure the progress bar, refresh every epoch + has_progressbar_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, pl.callbacks.ProgressBar) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_progressbar_callback and not progress_bar_enabled: + raise ValueError( + "Progress bar is disabled but a ProgressBar callback is provided. Please enable the progress bar or remove" + " the callback." + ) + if progress_bar_enabled: + if not has_progressbar_callback: + prog_bar_callback = ProgressBar(refresh_rate=num_batches_per_epoch, epochs=config_train.epochs) + callbacks.append(prog_bar_callback) + else: + pl_trainer_config["enable_progress_bar"] = False + + # Early stopping monitor + has_earlystopping_callback = ( + True + if has_custom_callbacks + and any(isinstance(callback, pl.callbacks.EarlyStopping) for callback in pl_trainer_config["callbacks"]) + else False + ) + if has_earlystopping_callback and not config_train.early_stopping: + raise ValueError( + "Early stopping is disabled but an EarlyStopping callback is provided. Please enable early stopping or " + "remove the callback." + ) + if config_train.early_stopping: + if not metrics_enabled: + raise ValueError("Early stopping requires metrics to be enabled.") + if not has_earlystopping_callback: + early_stop_callback = pl.callbacks.EarlyStopping( + monitor=early_stopping_target, mode="min", patience=20, divergence_threshold=5.0 + ) + callbacks.append(early_stop_callback) + + if has_custom_callbacks: + pl_trainer_config["callbacks"].extend(callbacks) + else: + pl_trainer_config["callbacks"] = callbacks + pl_trainer_config["num_sanity_val_steps"] = 0 + pl_trainer_config["enable_model_summary"] = False + # TODO: Disabling sampler_ddp brings a good speedup in performance, however, check whether this is a good idea + # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#replace-sampler-ddp + # config["replace_sampler_ddp"] = False + + return pl.Trainer(**pl_trainer_config), checkpoint_callback + + +def find_learning_rate(model, loader, trainer, train_epochs): + log.info("No Learning Rate provided. Activating learning rate finder") + + # Configure the learning rate finder args + batches_per_epoch = len(loader) + main_training_total_steps = train_epochs * batches_per_epoch + # main_training_total_steps is around 1e3 to 1e6 -> num_training 100 to 400 + num_training = 100 + int(np.log10(1 + main_training_total_steps / 1000) * 100) + if batches_per_epoch < num_training: + log.warning( + f"Learning rate finder: The number of batches per epoch ({batches_per_epoch}) is too small than the required number \ + for the learning rate finder ({num_training}). The results might not be optimal." + ) + # num_training = num_batches + lr_finder_args = { + "min_lr": 1e-7, + "max_lr": 1e1, + "num_training": num_training, + "early_stop_threshold": None, + "mode": "exponential", + } + log.info(f"Learning rate finder ---- ARGs: {lr_finder_args}") + + # Execute the learning rate range finder + tuner = Tuner(trainer) + model.finding_lr = True + # model.train_loader = loader + lr_finder = tuner.lr_find( + model=model, + train_dataloaders=loader, + # val_dataloaders=val_loader, # not used, but lead to Lightning bug if not provided in prior versions. + **lr_finder_args, + ) + model.finding_lr = False + + # Estimate the optimal learning rate from the loss curve + assert lr_finder is not None + loss_list, lr_list, lr_suggested = smooth_loss_and_suggest(lr_finder) + log.info(f"Learning rate finder suggested learning rate: {lr_suggested}") + return lr_suggested diff --git a/tests/metrics/debug-energy-price-daily.ipynb b/tests/debug/debug-energy-price-daily.ipynb similarity index 100% rename from tests/metrics/debug-energy-price-daily.ipynb rename to tests/debug/debug-energy-price-daily.ipynb diff --git a/tests/metrics/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb similarity index 52% rename from tests/metrics/debug-energy-price-hourly.ipynb rename to tests/debug/debug-energy-price-hourly.ipynb index 14a09c93e..f78de7a04 100644 --- a/tests/metrics/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -16,7 +16,9 @@ "from plotly.subplots import make_subplots\n", "from plotly_resampler import unregister_plotly_resampler\n", "\n", - "from neuralprophet import NeuralProphet, set_random_seed" + "from neuralprophet import NeuralProphet, set_random_seed, set_log_level\n", + "\n", + "set_log_level(\"INFO\")" ] }, { @@ -123,7 +125,7 @@ "df[\"y\"] = pd.to_numeric(df[\"y\"], errors=\"coerce\")\n", "\n", "df = df.drop(\"ds\", axis=1)\n", - "df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"H\")\n", + "df[\"ds\"] = pd.date_range(start=\"2015-01-01 00:00:00\", periods=len(df), freq=\"h\")\n", "df[\"ID\"] = \"test\"\n", "\n", "df_id = df[[\"ds\", \"y\", \"temp\"]].copy()\n", @@ -146,7 +148,11 @@ "df[\"temp\"] = (df[\"temp\"] - 65.0) / 50.0\n", "\n", "# df\n", - "df = df[[\"ID\", \"ds\", \"y\", \"temp\", \"winter\", \"summer\"]]" + "df = df[[\"ID\", \"ds\", \"y\", \"temp\", \"winter\", \"summer\"]]\n", + "\n", + "# Split\n", + "df_train = df[df[\"ds\"] < \"2015-03-01\"]\n", + "df_test = df[df[\"ds\"] >= \"2015-03-01\"]" ] }, { @@ -158,13 +164,14 @@ "name": "stdout", "output_type": "stream", "text": [ + "quantiles: [0.01, 0.99]\n", "Using CPU\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -173,9 +180,6 @@ } ], "source": [ - "### Temporary Test for on-the-fly sampling - very time consuming!\n", - "\n", - "\n", "# Hyperparameter\n", "tuned_params = {\n", " \"n_lags\": 10,\n", @@ -184,12 +188,12 @@ " \"yearly_seasonality\": 10,\n", " \"weekly_seasonality\": True,\n", " \"daily_seasonality\": False, # due to conditional daily seasonality\n", - " \"batch_size\": 128,\n", + " \"batch_size\": 32,\n", " \"ar_layers\": [8, 4],\n", " \"lagged_reg_layers\": [8],\n", " # not tuned\n", " \"n_forecasts\": 5,\n", - " \"learning_rate\": 0.001,\n", + " # \"learning_rate\": 0.1,\n", " \"epochs\": 10,\n", " \"trend_global_local\": \"global\",\n", " \"season_global_local\": \"global\",\n", @@ -200,9 +204,12 @@ "# Uncertainty Quantification\n", "confidence_lv = 0.98\n", "quantile_list = [round(((1 - confidence_lv) / 2), 2), round((confidence_lv + (1 - confidence_lv) / 2), 2)]\n", + "# quantile_list = None\n", + "print(f\"quantiles: {quantile_list}\")\n", "\n", "# Check if GPU is available\n", - "use_gpu = torch.cuda.is_available()\n", + "# use_gpu = torch.cuda.is_available()\n", + "use_gpu = False\n", "\n", "# Set trainer configuration\n", "trainer_configs = {\n", @@ -234,50 +241,135 @@ "output_type": "stream", "text": [ "INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.929% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.929% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", + "WARNING - (NP.forecaster.fit) - Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.929% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", + "\n", + "INFO - (NP.forecaster._train) - Dataset size: 2758\n", + "INFO - (NP.forecaster._train) - Number of batches per training epoch: 87\n", "INFO - (NP.utils.configure_trainer) - Using accelerator cpu with 1 device(s).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "aa26aaf9191f401b9c69ebafca381bab", + "model_id": "ded17dc7d6e940bfb29321cd972603b6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00= \"2015-03-01\"]\n", - "\n", "# Training & Predict\n", - "metrics = m.fit(df=df_train, validation_df=df_test, freq=\"H\", num_workers=4, early_stopping=False)" + "metrics = m.fit(\n", + " df=df_train,\n", + " validation_df=df_test,\n", + " freq=\"h\",\n", + " early_stopping=False,\n", + " # scheduler=\"onecyclelr\",\n", + " # scheduler_args={\n", + " # \"pct_start\": 0.3,\n", + " # \"div_factor\": 100.0,\n", + " # \"final_div_factor\": 1000.0,\n", + " # \"anneal_strategy\": \"cos\",\n", + " # \"three_phase\": False,\n", + " # },\n", + " # scheduler=\"exponentiallr\",\n", + " # scheduler_args={\"gamma\": 0.8,},\n", + ")" ] }, { @@ -428,16 +545,16 @@ "type": "scatter", "xaxis": "x", "y": [ - 1.6991313695907593, - 1.5541504621505737, - 1.2866111993789673, - 1.0485198497772217, - 0.9603586792945862, - 0.933108389377594, - 0.9244528412818909, - 0.9177840948104858, - 0.9132021069526672, - 0.9105463027954102 + 1.0717755556106567, + 0.5511327981948853, + 0.4953157603740692, + 0.4818778932094574, + 0.4735223650932312, + 0.47175300121307373, + 0.4704058766365051, + 0.4707088768482208, + 0.46935251355171204, + 0.468860924243927 ], "yaxis": "y" }, @@ -452,16 +569,16 @@ "type": "scatter", "xaxis": "x", "y": [ - 1.9174306392669678, - 2.133635997772217, - 2.1361277103424072, - 1.954904317855835, - 1.8205108642578125, - 1.7834810018539429, - 1.7635681629180908, - 1.7493915557861328, - 1.7418491840362549, - 1.7389646768569946 + 0.5182198286056519, + 0.5441035628318787, + 0.47900277376174927, + 0.5163846015930176, + 0.4929402470588684, + 0.5097485780715942, + 0.5129396915435791, + 0.5079220533370972, + 0.5094015002250671, + 0.5108088850975037 ], "yaxis": "y" }, @@ -476,16 +593,16 @@ "type": "scatter", "xaxis": "x2", "y": [ - 2.249849557876587, - 2.062807083129883, - 1.6801131963729858, - 1.344346523284912, - 1.2270969152450562, - 1.1934525966644287, - 1.1826142072677612, - 1.1741188764572144, - 1.169130563735962, - 1.1649360656738281 + 1.3691433668136597, + 0.7346274256706238, + 0.6665279269218445, + 0.6506426334381104, + 0.6393488049507141, + 0.6376363039016724, + 0.6363534331321716, + 0.6357880234718323, + 0.6352851986885071, + 0.6348727941513062 ], "yaxis": "y2" }, @@ -500,16 +617,16 @@ "type": "scatter", "xaxis": "x2", "y": [ - 2.1282451152801514, - 2.287360668182373, - 2.3184731006622314, - 2.140346050262451, - 2.0008866786956787, - 1.962218999862671, - 1.9410110712051392, - 1.9257516860961914, - 1.9175572395324707, - 1.914405107498169 + 0.618143618106842, + 0.6197627186775208, + 0.555190920829773, + 0.5922481417655945, + 0.5693594813346863, + 0.5874571204185486, + 0.5925166010856628, + 0.5866028666496277, + 0.588126540184021, + 0.5891289114952087 ], "yaxis": "y2" }, @@ -524,16 +641,16 @@ "type": "scatter", "xaxis": "x3", "y": [ - 3.4565775394439697, - 3.047083854675293, - 2.3058581352233887, - 1.710412621498108, - 1.4448997974395752, - 1.353717565536499, - 1.3267676830291748, - 1.3102833032608032, - 1.2921112775802612, - 1.2888280153274536 + 1.1767247915267944, + 0.36462634801864624, + 0.3093106746673584, + 0.29513803124427795, + 0.2881109416484833, + 0.2860604524612427, + 0.2844979465007782, + 0.2839355170726776, + 0.28298255801200867, + 0.2833350598812103 ], "yaxis": "y3" }, @@ -548,16 +665,16 @@ "type": "scatter", "xaxis": "x3", "y": [ - 4.821254730224609, - 4.705277919769287, - 4.240411758422852, - 3.7221953868865967, - 3.4264442920684814, - 3.345188617706299, - 3.2992584705352783, - 3.2648608684539795, - 3.246990919113159, - 3.2401645183563232 + 0.4635086953639984, + 0.4854629635810852, + 0.39425128698349, + 0.4356289505958557, + 0.40521177649497986, + 0.4225311279296875, + 0.4251018166542053, + 0.4184044301509857, + 0.4201345443725586, + 0.420722633600235 ], "yaxis": "y3" }, @@ -584,6 +701,30 @@ 0 ], "yaxis": "y3" + }, + { + "legendgroup": "LR", + "line": { + "color": "#2d92ff", + "width": 2 + }, + "mode": "lines", + "name": "LR", + "type": "scatter", + "xaxis": "x4", + "y": [ + 0.0028570329304784536, + 0.009130113758146763, + 0.00918674748390913, + 0.0029136640951037407, + 0.001063643372617662, + 0.0008884650305844843, + 0.0006039389409124851, + 0.00031874445267021656, + 0.00014181611186359078, + 0.00014073456986807287 + ], + "yaxis": "y4" } ], "layout": { @@ -594,7 +735,7 @@ }, "showarrow": false, "text": "MAE", - "x": 0.14444444444444446, + "x": 0.10625, "xanchor": "center", "xref": "paper", "y": 1, @@ -607,7 +748,7 @@ }, "showarrow": false, "text": "RMSE", - "x": 0.5, + "x": 0.36875, "xanchor": "center", "xref": "paper", "y": 1, @@ -620,7 +761,20 @@ }, "showarrow": false, "text": "Loss", - "x": 0.8555555555555556, + "x": 0.6312500000000001, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "LR", + "x": 0.89375, "xanchor": "center", "xref": "paper", "y": 1, @@ -1466,7 +1620,7 @@ "anchor": "y", "domain": [ 0, - 0.2888888888888889 + 0.2125 ], "linewidth": 1.5, "mirror": true, @@ -1476,8 +1630,8 @@ "xaxis2": { "anchor": "y2", "domain": [ - 0.35555555555555557, - 0.6444444444444445 + 0.2625, + 0.475 ], "linewidth": 1.5, "mirror": true, @@ -1487,7 +1641,18 @@ "xaxis3": { "anchor": "y3", "domain": [ - 0.7111111111111111, + 0.525, + 0.7375 + ], + "linewidth": 1.5, + "mirror": true, + "showgrid": false, + "showline": true + }, + "xaxis4": { + "anchor": "y4", + "domain": [ + 0.7875, 1 ], "linewidth": 1.5, @@ -1533,6 +1698,19 @@ "showgrid": false, "showline": true, "type": "log" + }, + "yaxis4": { + "anchor": "x4", + "domain": [ + 0, + 1 + ], + "linewidth": 1.5, + "mirror": true, + "rangemode": "tozero", + "showgrid": false, + "showline": true, + "type": "log" } } } @@ -1553,15 +1731,16 @@ { "data": { "text/plain": [ - "{'MAE_val': 1.7389646768569946,\n", - " 'RMSE_val': 1.914405107498169,\n", - " 'Loss_val': 3.2401645183563232,\n", + "{'MAE_val': 0.5108088850975037,\n", + " 'RMSE_val': 0.5891289114952087,\n", + " 'Loss_val': 0.420722633600235,\n", " 'RegLoss_val': 0.0,\n", " 'epoch': 9,\n", - " 'MAE': 0.9105463027954102,\n", - " 'RMSE': 1.1649360656738281,\n", - " 'Loss': 1.2888280153274536,\n", - " 'RegLoss': 0.0}" + " 'MAE': 0.468860924243927,\n", + " 'RMSE': 0.6348727941513062,\n", + " 'Loss': 0.2833350598812103,\n", + " 'RegLoss': 0.0,\n", + " 'LR': 0.00014073456986807287}" ] }, "execution_count": 8, @@ -1608,20 +1787,139 @@ " RMSE\n", " Loss\n", " RegLoss\n", + " LR\n", " \n", " \n", " \n", " \n", + " 0\n", + " 0.518220\n", + " 0.618144\n", + " 0.463509\n", + " 0.0\n", + " 0\n", + " 1.071776\n", + " 1.369143\n", + " 1.176725\n", + " 0.0\n", + " 0.002857\n", + " \n", + " \n", + " 1\n", + " 0.544104\n", + " 0.619763\n", + " 0.485463\n", + " 0.0\n", + " 1\n", + " 0.551133\n", + " 0.734627\n", + " 0.364626\n", + " 0.0\n", + " 0.009130\n", + " \n", + " \n", + " 2\n", + " 0.479003\n", + " 0.555191\n", + " 0.394251\n", + " 0.0\n", + " 2\n", + " 0.495316\n", + " 0.666528\n", + " 0.309311\n", + " 0.0\n", + " 0.009187\n", + " \n", + " \n", + " 3\n", + " 0.516385\n", + " 0.592248\n", + " 0.435629\n", + " 0.0\n", + " 3\n", + " 0.481878\n", + " 0.650643\n", + " 0.295138\n", + " 0.0\n", + " 0.002914\n", + " \n", + " \n", + " 4\n", + " 0.492940\n", + " 0.569359\n", + " 0.405212\n", + " 0.0\n", + " 4\n", + " 0.473522\n", + " 0.639349\n", + " 0.288111\n", + " 0.0\n", + " 0.001064\n", + " \n", + " \n", + " 5\n", + " 0.509749\n", + " 0.587457\n", + " 0.422531\n", + " 0.0\n", + " 5\n", + " 0.471753\n", + " 0.637636\n", + " 0.286060\n", + " 0.0\n", + " 0.000888\n", + " \n", + " \n", + " 6\n", + " 0.512940\n", + " 0.592517\n", + " 0.425102\n", + " 0.0\n", + " 6\n", + " 0.470406\n", + " 0.636353\n", + " 0.284498\n", + " 0.0\n", + " 0.000604\n", + " \n", + " \n", + " 7\n", + " 0.507922\n", + " 0.586603\n", + " 0.418404\n", + " 0.0\n", + " 7\n", + " 0.470709\n", + " 0.635788\n", + " 0.283936\n", + " 0.0\n", + " 0.000319\n", + " \n", + " \n", + " 8\n", + " 0.509402\n", + " 0.588127\n", + " 0.420135\n", + " 0.0\n", + " 8\n", + " 0.469353\n", + " 0.635285\n", + " 0.282983\n", + " 0.0\n", + " 0.000142\n", + " \n", + " \n", " 9\n", - " 1.738965\n", - " 1.914405\n", - " 3.240165\n", + " 0.510809\n", + " 0.589129\n", + " 0.420723\n", " 0.0\n", " 9\n", - " 0.910546\n", - " 1.164936\n", - " 1.288828\n", + " 0.468861\n", + " 0.634873\n", + " 0.283335\n", " 0.0\n", + " 0.000141\n", " \n", " \n", "\n", @@ -1629,10 +1927,28 @@ ], "text/plain": [ " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", - "9 1.738965 1.914405 3.240165 0.0 9 0.910546 1.164936 \n", + "0 0.518220 0.618144 0.463509 0.0 0 1.071776 1.369143 \n", + "1 0.544104 0.619763 0.485463 0.0 1 0.551133 0.734627 \n", + "2 0.479003 0.555191 0.394251 0.0 2 0.495316 0.666528 \n", + "3 0.516385 0.592248 0.435629 0.0 3 0.481878 0.650643 \n", + "4 0.492940 0.569359 0.405212 0.0 4 0.473522 0.639349 \n", + "5 0.509749 0.587457 0.422531 0.0 5 0.471753 0.637636 \n", + "6 0.512940 0.592517 0.425102 0.0 6 0.470406 0.636353 \n", + "7 0.507922 0.586603 0.418404 0.0 7 0.470709 0.635788 \n", + "8 0.509402 0.588127 0.420135 0.0 8 0.469353 0.635285 \n", + "9 0.510809 0.589129 0.420723 0.0 9 0.468861 0.634873 \n", "\n", - " Loss RegLoss \n", - "9 1.288828 0.0 " + " Loss RegLoss LR \n", + "0 1.176725 0.0 0.002857 \n", + "1 0.364626 0.0 0.009130 \n", + "2 0.309311 0.0 0.009187 \n", + "3 0.295138 0.0 0.002914 \n", + "4 0.288111 0.0 0.001064 \n", + "5 0.286060 0.0 0.000888 \n", + "6 0.284498 0.0 0.000604 \n", + "7 0.283936 0.0 0.000319 \n", + "8 0.282983 0.0 0.000142 \n", + "9 0.283335 0.0 0.000141 " ] }, "execution_count": 9, @@ -1641,7 +1957,7 @@ } ], "source": [ - "metrics.tail(1)" + "metrics" ] }, { @@ -1653,47 +1969,121 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n" + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.data.processing._handle_missing_data) - Dropped 5 rows at the end with NaNs in 'y' column.\n", - "INFO - (NP.df_utils._infer_frequency) - Major frequency H corresponds to 99.932% of the data.\n", - "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - H\n", - "INFO - (NP.data.processing._handle_missing_data) - Dropped 5 rows at the end with NaNs in 'y' column.\n" + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Major frequency h corresponds to 99.932% of the data.\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning:\n", + "\n", + "Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", + "\n", + "\n", + "INFO - (NP.df_utils._infer_frequency) - Defined frequency is equal to major frequency - h\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning:\n", + "\n", + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "\n", + "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "78600faef98442c3bcae260cf6a78232", + "model_id": "f1e4231ad84f4ce2a3b3152a04780df8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Predicting: 22it [00:00, ?it/s]" + "Predicting: | | 0/? [00:00[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'aebc484d-c130-47bd-8870-268071f0b3d5',\n", + " 'uid': '1e485c1d-dae9-439f-97e8-f8960bf19265',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", + " 'y': array([-8.401451, -8.331238, -7.641697, ..., 35.0834 , 31.378742, 26.125694],\n", " dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", @@ -1745,46 +2190,46 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'c62aca2a-cbb9-4e43-915f-156387e57092',\n", + " 'uid': 'ddeebab6-2948-4ed2-839d-80f25bcbbb46',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([80.960884, 76.19124 , 64.98064 , ..., 55.83882 , 67.100685, 64.74074 ],\n", + " 'y': array([70.66735 , 76.33557 , 73.91716 , ..., 72.79848 , 76.60592 , 75.965355],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'aeae0371-af61-428b-bac3-3d7c9675a881',\n", + " 'uid': '3979d61f-ad4c-49eb-92f2-0519eec19b62',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", + " 'y': array([39.514854, 43.52948 , 40.765232, ..., 61.876595, 65.44407 , 61.56613 ],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'fdc61ccb-a79c-4487-bdf9-b9be5d0159d8',\n", + " 'uid': 'b8c0ed1f-761f-44db-a774-93abc8ab8338',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([62.35801 , 58.90128 , 49.21923 , ..., 50.683945, 56.553596, 58.41175 ],\n", + " 'y': array([39.514854, 43.52948 , 40.765232, ..., 61.876595, 65.44407 , 61.56613 ],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'f3b8bafe-c1a6-4a00-b6d8-94ad845ee178',\n", + " 'uid': '46d51f13-dea8-4bc0-b1fa-e458374a0cbc',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -1822,7 +2267,7 @@ "})" ] }, - "execution_count": 13, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -1834,339 +2279,124 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "INFO - (NP.forecaster.plot_components) - Plotting data from ID test\n" + "INFO - (NP.forecaster.plot_components) - Plotting data from ID test\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:410: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:177: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version. Please use 'h' instead of 'H'.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/.cache/pypoetry/virtualenvs/neuralprophet-CT7lk1Bv-py3.10/lib/python3.10/site-packages/plotly_resampler/figure_resampler/utils.py:178: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:559: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0851c9188ffb4c94bc7948103985aee2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "FigureWidgetResampler({\n", - " 'data': [{'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': '[R] Trend ~1h',\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': 'a971f8c1-1e2e-428f-bbae-b1e366f40f84',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x',\n", - " 'y': array([41.138184, 41.136326, 41.134468, ..., 38.49218 , 38.488464, 38.486603],\n", - " dtype=float32),\n", - " 'yaxis': 'y'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '896ec7a5-4db3-4572-9615-c92ff0a440c2',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x2',\n", - " 'y': array([7.7610316, 7.77278 , 7.782315 , ..., 8.327494 , 8.318201 , 8.315492 ],\n", - " dtype=float32),\n", - " 'yaxis': 'y2'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '4e7558c0-22c3-42b6-b38b-64828e95911f',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x3',\n", - " 'y': array([ 0.36878857, 0.30485797, 0.2463306 , ..., -0.56539005, 0.4600458 ,\n", - " 0.93207777], dtype=float32),\n", - " 'yaxis': 'y3'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '421f3dfd-0361-48bd-b035-27b85837e7d1',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x4',\n", - " 'y': array([ 6.8369484 , 8.779529 , -0.55572075, ..., 0. , 0. ,\n", - " 0. ], dtype=float32),\n", - " 'yaxis': 'y4'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '34dff3a4-3054-4a91-b79e-f965aa8d3284',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x5',\n", - " 'y': array([0. , 0. , 0. , ..., 2.5935924, 7.5037613, 4.810857 ],\n", - " dtype=float32),\n", - " 'yaxis': 'y5'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': 'cd6c048f-2b7f-47c5-8f34-50aed056ee0b',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x6',\n", - " 'y': array([14.265438 , 6.3923936 , -0.08357577, ..., 0. , 0.4089267 ,\n", - " 4.4793005 ], dtype=float32),\n", - " 'yaxis': 'y6'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '61f91c94-2944-4efb-9907-709ee2fb6c77',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x7',\n", - " 'y': array([1.765334 , 3.7883697, 4.1204934, ..., 1.8360679, 1.8920995, 0. ],\n", - " dtype=float32),\n", - " 'yaxis': 'y7'},\n", - " {'line': {'color': '#2d92ff', 'width': 2},\n", - " 'mode': 'lines',\n", - " 'name': ('[R' ... ' style=\"color:#fc9944\">~1h'),\n", - " 'showlegend': False,\n", - " 'type': 'scatter',\n", - " 'uid': '885dd23f-b889-4004-9dcf-70c1ca00a53d',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 9, 0),\n", - " datetime.datetime(2015, 1, 2, 10, 0),\n", - " datetime.datetime(2015, 1, 2, 11, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x8',\n", - " 'y': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),\n", - " 'yaxis': 'y8'},\n", - " {'fill': 'tozeroy',\n", - " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", - " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", - " 'mode': 'lines',\n", - " 'name': '[R] yhat5 1.0% ~1h',\n", - " 'showlegend': True,\n", - " 'type': 'scatter',\n", - " 'uid': 'b180a45d-ef7d-43be-9180-b0d60ffffff7',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x9',\n", - " 'y': array([ 37.975266 , 31.622068 , -3.693409 , ..., -14.12719 , -2.9433403,\n", - " -0.9949646], dtype=float32),\n", - " 'yaxis': 'y9'},\n", - " {'fill': 'tozeroy',\n", - " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", - " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", - " 'mode': 'lines',\n", - " 'name': '[R] yhat5 99.0% ~1h',\n", - " 'showlegend': True,\n", - " 'type': 'scatter',\n", - " 'uid': '94ed8664-debf-4a84-b7bf-74bc18ea9a66',\n", - " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", - " datetime.datetime(2015, 1, 2, 14, 0),\n", - " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", - " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", - " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'xaxis': 'x9',\n", - " 'y': array([56.57814 , 48.912025 , 12.067997 , ..., -8.972313 , 5.452839 ,\n", - " 5.3340225], dtype=float32),\n", - " 'yaxis': 'y9'}],\n", - " 'layout': {'autosize': True,\n", - " 'barmode': 'overlay',\n", - " 'font': {'size': 10},\n", - " 'height': 1890,\n", - " 'hovermode': 'x unified',\n", - " 'legend': {'traceorder': 'reversed', 'y': 0.1},\n", - " 'margin': {'b': 0, 'l': 0, 'pad': 0, 'r': 10, 't': 10},\n", - " 'template': '...',\n", - " 'title': {'font': {'size': 12}},\n", - " 'width': 700,\n", - " 'xaxis': {'anchor': 'y',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis2': {'anchor': 'y2',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis3': {'anchor': 'y3',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis4': {'anchor': 'y4',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis5': {'anchor': 'y5',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis6': {'anchor': 'y6',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis7': {'anchor': 'y7',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis8': {'anchor': 'y8',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 10:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'xaxis9': {'anchor': 'y9',\n", - " 'domain': [0.0, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'range': [2014-12-30 14:00:00, 2015-03-05 19:00:00],\n", - " 'showline': True,\n", - " 'title': {'text': 'ds'},\n", - " 'type': 'date'},\n", - " 'yaxis': {'anchor': 'x',\n", - " 'domain': [0.9185185185185185, 1.0],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'normal',\n", - " 'showline': True,\n", - " 'title': {'text': 'Trend'}},\n", - " 'yaxis2': {'anchor': 'x2',\n", - " 'domain': [0.8037037037037038, 0.8851851851851853],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'yearly seasonality'}},\n", - " 'yaxis3': {'anchor': 'x3',\n", - " 'domain': [0.6888888888888889, 0.7703703703703704],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'weekly seasonality'}},\n", - " 'yaxis4': {'anchor': 'x4',\n", - " 'domain': [0.5740740740740741, 0.6555555555555556],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'winter seasonality'}},\n", - " 'yaxis5': {'anchor': 'x5',\n", - " 'domain': [0.45925925925925926, 0.5407407407407407],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'summer seasonality'}},\n", - " 'yaxis6': {'anchor': 'x6',\n", - " 'domain': [0.34444444444444444, 0.42592592592592593],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'AR (5)-ahead'}},\n", - " 'yaxis7': {'anchor': 'x7',\n", - " 'domain': [0.22962962962962963, 0.3111111111111111],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Lagged Regressor \"temp\" (5)-ahead'}},\n", - " 'yaxis8': {'anchor': 'x8',\n", - " 'domain': [0.11481481481481481, 0.1962962962962963],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Additive Events'}},\n", - " 'yaxis9': {'anchor': 'x9',\n", - " 'domain': [0.0, 0.08148148148148149],\n", - " 'linewidth': 1.5,\n", - " 'mirror': True,\n", - " 'rangemode': 'tozero',\n", - " 'showline': True,\n", - " 'title': {'text': 'Uncertainty'}}}\n", - "})" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" + "ename": "IndexError", + "evalue": "index -1 is out of bounds for axis 0 with size 0", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_components\u001b[49m\u001b[43m(\u001b[49m\u001b[43mforecast\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2465\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2463\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2466\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2468\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2469\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2470\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2471\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2472\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2473\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2474\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2475\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2476\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2477\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2478\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2483\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2484\u001b[0m )\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:332\u001b[0m, in \u001b[0;36mplot_components\u001b[0;34m(m, fcst, plot_configuration, df_name, one_period_per_season, figsize, resampler_active, plotly_static)\u001b[0m\n\u001b[1;32m 327\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m get_forecast_component_props(\n\u001b[1;32m 328\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst, df_name\u001b[38;5;241m=\u001b[39mdf_name, comp_name\u001b[38;5;241m=\u001b[39mcomp_name, plot_name\u001b[38;5;241m=\u001b[39mcomp[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplot_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 329\u001b[0m )\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto-regression\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlagged regressor\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name:\n\u001b[0;32m--> 332\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m \u001b[43mget_multiforecast_component_props\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcomp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 333\u001b[0m fig\u001b[38;5;241m.\u001b[39mupdate_layout(barmode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moverlay\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:603\u001b[0m, in \u001b[0;36mget_multiforecast_component_props\u001b[0;34m(fcst, comp_name, plot_name, multiplicative, bar, focus, num_overplot, **kwargs)\u001b[0m\n\u001b[1;32m 601\u001b[0m y \u001b[38;5;241m=\u001b[39m fcst[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcomp_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 602\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mvalues\n\u001b[0;32m--> 603\u001b[0m \u001b[43my\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bar:\n\u001b[1;32m 605\u001b[0m traces\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 606\u001b[0m go\u001b[38;5;241m.\u001b[39mBar(\n\u001b[1;32m 607\u001b[0m name\u001b[38;5;241m=\u001b[39mplot_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 613\u001b[0m )\n\u001b[1;32m 614\u001b[0m )\n", + "\u001b[0;31mIndexError\u001b[0m: index -1 is out of bounds for axis 0 with size 0" + ] } ], "source": [ @@ -2175,13 +2405,39 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:178: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:475: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:508: FutureWarning:\n", + "\n", + "The behavior of DatetimeProperties.to_pydatetime is deprecated, in a future version this will return a Series containing python datetime objects instead of an ndarray. To retain the old behavior, call `np.array` on the result\n", + "\n", + "\n", + "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/plot_model_parameters_plotly.py:564: FutureWarning:\n", + "\n", + "'H' is deprecated and will be removed in a future version, please use 'h' instead.\n", + "\n", + "\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "653b60479a0244c394b8e68cea26b341", + "model_id": "4549df8821ee4b649e748ccf20cbb9a9", "version_major": 2, "version_minor": 0 }, @@ -2192,18 +2448,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': 'f6f21f4d-8199-49f7-a49a-9951dd269bd9',\n", + " 'uid': 'a87ad6c0-8302-4d80-a053-3de55909d7d9',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([41.1995 , 38.57022], dtype=float32),\n", + " 'y': array([44.171093, 46.755657], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': 'f0adc090-2190-4a3b-9c9b-97c8cede02e2',\n", + " 'uid': '687528c8-278a-4dba-a432-7580315842b3',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2211,15 +2467,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([4.0829487 , 5.187225 , 6.208157 , ..., 0.19168049, 1.4080983 ,\n", - " 2.6177309 ], dtype=float32),\n", + " 'y': array([3.5837142, 3.9962187, 4.3687215, ..., 2.132554 , 2.5801535, 3.0329647],\n", + " dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': '17c4c727-ac96-43b0-8e36-acc3caab9c2d',\n", + " 'uid': '9bc616f5-17ad-4fc4-b88f-bb47bdcc4e5f',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2233,117 +2489,117 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([-4.4598384 , -4.2069 , -3.939087 , -3.623846 , -3.2824173 ,\n", - " -2.955236 , -2.6178632 , -2.2747667 , -1.9053857 , -1.5644891 ,\n", - " -1.2317951 , -0.9106357 , -0.58533347, -0.28182796, -0.02269452,\n", - " 0.21292625, 0.42130318, 0.6071059 , 0.7497983 , 0.8622262 ,\n", - " 0.9443581 , 0.99504983, 1.0032893 , 0.9738326 , 0.9262827 ,\n", - " 0.85167503, 0.74685544, 0.6115421 , 0.4558556 , 0.2885901 ,\n", - " 0.11641604, -0.07034495, -0.27396792, -0.47582942, -0.66030794,\n", - " -0.8382371 , -1.0191802 , -1.1867032 , -1.3330028 , -1.4465324 ,\n", - " -1.5369219 , -1.6096568 , -1.6520382 , -1.6572822 , -1.6354212 ,\n", - " -1.5726112 , -1.4764075 , -1.3572443 , -1.2015358 , -1.011647 ,\n", - " -0.80462766, -0.5479135 , -0.2638637 , 0.02303137, 0.33831146,\n", - " 0.64510536, 0.9981383 , 1.3464724 , 1.6830701 , 2.0331054 ,\n", - " 2.3555207 , 2.69649 , 3.008644 , 3.2858677 , 3.5552442 ,\n", - " 3.7721593 , 3.9756846 , 4.1423197 , 4.264748 , 4.3529058 ,\n", - " 4.389515 , 4.3708434 , 4.3175454 , 4.2146673 , 4.072025 ,\n", - " 3.8747027 , 3.6363995 , 3.340743 , 3.0221694 , 2.664729 ,\n", - " 2.2731943 , 1.8473105 , 1.4062825 , 0.9366955 , 0.42765933,\n", - " -0.07621501, -0.56814903, -1.059046 , -1.5562105 , -2.0614126 ,\n", - " -2.5020652 , -2.9485521 , -3.352801 , -3.7507663 , -4.1075444 ,\n", - " -4.387843 , -4.6441846 , -4.8436375 , -4.9982753 , -5.098551 ,\n", - " -5.1361027 , -5.115603 , -5.042491 , -4.910203 , -4.7105756 ,\n", - " -4.472313 , -4.1741037 , -3.8366807 , -3.4286702 , -2.970053 ,\n", - " -2.5127845 , -2.0011377 , -1.4785788 , -0.89798415, -0.33239934,\n", - " 0.24315366, 0.8609969 , 1.4341363 , 2.0365245 , 2.5912929 ,\n", - " 3.1261334 , 3.631675 , 4.12216 , 4.5922604 , 4.9916644 ,\n", - " 5.357276 , 5.6596603 , 5.918557 , 6.127722 , 6.2664776 ,\n", - " 6.3465624 , 6.3701463 , 6.3298607 , 6.2247863 , 6.0684857 ,\n", - " 5.8579984 , 5.5947313 , 5.2599936 , 4.8867702 , 4.4837136 ,\n", - " 4.040872 , 3.5451972 , 3.0077634 , 2.463368 , 1.9197478 ,\n", - " 1.3607975 , 0.7956455 , 0.19293702, -0.40335596, -0.9517328 ,\n", - " -1.4785621 , -1.9893605 , -2.4982615 , -2.9505756 , -3.3692324 ,\n", - " -3.7429237 , -4.1006875 , -4.402756 , -4.6468496 , -4.8368144 ,\n", - " -4.987736 , -5.0905485 , -5.1344304 , -5.1309776 , -5.0739946 ,\n", - " -4.993806 , -4.8492026 , -4.6650367 ], dtype=float32),\n", + " 'y': array([ 3.6775422 , 3.4448764 , 3.1795988 , 2.856169 , 2.5034661 ,\n", + " 2.132053 , 1.7310926 , 1.3034438 , 0.8253559 , 0.36245304,\n", + " -0.12688631, -0.6307145 , -1.1616575 , -1.7126511 , -2.2267792 ,\n", + " -2.7566912 , -3.2863293 , -3.8237705 , -4.32482 , -4.836178 ,\n", + " -5.361228 , -5.8546944 , -6.334524 , -6.7697473 , -7.2097106 ,\n", + " -7.6307163 , -8.015902 , -8.410988 , -8.759339 , -9.097741 ,\n", + " -9.40439 , -9.673478 , -9.938627 , -10.160235 , -10.360814 ,\n", + " -10.533252 , -10.672475 , -10.790765 , -10.871719 , -10.92612 ,\n", + " -10.949665 , -10.941065 , -10.902167 , -10.831596 , -10.7297535 ,\n", + " -10.584343 , -10.400615 , -10.213366 , -9.973349 , -9.703287 ,\n", + " -9.400256 , -9.039257 , -8.638971 , -8.218772 , -7.76986 ,\n", + " -7.29001 , -6.7592626 , -6.205373 , -5.610344 , -5.011173 ,\n", + " -4.306703 , -3.6050465 , -2.8899956 , -2.1421108 , -1.3980246 ,\n", + " -0.55479425, 0.25694594, 1.0867394 , 1.9551147 , 2.7880616 ,\n", + " 3.6983426 , 4.5652947 , 5.432634 , 6.2935643 , 7.1280026 ,\n", + " 7.9856586 , 8.809366 , 9.650012 , 10.417858 , 11.145185 ,\n", + " 11.858826 , 12.523886 , 13.137134 , 13.699438 , 14.241183 ,\n", + " 14.707007 , 15.098125 , 15.426196 , 15.693906 , 15.901334 ,\n", + " 16.017977 , 16.07477 , 16.05878 , 15.968486 , 15.792491 ,\n", + " 15.547381 , 15.236185 , 14.865028 , 14.416378 , 13.873076 ,\n", + " 13.277334 , 12.638401 , 11.962018 , 11.21623 , 10.383526 ,\n", + " 9.558632 , 8.688934 , 7.814783 , 6.8657885 , 5.8782206 ,\n", + " 4.93936 , 3.9818935 , 3.0527442 , 2.0779295 , 1.1511891 ,\n", + " 0.24442616, -0.68556976, -1.5159744 , -2.3555033 , -3.1236107 ,\n", + " -3.845236 , -4.5020337 , -5.122622 , -5.7106447 , -6.2054214 ,\n", + " -6.647813 , -7.003909 , -7.302908 , -7.544511 , -7.7055807 ,\n", + " -7.7959027 , -7.8254924 , -7.783019 , -7.673332 , -7.5097384 ,\n", + " -7.290905 , -7.017191 , -6.67524 , -6.2961516 , -5.888133 ,\n", + " -5.437577 , -4.9315596 , -4.396648 , -3.855047 , -3.313541 ,\n", + " -2.7510498 , -2.1820972 , -1.5792464 , -0.9815789 , -0.42582962,\n", + " 0.10404018, 0.6269935 , 1.1417981 , 1.6131238 , 2.0423858 ,\n", + " 2.4490814 , 2.8360126 , 3.1679304 , 3.4505346 , 3.676464 ,\n", + " 3.8737729 , 4.024116 , 4.1148095 , 4.157773 , 4.150377 ,\n", + " 4.109914 , 4.006662 , 3.8584704 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '22616063-3eff-4306-a74a-dc4965de0de9',\n", + " 'uid': '4495dca3-32d1-4ba1-8c78-6e1a1b0af24c',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([-4.292418 , -3.5483618, -3.0230176, ..., -5.4796743, -5.2587185,\n", - " -4.8447485], dtype=float32),\n", + " 'y': array([ 0.96416605, -0.55186653, -1.687766 , ..., 3.8534515 , 3.2716925 ,\n", + " 2.1252832 ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': '277a9945-45f4-45ca-91c5-7e8e3f338811',\n", + " 'uid': 'cceb9a1a-bc38-482b-a7d5-82b72a93a298',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([-1.6798731 , -2.3781397 , -2.901272 , ..., -0.19541107, -0.51879483,\n", - " -1.117872 ], dtype=float32),\n", + " 'y': array([-5.745831 , -5.2108707, -4.672284 , ..., -6.2985435, -6.2671404,\n", + " -6.0654645], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': 'fa756dff-2aaa-47cc-b16e-9266593c0172',\n", + " 'uid': 'fc2e7ac5-066d-48a0-8e16-b80235e12bfe',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([-0.03951903, 0.41645312, 0.02179232, -0.2604984 , -0.06300073,\n", - " -0.06662486, -0.08233377, -0.03597524, 0.08927898, -0.07381544],\n", - " dtype=float32),\n", + " 'y': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': '041656ea-aaa1-4e1b-abce-a1e942a626bb',\n", + " 'uid': '6a963e5b-a63f-45bf-ae7f-1fa296e6f623',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 0.20808354, 0.3050754 , 0.52504927, 0.02816455, -0.2267277 ,\n", - " -0.18377087, 0.34080964, 0.00188361, -0.14284115, 0.06430382,\n", - " 0.31131235, -0.09880974, 0.06406495, 0.25881714, 0.08779721,\n", - " -0.18321382, 0.2451885 , -0.23906691, -0.233605 , -0.05307174,\n", - " 0.17820123, 0.12141816, 0.0911953 , -0.10566162, 0.07743413,\n", - " 0.21802229, 0.35458654, 0.06151056, 0.23792064, -0.12219968,\n", - " -0.2825721 , -0.09865767, 0.25742164], dtype=float32),\n", + " 'y': array([-2.3441856e-01, -7.4248093e-01, 1.3433978e-01, 4.7361168e-01,\n", + " 4.8439783e-01, 2.8078523e-01, -1.9517194e-01, 2.2985543e-01,\n", + " 1.5531473e-01, -4.2631316e-01, 5.0868553e-01, 1.1522221e-01,\n", + " -4.8527386e-02, 2.0242128e-01, 4.4463417e-03, -2.3070528e-01,\n", + " 1.7045366e-02, -8.4169136e-05, 1.5831508e-01, -2.2444238e-01,\n", + " 1.4253077e-01, -2.9090768e-02, -1.4969027e-01, 3.8341036e-01,\n", + " -1.2710637e-01, -1.4844303e-01, 1.1406808e-01, -2.2177878e-01,\n", + " 2.8057915e-01, -3.3217099e-01, -1.6262497e-01, -3.2851827e-01,\n", + " -1.3853197e-01], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': '849f3240-aaa3-4374-b3fb-c1bd7cc14cfa',\n", + " 'uid': 'f045a98c-7b21-4761-b615-bc9edab8575b',\n", " 'width': 0.8,\n", - " 'x': array(['Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", - " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", - " \"Washington's Birthday_-1\", 'Christmas Day_+0', 'Christmas Day_+1',\n", - " 'Christmas Day_-1', 'Thanksgiving_+0', 'Thanksgiving_+1',\n", - " 'Thanksgiving_-1', 'Martin Luther King Jr. Day_+0',\n", - " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", - " 'Memorial Day_+0', 'Memorial Day_+1', 'Memorial Day_-1',\n", + " 'x': array([\"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", + " \"Washington's Birthday_-1\", 'Thanksgiving_+0', 'Thanksgiving_+1',\n", + " 'Thanksgiving_-1', 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1',\n", + " 'Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1', 'Independence Day_+0',\n", - " 'Independence Day_+1', 'Independence Day_-1', 'Columbus Day_+0',\n", - " 'Columbus Day_+1', 'Columbus Day_-1'], dtype=object),\n", + " 'Independence Day_+0', 'Independence Day_+1', 'Independence Day_-1',\n", + " 'Martin Luther King Jr. Day_+0', 'Martin Luther King Jr. Day_+1',\n", + " 'Martin Luther King Jr. Day_-1', 'Memorial Day_+0', 'Memorial Day_+1',\n", + " 'Memorial Day_-1', 'Columbus Day_+0', 'Columbus Day_+1',\n", + " 'Columbus Day_-1', 'Christmas Day_+0', 'Christmas Day_+1',\n", + " 'Christmas Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [1.7690346240997314, -4.356875419616699, -2.5583579540252686,\n", - " 3.7520101070404053, 1.3547093868255615, -1.4862573146820068,\n", - " 4.024331092834473, -0.7799521684646606, -1.7819913625717163,\n", - " -2.080281972885132, 0.33075717091560364, 4.571771144866943,\n", - " 2.3425700664520264, 1.175431251525879, 2.4367449283599854,\n", - " -2.1346323490142822, 3.684549331665039, 0.6624831557273865,\n", - " -2.1663002967834473, -2.142958164215088, 5.068490505218506,\n", - " -0.09585778415203094, 2.920788288116455, 3.8810973167419434,\n", - " 0.36290690302848816, -1.381648063659668, 1.097022533416748,\n", - " 2.787872552871704, 1.5658684968948364, 1.4216945171356201],\n", + " 'y': [-5.999994277954102, 0.15511037409305573, -0.6804019212722778,\n", + " -0.8969926834106445, -4.350093841552734, 2.10798978805542,\n", + " -4.097671031951904, 5.030608177185059, 3.1227762699127197,\n", + " 3.44264817237854, 4.6125640869140625, 0.7293226718902588,\n", + " 2.7135281562805176, -0.2420026659965515, 4.3908257484436035,\n", + " -7.856958866119385, -5.952345848083496, 5.613704204559326,\n", + " -7.10869026184082, -2.1775100231170654, 2.4739584922790527,\n", + " 0.04653293639421463, 1.881555438041687, -0.2442491501569748,\n", + " 1.7328133583068848, 3.332047462463379, -4.845430850982666,\n", + " 0.990510880947113, 3.7318599224090576, -1.215183973312378],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n", @@ -2488,7 +2744,7 @@ "})" ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -2521,7 +2777,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0rc1" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/tests/metrics/debug-yosemite.ipynb b/tests/debug/debug-yosemite.ipynb similarity index 100% rename from tests/metrics/debug-yosemite.ipynb rename to tests/debug/debug-yosemite.ipynb diff --git a/tests/metrics/debug_glocal.py b/tests/debug/debug_glocal.py similarity index 100% rename from tests/metrics/debug_glocal.py rename to tests/debug/debug_glocal.py diff --git a/tests/test_configure.py b/tests/test_configure.py index e5c5e9800..a93539e29 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -1,20 +1,6 @@ import pytest -from neuralprophet.configure import Train - - -def generate_config_train_params(overrides={}): - config_train_params = { - "quantiles": None, - "learning_rate": None, - "epochs": None, - "batch_size": None, - "loss_func": "SmoothL1Loss", - "optimizer": "AdamW", - } - for key, value in overrides.items(): - config_train_params[key] = value - return config_train_params +from neuralprophet import NeuralProphet def test_config_training_quantiles(): @@ -26,24 +12,21 @@ def test_config_training_quantiles(): ({"quantiles": [0.2, 0.8]}, [0.5, 0.2, 0.8]), ({"quantiles": [0.5, 0.8]}, [0.5, 0.8]), ] - for overrides, expected in checks: - config_train_params = generate_config_train_params(overrides) - config = Train(**config_train_params) - assert config.quantiles == expected + model = NeuralProphet(**overrides) + assert model.config_model.quantiles == expected def test_config_training_quantiles_error_invalid_type(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = "hello world" with pytest.raises(AssertionError) as err: - Train(**config_train_params) - assert str(err.value) == "Quantiles must be in a list format, not None or scalar." + _ = NeuralProphet(quantiles="hello world") + assert str(err.value) == "Quantiles must be provided as list." def test_config_training_quantiles_error_invalid_scale(): - config_train_params = generate_config_train_params() - config_train_params["quantiles"] = [-1] with pytest.raises(Exception) as err: - Train(**config_train_params) + _ = NeuralProphet(quantiles=[-1]) + assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." + with pytest.raises(Exception) as err: + _ = NeuralProphet(quantiles=[1.3]) assert str(err.value) == "The quantiles specified need to be floats in-between (0, 1)." diff --git a/tests/test_glocal.py b/tests/test_glocal.py index 9bda1882c..fe4719140 100644 --- a/tests/test_glocal.py +++ b/tests/test_glocal.py @@ -20,7 +20,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 256 EPOCHS = 1 -BATCH_SIZE = 128 +BATCH_SIZE = 32 LR = 1.0 PLOT = False @@ -60,7 +60,7 @@ def test_regularized_trend_global_local_modeling(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - m = NeuralProphet(n_lags=10, epochs=EPOCHS, trend_global_local="local", trend_reg=1) + m = NeuralProphet(n_lags=10, epochs=EPOCHS, learning_rate=LR, trend_global_local="local", trend_reg=1) train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) m.fit(train_df) future = m.make_future_dataframe(test_df) @@ -286,7 +286,9 @@ def test_adding_new_local_seasonality(): df2_0["ID"] = "df2" df3_0 = df.iloc[256:384, :].copy(deep=True) df3_0["ID"] = "df3" - m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, season_global_local="global", trend_global_local="local") + m = NeuralProphet( + epochs=EPOCHS, learning_rate=LR, batch_size=BATCH_SIZE, season_global_local="global", trend_global_local="local" + ) m.add_seasonality(period=30, fourier_order=8, name="monthly", global_local="local") train_df, test_df = m.split_df(pd.concat((df1_0, df2_0, df3_0)), valid_p=0.33, local_split=True) m.fit(train_df) diff --git a/tests/test_integration.py b/tests/test_integration.py index 8ef45b10a..ee4a9028d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -611,6 +611,20 @@ def test_loss_func_torch(): m.predict(future) +def test_loss_func_torch_lr_finder(): + log.info("TEST setting torch.nn loss func") + df = pd.read_csv(PEYTON_FILE, nrows=512) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + loss_func=torch.nn.MSELoss, + learning_rate=None, + ) + m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=10, n_historic_predictions=10) + m.predict(future) + + def test_callable_loss(): log.info("TEST Callable Loss") @@ -630,6 +644,7 @@ def my_loss(output, target): batch_size=BATCH_SIZE, seasonality_mode="multiplicative", loss_func=my_loss, + learning_rate=LR, ) m.fit(df, freq="5min") future = m.make_future_dataframe(df, periods=12 * 24, n_historic_predictions=12 * 24) @@ -659,6 +674,7 @@ def forward(self, input, target): epochs=EPOCHS, batch_size=BATCH_SIZE, loss_func=MyLoss, + learning_rate=LR, ) m.fit(df, freq="5min") future = m.make_future_dataframe(df, periods=12, n_historic_predictions=12) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 6631a4d43..e5c6d96eb 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -61,14 +61,17 @@ def test_regularization_holidays(): m = NeuralProphet( epochs=20, - batch_size=64, + batch_size=32, learning_rate=0.1, yearly_seasonality=False, weekly_seasonality=False, daily_seasonality=False, growth="off", ) - m = m.add_country_holidays("US", regularization=0.001) + m = m.add_country_holidays( + "US", + regularization=0.0001, + ) m.fit(df, freq="D") to_reduce = [] @@ -80,8 +83,8 @@ def test_regularization_holidays(): to_reduce.append(weight_list[0][0][0]) else: to_preserve.append(weight_list[0][0][0]) - # print(to_reduce) - # print(to_preserve) + # print(f"To reduce (< 0.2) {to_reduce}") + # print(f"To preserve (> 0.5) {to_preserve}") assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 @@ -100,7 +103,10 @@ def test_regularization_events(): daily_seasonality=False, growth="off", ) - m = m.add_events(["event_%i" % index for index, _ in enumerate(events)], regularization=REGULARIZATION) + m = m.add_events( + ["event_%i" % index for index, _ in enumerate(events)], + regularization=0.1, + ) events_df = pd.concat( [ pd.DataFrame( @@ -124,9 +130,9 @@ def test_regularization_events(): to_reduce.append(param.detach().numpy()[0][0]) else: to_preserve.append(param.detach().numpy()[0][0]) - # print(to_reduce) - # print(to_preserve) - assert np.mean(to_reduce) < 0.1 + # print(f"To reduce (< 0.2) {to_reduce}") + # print(f"To preserve (> 0.5) {to_preserve}") + assert np.mean(to_reduce) < 0.2 assert np.mean(to_preserve) > 0.5 diff --git a/tests/test_save.py b/tests/test_save.py new file mode 100644 index 000000000..1aeab44fe --- /dev/null +++ b/tests/test_save.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +import io +import logging +import os +import pathlib + +import pandas as pd +import pytest + +from neuralprophet import NeuralProphet, load, save + +log = logging.getLogger("NP.test") +log.setLevel("ERROR") +log.parent.setLevel("ERROR") + +DIR = pathlib.Path(__file__).parent.parent.absolute() +DATA_DIR = os.path.join(DIR, "tests", "test-data") +PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv") +AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv") +YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") +NROWS = 512 +EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 +LR = 1.0 +BATCH_SIZE = 64 + +PLOT = False + + +def test_save_load(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + _ = m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=3) + forecast = m.predict(df=future) + log.info("testing: save") + save(m, "test_model.pt") + + log.info("testing: load") + m2 = load("test_model.pt") + forecast2 = m2.predict(df=future) + + m3 = load("test_model.pt", map_location="cpu") + forecast3 = m3.predict(df=future) + + # Check that the forecasts are the same + pd.testing.assert_frame_equal(forecast, forecast2) + pd.testing.assert_frame_equal(forecast, forecast3) + + +def test_save_load_io(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=6, + n_forecasts=3, + n_changepoints=0, + ) + _ = m.fit(df, freq="D") + future = m.make_future_dataframe(df, periods=3) + forecast = m.predict(df=future) + + # Save the model to an in-memory buffer + log.info("testing: save to buffer") + buffer = io.BytesIO() + save(m, buffer) + buffer.seek(0) # Reset buffer position to the beginning + + log.info("testing: load from buffer") + m2 = load(buffer) + forecast2 = m2.predict(df=future) + + buffer.seek(0) # Reset buffer position to the beginning for another load + m3 = load(buffer, map_location="cpu") + forecast3 = m3.predict(df=future) + + # Check that the forecasts are the same + pd.testing.assert_frame_equal(forecast, forecast2) + pd.testing.assert_frame_equal(forecast, forecast3) + + +# def test_continue_training_checkpoint(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS) +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_continue_training_with_scheduler_selection(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# batch_size=BATCH_SIZE, +# learning_rate=LR, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# # Continue training with StepLR +# metrics2 = m.fit(df, freq="D", continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() + + +# def test_save_load_continue_training(): +# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) +# m = NeuralProphet( +# epochs=EPOCHS, +# n_lags=6, +# n_forecasts=3, +# n_changepoints=0, +# ) +# metrics = m.fit(df, checkpointing=True, freq="D") +# save(m, "test_model.pt") +# m2 = load("test_model.pt") +# metrics2 = m2.fit(df, continue_training=True, epochs=ADDITIONAL_EPOCHS, scheduler="StepLR") +# assert metrics["Loss"].min() >= metrics2["Loss"].min() diff --git a/tests/test_train_config.py b/tests/test_train_config.py new file mode 100644 index 000000000..e1ecbde8b --- /dev/null +++ b/tests/test_train_config.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +import io +import logging +import os +import pathlib + +import pandas as pd +import pytest + +from neuralprophet import NeuralProphet, df_utils, load, save + +log = logging.getLogger("NP.test") +log.setLevel("ERROR") +log.parent.setLevel("ERROR") + +DIR = pathlib.Path(__file__).parent.parent.absolute() +DATA_DIR = os.path.join(DIR, "tests", "test-data") +PEYTON_FILE = os.path.join(DATA_DIR, "wp_log_peyton_manning.csv") +AIR_FILE = os.path.join(DATA_DIR, "air_passengers.csv") +YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") +NROWS = 512 +EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 +LR = 1.0 +BATCH_SIZE = 64 + +PLOT = False + + +def generate_config_train_params(overrides={}): + config_train_params = { + "learning_rate": None, + "epochs": None, + "batch_size": None, + "loss_func": "SmoothL1Loss", + "optimizer": "AdamW", + } + for key, value in overrides.items(): + config_train_params[key] = value + return config_train_params + + +def test_custom_lr_scheduler(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + + # Set in NeuralProphet() + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + scheduler="CosineAnnealingWarmRestarts", + scheduler_args={"T_0": 5, "T_mult": 2}, + ) + metrics = m.fit(df, freq="D") + # Set in NeuralProphet(), no args + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + scheduler="StepLR", + ) + metrics = m.fit(df, freq="D") + + # Set in fit() + m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) + metrics = m.fit( + df, + freq="D", + scheduler="ExponentialLR", + scheduler_args={"gamma": 0.95}, + ) + + # Set in fit(), no args + m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) + metrics = m.fit( + df, + freq="D", + scheduler="OneCycleLR", + ) diff --git a/tests/test_unit.py b/tests/test_unit.py index 2032ffecb..42f19d218 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -474,6 +474,7 @@ def test_reg_delay(): ) m.fit(df, freq="D") c = m.config_train + # weight, epoch, epoch_iteration_progress for w, e, i in [ (0, 0, 1), (0, 3, 0), @@ -484,7 +485,8 @@ def test_reg_delay(): (1, 7, 1), (1, 8, 0), ]: - weight = c.get_reg_delay_weight(e, i, reg_start_pct=0.5, reg_full_pct=0.8) + progress = float(e + i) / 10.0 + weight = c.get_reg_delay_weight(progress=progress, reg_start_pct=0.5, reg_full_pct=0.8) assert weight == w diff --git a/tests/test_utils.py b/tests/test_utils.py index 1bda8f108..d076d3d28 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import io import logging import os import pathlib @@ -8,7 +7,7 @@ import pandas as pd import pytest -from neuralprophet import NeuralProphet, df_utils, load, save +from neuralprophet import NeuralProphet, df_utils log = logging.getLogger("NP.test") log.setLevel("ERROR") @@ -21,6 +20,7 @@ YOS_FILE = os.path.join(DATA_DIR, "yosemite_temps.csv") NROWS = 512 EPOCHS = 10 +ADDITIONAL_EPOCHS = 5 LR = 1.0 BATCH_SIZE = 64 @@ -39,85 +39,8 @@ def test_create_dummy_datestamps(): _ = m.fit(df_dummy) _ = m.make_future_dataframe(df_dummy, periods=365, n_historic_predictions=True) - + def test_no_log(): df = pd.read_csv(PEYTON_FILE, nrows=NROWS) m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) _ = m.fit(df, metrics=False, metrics_log_dir=False) - - -def test_save_load(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - _ = m.fit(df, freq="D") - future = m.make_future_dataframe(df, periods=3) - forecast = m.predict(df=future) - log.info("testing: save") - save(m, "test_model.pt") - - log.info("testing: load") - m2 = load("test_model.pt") - forecast2 = m2.predict(df=future) - - m3 = load("test_model.pt", map_location="cpu") - forecast3 = m3.predict(df=future) - - # Check that the forecasts are the same - pd.testing.assert_frame_equal(forecast, forecast2) - pd.testing.assert_frame_equal(forecast, forecast3) - - -def test_save_load_io(): - df = pd.read_csv(PEYTON_FILE, nrows=NROWS) - m = NeuralProphet( - epochs=EPOCHS, - batch_size=BATCH_SIZE, - learning_rate=LR, - n_lags=6, - n_forecasts=3, - n_changepoints=0, - ) - _ = m.fit(df, freq="D") - future = m.make_future_dataframe(df, periods=3) - forecast = m.predict(df=future) - - # Save the model to an in-memory buffer - log.info("testing: save to buffer") - buffer = io.BytesIO() - save(m, buffer) - buffer.seek(0) # Reset buffer position to the beginning - - log.info("testing: load from buffer") - m2 = load(buffer) - forecast2 = m2.predict(df=future) - - buffer.seek(0) # Reset buffer position to the beginning for another load - m3 = load(buffer, map_location="cpu") - forecast3 = m3.predict(df=future) - - # Check that the forecasts are the same - pd.testing.assert_frame_equal(forecast, forecast2) - pd.testing.assert_frame_equal(forecast, forecast3) - - -# TODO: add functionality to continue training -# def test_continue_training(): -# df = pd.read_csv(PEYTON_FILE, nrows=NROWS) -# m = NeuralProphet( -# epochs=EPOCHS, -# batch_size=BATCH_SIZE, -# learning_rate=LR, -# n_lags=6, -# n_forecasts=3, -# n_changepoints=0, -# ) -# metrics = m.fit(df, freq="D") -# metrics2 = m.fit(df, freq="D", continue_training=True) -# assert metrics1["Loss"].sum() >= metrics2["Loss"].sum()