From b79b7e109508399c8f8cf05d2454cf6491095835 Mon Sep 17 00:00:00 2001 From: ourownstory Date: Wed, 28 Aug 2024 18:00:38 -0700 Subject: [PATCH] fix lr-finder --- neuralprophet/configure.py | 5 +- neuralprophet/forecaster.py | 77 +- neuralprophet/time_net.py | 11 +- neuralprophet/utils.py | 39 +- tests/debug/debug-energy-price-hourly.ipynb | 995 +++++++------------- 5 files changed, 401 insertions(+), 726 deletions(-) diff --git a/neuralprophet/configure.py b/neuralprophet/configure.py index ac6ea2330..467922213 100644 --- a/neuralprophet/configure.py +++ b/neuralprophet/configure.py @@ -279,10 +279,11 @@ def set_lr_finder_args(self, dataset_size, num_batches): # num_training = num_batches self.lr_finder_args.update( { - "min_lr": 1e-7, - "max_lr": 10, + "min_lr": 1e-8, + "max_lr": 1e1, "num_training": num_training, "early_stop_threshold": None, + "mode": "exponential", } ) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index f0f7f1b36..746549462 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2796,7 +2796,10 @@ def _train( # Set up data the training dataloader df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be removed? train_loader = self._init_train_loader(df, num_workers) - dataset_size = len(df) # train_loader.dataset + dataset_size = len(train_loader.dataset) # df + batches_per_epoch = len(train_loader) + log.info(f"Dataset size: {dataset_size}") + log.info(f"Number of batches per training epoch: {batches_per_epoch}") # Internal flag to check if validation is enabled validation_enabled = df_val is not None @@ -2818,55 +2821,41 @@ def _train( deterministic=deterministic, ) + # Find suitable learning rate + if not self.config_train.learning_rate: + log.info("No Learning Rate provided. Activating learning rate finder") + # Set parameters for the learning rate finder + self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=batches_per_epoch) + log.info(f"Learning rate finder ---- ARGs: {self.config_train.lr_finder_args}") + self.model.finding_lr = True + tuner = Tuner(self.trainer) + lr_finder = tuner.lr_find( + model=self.model, + train_dataloaders=train_loader, + # val_dataloaders=val_loader, # not used, but may lead to Lightning bug if not provided + **self.config_train.lr_finder_args, + ) + # Estimate the optimal learning rate from the loss curve + assert lr_finder is not None + _, _, lr_suggested = utils.smooth_loss_and_suggest(lr_finder) + self.model.learning_rate = lr_suggested + self.config_train.learning_rate = lr_suggested + log.info(f"Learning rate finder suggested learning rate: {lr_suggested}") + self.model.finding_lr = False + # Tune hyperparams and train if validation_enabled: # Set up data the validation dataloader df_val, _, _, _ = df_utils.prep_or_copy_df(df_val) val_loader = self._init_val_loader(df_val) - if not self.config_train.learning_rate: - # Find suitable learning rate - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - self.model.finding_lr = True - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - # val_dataloaders=val_loader, # not be used, but may lead to Lightning bug if not provided - **self.config_train.lr_finder_args, - ) - # Estimate the optimal learning rate from the loss curve - assert lr_finder is not None - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - self.model.finding_lr = False - start = time.time() - self.trainer.fit( - self.model, - train_loader, - val_loader, - ) - else: - if not self.config_train.learning_rate: - # Find suitable learning rate - # Set parameters for the learning rate finder - self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader)) - self.model.finding_lr = True - tuner = Tuner(self.trainer) - lr_finder = tuner.lr_find( - model=self.model, - train_dataloaders=train_loader, - **self.config_train.lr_finder_args, - ) - assert lr_finder is not None - # Estimate the optimal learning rate from the loss curve - _, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder) - self.model.finding_lr = False - start = time.time() - self.trainer.fit( - self.model, - train_loader, - ) + self.model.finding_lr = False + start = time.time() + self.trainer.fit( + model=self.model, + train_dataloaders=train_loader, + val_dataloaders=val_loader if validation_enabled else None, + ) log.debug("Train Time: {:8.3f}".format(time.time() - start)) diff --git a/neuralprophet/time_net.py b/neuralprophet/time_net.py index 30fb7a56e..a1148aa4f 100644 --- a/neuralprophet/time_net.py +++ b/neuralprophet/time_net.py @@ -775,8 +775,8 @@ def loss_func(self, inputs, predicted, targets): def training_step(self, batch, batch_idx): inputs, targets, meta = batch - epoch_float = self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch) - self.train_progress = epoch_float / self.config_train.epochs + epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch) + self.train_progress = epoch_float / float(self.config_train.epochs) # Global-local if self.meta_used_in_model: meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device) @@ -796,7 +796,10 @@ def training_step(self, batch, batch_idx): optimizer.step() scheduler = self.lr_schedulers() - scheduler.step(epoch=epoch_float) + if self.finding_lr: + scheduler.step() + else: + scheduler.step(epoch=epoch_float) if self.finding_lr: # Manually track the loss for the lr finder @@ -874,7 +877,7 @@ def configure_optimizers(self): # Optimizer if self.finding_lr and self.learning_rate is None: - self.learning_rate = self.config_train.lr_finder_args["min_lr"] + self.learning_rate = 0.1 optimizer = self.config_train.optimizer( self.parameters(), lr=self.learning_rate, diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index cc5a3ed16..309d9098f 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -771,17 +771,17 @@ def smooth_loss_and_suggest(lr_finder, window=10): """ lr_finder_results = lr_finder.results lr = lr_finder_results["lr"] - loss = lr_finder_results["loss"] + loss = np.array(lr_finder_results["loss"]) # Derive window size from num lr searches, ensure window is divisible by 2 # half_window = math.ceil(round(len(loss) * 0.1) / 2) half_window = math.ceil(window / 2) # Pad sequence and initialialize hamming filter - loss = np.pad(np.array(loss), pad_width=half_window, mode="edge") - window = np.hamming(half_window * 2) + loss = np.pad(loss, pad_width=half_window, mode="edge") + hamming_window = np.hamming(2 * half_window) # Convolve the over the loss distribution try: - loss = np.convolve( - window / window.sum(), + loss_smooth = np.convolve( + hamming_window / hamming_window.sum(), loss, mode="valid", )[1:] @@ -790,26 +790,41 @@ def smooth_loss_and_suggest(lr_finder, window=10): f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of " f"{window}." ) + # Suggest the lr with steepest negative gradient try: # Find the steepest gradient and the minimum loss after that - suggestion = lr[np.argmin(np.gradient(loss))] + suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))] + suggestion_minimum = lr[np.argmin(loss_smooth)] except ValueError: log.error( f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of " "samples or manually set the learning rate." ) raise - suggestion_default = lr_finder.suggestion(skip_begin=10, skip_end=3) - if suggestion is not None and suggestion_default is not None: - log_suggestion_smooth = np.log(suggestion) + # get the tuner's default suggestion + suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10) + + log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}") + log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}") + log.info(f"Learning rate finder ---- minimum: {suggestion_minimum}") + if suggestion_steepest is not None and suggestion_minimum is not None and suggestion_default is not None: + log_suggestion_smooth = np.log(suggestion_steepest) + log_suggestion_minimum = np.log(suggestion_minimum) log_suggestion_default = np.log(suggestion_default) - lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2) - elif suggestion is None and suggestion_default is None: + lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_minimum + log_suggestion_default) / 3) + log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}") + elif suggestion_steepest is None and suggestion_default is None: log.error("Automatic learning rate test failed. Please set manually the learning rate.") raise else: - lr_suggestion = suggestion if suggestion is not None else suggestion_default + lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default + + log.info(f"Learning rate finder ---- returning: {lr_suggestion}") + log.info(f"Learning rate finder ---- LR (start): {lr[:5]}") + log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}") + log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}") + log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}") return (loss, lr, lr_suggestion) diff --git a/tests/debug/debug-energy-price-hourly.ipynb b/tests/debug/debug-energy-price-hourly.ipynb index a8c769d20..ab4485f1f 100644 --- a/tests/debug/debug-energy-price-hourly.ipynb +++ b/tests/debug/debug-energy-price-hourly.ipynb @@ -16,7 +16,9 @@ "from plotly.subplots import make_subplots\n", "from plotly_resampler import unregister_plotly_resampler\n", "\n", - "from neuralprophet import NeuralProphet, set_random_seed" + "from neuralprophet import NeuralProphet, set_random_seed, set_log_level\n", + "\n", + "set_log_level(\"INFO\")" ] }, { @@ -169,7 +171,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -186,13 +188,13 @@ " \"yearly_seasonality\": 10,\n", " \"weekly_seasonality\": True,\n", " \"daily_seasonality\": False, # due to conditional daily seasonality\n", - " \"batch_size\": 64,\n", + " \"batch_size\": 32,\n", " \"ar_layers\": [8, 4],\n", " \"lagged_reg_layers\": [8],\n", " # not tuned\n", " \"n_forecasts\": 5,\n", - " \"learning_rate\": 0.1,\n", - " \"epochs\": 20,\n", + " # \"learning_rate\": 0.1,\n", + " \"epochs\": 10,\n", " \"trend_global_local\": \"global\",\n", " \"season_global_local\": \"global\",\n", " \"drop_missing\": True,\n", @@ -239,6 +241,7 @@ "output_type": "stream", "text": [ "INFO - (NP.forecaster.fit) - When Global modeling with local normalization, metrics are displayed in normalized scale.\n", + "WARNING - (NP.forecaster.fit) - Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD\n", "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/df_utils.py:1149: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.\n", " converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)\n", "\n", @@ -267,13 +270,15 @@ "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", "\n", + "INFO - (NP.forecaster._train) - Dataset size: 2758\n", + "INFO - (NP.forecaster._train) - Number of batches per training epoch: 87\n", "INFO - (NP.utils.configure_trainer) - Using accelerator cpu with 1 device(s).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a3a2688119ad4f35babe4b5751d7a677", + "model_id": "3c6d261e96524335a24f00923ad36c02", "version_major": 2, "version_minor": 0 }, @@ -288,23 +293,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", - "\n", - "WARNING - (py.warnings._showwarnmsg) - /home/tabletop/github/neural_prophet/neuralprophet/time_dataset.py:692: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " contains_nan = torch.cat([torch.tensor(contains_nan), torch.ones(n_forecasts, dtype=torch.bool)])\n", - "\n" + "INFO - (NP.forecaster._train) - No Learning Rate provided. Activating learning rate finder\n", + "WARNING - (NP.config.set_lr_finder_args) - Learning rate finder: The number of batches (87) is too small than the required number for the learning rate finder (168). The results might not be optimal.\n", + "INFO - (NP.forecaster._train) - Learning rate finder ---- ARGs: {'min_lr': 1e-08, 'max_lr': 10.0, 'num_training': 168, 'early_stop_threshold': None, 'mode': 'exponential'}\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "be32cb53d78b4ead975b12aa5ad15196", + "model_id": "35957cbfc5044d2eab5eb3fe1ccee7c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Training: | | 0/? [00:00\n", " \n", " 0\n", - " 0.499936\n", - " 0.583346\n", - " 0.938270\n", + " 0.710609\n", + " 0.819329\n", + " 0.622614\n", " 0.0\n", " 0\n", - " 1.503294\n", - " 2.114124\n", - " 1.916612\n", + " 1.144053\n", + " 1.562842\n", + " 1.368656\n", " 0.0\n", - " 0.004087\n", + " 0.012448\n", " \n", " \n", " 1\n", - " 0.534045\n", - " 0.631530\n", - " 0.440998\n", + " 0.836989\n", + " 0.946932\n", + " 0.733583\n", " 0.0\n", " 1\n", - " 0.718145\n", - " 0.943761\n", - " 0.505523\n", + " 0.532974\n", + " 0.702971\n", + " 0.343495\n", " 0.0\n", - " 0.021600\n", + " 0.039781\n", " \n", " \n", " 2\n", - " 0.542755\n", - " 0.644081\n", - " 0.454675\n", + " 0.588745\n", + " 0.704277\n", + " 0.497290\n", " 0.0\n", " 2\n", - " 0.537536\n", - " 0.724863\n", - " 0.347341\n", + " 0.495191\n", + " 0.658636\n", + " 0.304316\n", " 0.0\n", - " 0.050152\n", + " 0.040028\n", " \n", " \n", " 3\n", - " 0.508438\n", - " 0.616892\n", - " 0.487512\n", + " 0.699847\n", + " 0.818369\n", + " 0.594933\n", " 0.0\n", " 3\n", - " 0.503906\n", - " 0.677358\n", - " 0.312197\n", + " 0.475755\n", + " 0.632438\n", + " 0.283402\n", " 0.0\n", - " 0.078837\n", + " 0.012695\n", " \n", " \n", " 4\n", - " 0.649246\n", - " 0.755430\n", - " 0.545550\n", + " 0.704670\n", + " 0.828111\n", + " 0.594259\n", " 0.0\n", " 4\n", - " 0.505661\n", - " 0.671692\n", - " 0.313073\n", + " 0.460465\n", + " 0.615198\n", + " 0.271323\n", " 0.0\n", - " 0.096699\n", + " 0.004634\n", " \n", " \n", " 5\n", - " 0.463848\n", - " 0.568442\n", - " 0.367994\n", + " 0.648891\n", + " 0.755905\n", + " 0.530240\n", " 0.0\n", " 5\n", - " 0.520102\n", - " 0.691615\n", - " 0.322044\n", + " 0.458983\n", + " 0.614499\n", + " 0.270039\n", " 0.0\n", - " 0.099596\n", + " 0.003871\n", " \n", " \n", " 6\n", - " 0.355072\n", - " 0.410634\n", - " 0.251356\n", + " 0.715093\n", + " 0.839661\n", + " 0.608006\n", " 0.0\n", " 6\n", - " 0.511964\n", - " 0.684423\n", - " 0.316359\n", + " 0.459262\n", + " 0.614916\n", + " 0.269727\n", " 0.0\n", - " 0.097137\n", + " 0.002631\n", " \n", " \n", " 7\n", - " 0.447367\n", - " 0.500184\n", - " 0.336913\n", + " 0.689927\n", + " 0.807763\n", + " 0.577745\n", " 0.0\n", " 7\n", - " 0.503181\n", - " 0.669457\n", - " 0.307173\n", + " 0.455545\n", + " 0.609835\n", + " 0.266736\n", " 0.0\n", - " 0.092315\n", + " 0.001389\n", " \n", " \n", " 8\n", - " 0.821846\n", - " 0.951728\n", - " 0.720978\n", + " 0.646029\n", + " 0.753095\n", + " 0.530006\n", " 0.0\n", " 8\n", - " 0.503031\n", - " 0.671102\n", - " 0.308114\n", + " 0.457110\n", + " 0.611970\n", + " 0.267934\n", " 0.0\n", - " 0.085371\n", + " 0.000618\n", " \n", " \n", " 9\n", - " 0.414638\n", - " 0.474769\n", - " 0.334302\n", + " 0.688786\n", + " 0.806359\n", + " 0.577159\n", " 0.0\n", " 9\n", - " 0.511291\n", - " 0.686945\n", - " 0.311271\n", + " 0.456113\n", + " 0.611694\n", + " 0.267121\n", " 0.0\n", - " 0.076654\n", - " \n", - " \n", - " 10\n", - " 0.606577\n", - " 0.723609\n", - " 0.504883\n", - " 0.0\n", - " 10\n", - " 0.493725\n", - " 0.657971\n", - " 0.301624\n", - " 0.0\n", - " 0.066600\n", - " \n", - " \n", - " 11\n", - " 0.560590\n", - " 0.657100\n", - " 0.453766\n", - " 0.0\n", - " 11\n", - " 0.487225\n", - " 0.654937\n", - " 0.295672\n", - " 0.0\n", - " 0.055713\n", - " \n", - " \n", - " 12\n", - " 0.419592\n", - " 0.459256\n", - " 0.307631\n", - " 0.0\n", - " 12\n", - " 0.479861\n", - " 0.642756\n", - " 0.287683\n", - " 0.0\n", - " 0.044541\n", - " \n", - " \n", - " 13\n", - " 0.492459\n", - " 0.561360\n", - " 0.379794\n", - " 0.0\n", - " 13\n", - " 0.479290\n", - " 0.643680\n", - " 0.284241\n", - " 0.0\n", - " 0.033641\n", - " \n", - " \n", - " 14\n", - " 0.547214\n", - " 0.630017\n", - " 0.432885\n", - " 0.0\n", - " 14\n", - " 0.471661\n", - " 0.633883\n", - " 0.280081\n", - " 0.0\n", - " 0.023563\n", - " \n", - " \n", - " 15\n", - " 0.542842\n", - " 0.630828\n", - " 0.427475\n", - " 0.0\n", - " 15\n", - " 0.467507\n", - " 0.630942\n", - " 0.275439\n", - " 0.0\n", - " 0.014810\n", - " \n", - " \n", - " 16\n", - " 0.497607\n", - " 0.569062\n", - " 0.380621\n", - " 0.0\n", - " 16\n", - " 0.468031\n", - " 0.631191\n", - " 0.276560\n", - " 0.0\n", - " 0.007821\n", - " \n", - " \n", - " 17\n", - " 0.507053\n", - " 0.580275\n", - " 0.390214\n", - " 0.0\n", - " 17\n", - " 0.458170\n", - " 0.620013\n", - " 0.268218\n", - " 0.0\n", - " 0.002948\n", - " \n", - " \n", - " 18\n", - " 0.506170\n", - " 0.578457\n", - " 0.389007\n", - " 0.0\n", - " 18\n", - " 0.460292\n", - " 0.622816\n", - " 0.268188\n", - " 0.0\n", - " 0.000434\n", - " \n", - " \n", - " 19\n", - " 0.508543\n", - " 0.581377\n", - " 0.391374\n", - " 0.0\n", - " 19\n", - " 0.459247\n", - " 0.622094\n", - " 0.267627\n", - " 0.0\n", - " 0.000405\n", + " 0.000613\n", " \n", " \n", "\n", "" ], "text/plain": [ - " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", - "0 0.499936 0.583346 0.938270 0.0 0 1.503294 2.114124 \n", - "1 0.534045 0.631530 0.440998 0.0 1 0.718145 0.943761 \n", - "2 0.542755 0.644081 0.454675 0.0 2 0.537536 0.724863 \n", - "3 0.508438 0.616892 0.487512 0.0 3 0.503906 0.677358 \n", - "4 0.649246 0.755430 0.545550 0.0 4 0.505661 0.671692 \n", - "5 0.463848 0.568442 0.367994 0.0 5 0.520102 0.691615 \n", - "6 0.355072 0.410634 0.251356 0.0 6 0.511964 0.684423 \n", - "7 0.447367 0.500184 0.336913 0.0 7 0.503181 0.669457 \n", - "8 0.821846 0.951728 0.720978 0.0 8 0.503031 0.671102 \n", - "9 0.414638 0.474769 0.334302 0.0 9 0.511291 0.686945 \n", - "10 0.606577 0.723609 0.504883 0.0 10 0.493725 0.657971 \n", - "11 0.560590 0.657100 0.453766 0.0 11 0.487225 0.654937 \n", - "12 0.419592 0.459256 0.307631 0.0 12 0.479861 0.642756 \n", - "13 0.492459 0.561360 0.379794 0.0 13 0.479290 0.643680 \n", - "14 0.547214 0.630017 0.432885 0.0 14 0.471661 0.633883 \n", - "15 0.542842 0.630828 0.427475 0.0 15 0.467507 0.630942 \n", - "16 0.497607 0.569062 0.380621 0.0 16 0.468031 0.631191 \n", - "17 0.507053 0.580275 0.390214 0.0 17 0.458170 0.620013 \n", - "18 0.506170 0.578457 0.389007 0.0 18 0.460292 0.622816 \n", - "19 0.508543 0.581377 0.391374 0.0 19 0.459247 0.622094 \n", + " MAE_val RMSE_val Loss_val RegLoss_val epoch MAE RMSE \\\n", + "0 0.710609 0.819329 0.622614 0.0 0 1.144053 1.562842 \n", + "1 0.836989 0.946932 0.733583 0.0 1 0.532974 0.702971 \n", + "2 0.588745 0.704277 0.497290 0.0 2 0.495191 0.658636 \n", + "3 0.699847 0.818369 0.594933 0.0 3 0.475755 0.632438 \n", + "4 0.704670 0.828111 0.594259 0.0 4 0.460465 0.615198 \n", + "5 0.648891 0.755905 0.530240 0.0 5 0.458983 0.614499 \n", + "6 0.715093 0.839661 0.608006 0.0 6 0.459262 0.614916 \n", + "7 0.689927 0.807763 0.577745 0.0 7 0.455545 0.609835 \n", + "8 0.646029 0.753095 0.530006 0.0 8 0.457110 0.611970 \n", + "9 0.688786 0.806359 0.577159 0.0 9 0.456113 0.611694 \n", "\n", - " Loss RegLoss LR \n", - "0 1.916612 0.0 0.004087 \n", - "1 0.505523 0.0 0.021600 \n", - "2 0.347341 0.0 0.050152 \n", - "3 0.312197 0.0 0.078837 \n", - "4 0.313073 0.0 0.096699 \n", - "5 0.322044 0.0 0.099596 \n", - "6 0.316359 0.0 0.097137 \n", - "7 0.307173 0.0 0.092315 \n", - "8 0.308114 0.0 0.085371 \n", - "9 0.311271 0.0 0.076654 \n", - "10 0.301624 0.0 0.066600 \n", - "11 0.295672 0.0 0.055713 \n", - "12 0.287683 0.0 0.044541 \n", - "13 0.284241 0.0 0.033641 \n", - "14 0.280081 0.0 0.023563 \n", - "15 0.275439 0.0 0.014810 \n", - "16 0.276560 0.0 0.007821 \n", - "17 0.268218 0.0 0.002948 \n", - "18 0.268188 0.0 0.000434 \n", - "19 0.267627 0.0 0.000405 " + " Loss RegLoss LR \n", + "0 1.368656 0.0 0.012448 \n", + "1 0.343495 0.0 0.039781 \n", + "2 0.304316 0.0 0.040028 \n", + "3 0.283402 0.0 0.012695 \n", + "4 0.271323 0.0 0.004634 \n", + "5 0.270039 0.0 0.003871 \n", + "6 0.269727 0.0 0.002631 \n", + "7 0.266736 0.0 0.001389 \n", + "8 0.267934 0.0 0.000618 \n", + "9 0.267121 0.0 0.000613 " ] }, "execution_count": 9, @@ -2375,7 +2047,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "23b4bbc68cea4966bf719a33132a3726", + "model_id": "9c2666a108ac4123919f8203b5f548b1", "version_major": 2, "version_minor": 0 }, @@ -2400,7 +2072,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dd54d62acf464964b7d15974106128e3", + "model_id": "08c352843e98443a8b24b2713b9636b1", "version_major": 2, "version_minor": 0 }, @@ -2486,7 +2158,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c873800851d442818fb758ee0b8565b0", + "model_id": "fdb0f6945c704c788e4243c7789a7e29", "version_major": 2, "version_minor": 0 }, @@ -2497,61 +2169,61 @@ " 'mode': 'lines',\n", " 'name': '[R] yhat5 1.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '1e41dda2-f7ca-4501-ae0d-394dbc69313f',\n", + " 'uid': 'cceaf554-f88b-47ac-b077-bd98eebd51bd',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 18, 0),\n", + " datetime.datetime(2015, 3, 2, 19, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([ 7.0392876, 9.7315445, 17.582043 , ..., 48.029076 , 46.43782 ,\n", - " 48.867878 ], dtype=float32)},\n", + " 'y': array([21.843597, 25.104948, 33.001038, ..., 46.6899 , 41.747295, 48.700737],\n", + " dtype=float32)},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(45, 146, 255, 0.2)',\n", " 'line': {'color': 'rgba(45, 146, 255, 0.2)', 'width': 1},\n", " 'mode': 'lines',\n", " 'name': '[R] yhat5 99.0% ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '8b1256ed-4f9c-4a39-98f6-d94c1af49272',\n", + " 'uid': 'bda8f5ed-a117-47cf-8298-3b7802a38dcc',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([64.257675, 64.03381 , 72.06434 , ..., 74.77048 , 70.81393 , 73.04162 ],\n", + " 'y': array([83.986046, 89.35564 , 74.86833 , ..., 74.58182 , 77.62551 , 77.05947 ],\n", " dtype=float32)},\n", " {'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': 'c9c42f86-4573-4aea-8a34-4980b55458a9',\n", + " 'uid': 'a0f939c9-0e13-45ce-8081-3dac7cf67c72',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", + " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", " dtype=float32)},\n", " {'marker': {'color': 'blue', 'size': 4, 'symbol': 'x'},\n", " 'mode': 'markers',\n", " 'name': '[R] Predicted ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '8704218f-c879-46d2-98f8-70840910069f',\n", + " 'uid': '555cf752-d8ea-41a9-8f47-596ee22a34be',\n", " 'x': array([datetime.datetime(2015, 1, 2, 13, 0),\n", " datetime.datetime(2015, 1, 2, 14, 0),\n", " datetime.datetime(2015, 1, 2, 15, 0), ...,\n", " datetime.datetime(2015, 3, 2, 17, 0),\n", - " datetime.datetime(2015, 3, 2, 19, 0),\n", + " datetime.datetime(2015, 3, 2, 18, 0),\n", " datetime.datetime(2015, 3, 2, 20, 0)], dtype=object),\n", - " 'y': array([41.12671 , 40.386654, 44.95556 , ..., 63.092262, 65.077774, 63.004234],\n", + " 'y': array([47.22839 , 49.346603, 51.1183 , ..., 58.201473, 60.27031 , 59.059807],\n", " dtype=float32)},\n", " {'marker': {'color': 'black', 'size': 4},\n", " 'mode': 'markers',\n", " 'name': '[R] Actual ~1h',\n", " 'type': 'scatter',\n", - " 'uid': '91a4e9d7-9480-462b-b473-7f38c9371ea4',\n", + " 'uid': '09a782ca-96bb-4a77-80a7-fd42484a363d',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 1, 1, 1, 0),\n", " datetime.datetime(2015, 1, 1, 2, 0), ...,\n", @@ -2714,7 +2386,7 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_components\u001b[49m\u001b[43m(\u001b[49m\u001b[43mforecast\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2452\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2450\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2451\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2452\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2453\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2454\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2455\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2456\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2457\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2458\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2459\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2460\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2461\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2462\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2464\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2465\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2470\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2471\u001b[0m )\n", + "File \u001b[0;32m~/github/neural_prophet/neuralprophet/forecaster.py:2465\u001b[0m, in \u001b[0;36mNeuralProphet.plot_components\u001b[0;34m(self, fcst, df_name, figsize, forecast_in_focus, plotting_backend, components, one_period_per_season)\u001b[0m\n\u001b[1;32m 2463\u001b[0m log_warning_deprecation_plotly(plotting_backend)\n\u001b[1;32m 2464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m plotting_backend\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplotly\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 2465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_components_plotly\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2466\u001b[0m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[43m \u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2468\u001b[0m \u001b[43m \u001b[49m\u001b[43mplot_configuration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_plot_configuration\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2469\u001b[0m \u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mfigsize\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m210\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2470\u001b[0m \u001b[43m \u001b[49m\u001b[43mdf_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2471\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_period_per_season\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mone_period_per_season\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2472\u001b[0m \u001b[43m \u001b[49m\u001b[43mresampler_active\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-resampler\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2473\u001b[0m \u001b[43m \u001b[49m\u001b[43mplotly_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplotting_backend\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mplotly-static\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2474\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2475\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2476\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m plot_components(\n\u001b[1;32m 2477\u001b[0m m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 2478\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2483\u001b[0m one_period_per_season\u001b[38;5;241m=\u001b[39mone_period_per_season,\n\u001b[1;32m 2484\u001b[0m )\n", "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:332\u001b[0m, in \u001b[0;36mplot_components\u001b[0;34m(m, fcst, plot_configuration, df_name, one_period_per_season, figsize, resampler_active, plotly_static)\u001b[0m\n\u001b[1;32m 327\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m get_forecast_component_props(\n\u001b[1;32m 328\u001b[0m fcst\u001b[38;5;241m=\u001b[39mfcst, df_name\u001b[38;5;241m=\u001b[39mdf_name, comp_name\u001b[38;5;241m=\u001b[39mcomp_name, plot_name\u001b[38;5;241m=\u001b[39mcomp[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mplot_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 329\u001b[0m )\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto-regression\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlagged regressor\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m name:\n\u001b[0;32m--> 332\u001b[0m trace_object \u001b[38;5;241m=\u001b[39m \u001b[43mget_multiforecast_component_props\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfcst\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfcst\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcomp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 333\u001b[0m fig\u001b[38;5;241m.\u001b[39mupdate_layout(barmode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moverlay\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 335\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "File \u001b[0;32m~/github/neural_prophet/neuralprophet/plot_forecast_plotly.py:603\u001b[0m, in \u001b[0;36mget_multiforecast_component_props\u001b[0;34m(fcst, comp_name, plot_name, multiplicative, bar, focus, num_overplot, **kwargs)\u001b[0m\n\u001b[1;32m 601\u001b[0m y \u001b[38;5;241m=\u001b[39m fcst[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcomp_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 602\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mvalues\n\u001b[0;32m--> 603\u001b[0m \u001b[43my\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bar:\n\u001b[1;32m 605\u001b[0m traces\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 606\u001b[0m go\u001b[38;5;241m.\u001b[39mBar(\n\u001b[1;32m 607\u001b[0m name\u001b[38;5;241m=\u001b[39mplot_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 613\u001b[0m )\n\u001b[1;32m 614\u001b[0m )\n", "\u001b[0;31mIndexError\u001b[0m: index -1 is out of bounds for axis 0 with size 0" @@ -2759,7 +2431,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "06ddcd4df2f94141aefe0f5056b5c617", + "model_id": "a09e98a70e3e4d4da2d2fe3865b8abc7", "version_major": 2, "version_minor": 0 }, @@ -2770,18 +2442,18 @@ " 'mode': 'lines',\n", " 'name': 'Trend',\n", " 'type': 'scatter',\n", - " 'uid': '2b2cb1c3-4a0e-4729-92d7-1167a376d732',\n", + " 'uid': '6d93394c-496d-4c52-8c79-e1a55e9bff0d',\n", " 'x': array([datetime.datetime(2015, 1, 1, 0, 0),\n", " datetime.datetime(2015, 2, 28, 23, 0)], dtype=object),\n", " 'xaxis': 'x',\n", - " 'y': array([25.74136 , 14.885769], dtype=float32),\n", + " 'y': array([35.735615, 26.46712 ], dtype=float32),\n", " 'yaxis': 'y'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'yearly',\n", " 'type': 'scatter',\n", - " 'uid': '965409ff-c889-48d6-bae0-936574b46f88',\n", + " 'uid': '91e577de-4754-46f9-a832-ff20851064d6',\n", " 'x': array([datetime.datetime(2017, 1, 1, 0, 0),\n", " datetime.datetime(2017, 1, 2, 0, 0),\n", " datetime.datetime(2017, 1, 3, 0, 0), ...,\n", @@ -2789,15 +2461,15 @@ " datetime.datetime(2017, 12, 30, 0, 0),\n", " datetime.datetime(2017, 12, 31, 0, 0)], dtype=object),\n", " 'xaxis': 'x2',\n", - " 'y': array([-48.182846, -50.157616, -51.64516 , ..., -38.81324 , -42.12695 ,\n", - " -45.080776], dtype=float32),\n", + " 'y': array([-1.7568997 , -2.1306572 , -2.4605272 , ..., -0.36249205, -0.8088371 ,\n", + " -1.2453859 ], dtype=float32),\n", " 'yaxis': 'y2'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'weekly',\n", " 'type': 'scatter',\n", - " 'uid': 'ba3b298a-be66-40c5-8264-ef5b716d06bf',\n", + " 'uid': 'fbb9e766-a826-4066-b0a7-9bbd7391dedd',\n", " 'x': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", @@ -2811,124 +2483,119 @@ " 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n", " 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]),\n", " 'xaxis': 'x3',\n", - " 'y': array([-3.04588795e+01, -3.54400597e+01, -3.99797134e+01, -4.43526154e+01,\n", - " -4.81006012e+01, -5.13112793e+01, -5.40508690e+01, -5.63052025e+01,\n", - " -5.81308632e+01, -5.93055077e+01, -6.00412445e+01, -6.02870750e+01,\n", - " -5.99476738e+01, -5.91321754e+01, -5.79114304e+01, -5.63154335e+01,\n", - " -5.43354149e+01, -5.17919426e+01, -4.89967918e+01, -4.60626144e+01,\n", - " -4.26569977e+01, -3.94107780e+01, -3.56459961e+01, -3.19178200e+01,\n", - " -2.82482281e+01, -2.45639687e+01, -2.09295406e+01, -1.71104145e+01,\n", - " -1.35622263e+01, -1.01446161e+01, -7.01516199e+00, -4.13522530e+00,\n", - " -1.30050945e+00, 1.12308824e+00, 3.24085855e+00, 5.06825876e+00,\n", - " 6.57082605e+00, 7.81317091e+00, 8.63896179e+00, 9.19974041e+00,\n", - " 9.46362782e+00, 9.35620880e+00, 8.95828724e+00, 8.39426994e+00,\n", - " 7.48970175e+00, 6.47928476e+00, 5.27563381e+00, 3.69769454e+00,\n", - " 2.23112702e+00, 5.54561198e-01, -9.77979958e-01, -2.75594759e+00,\n", - " -4.51804447e+00, -6.08968639e+00, -7.68710136e+00, -9.01867199e+00,\n", - " -1.04195356e+01, -1.15890474e+01, -1.25064888e+01, -1.32924528e+01,\n", - " -1.37729158e+01, -1.40936012e+01, -1.41434650e+01, -1.39483709e+01,\n", - " -1.34957485e+01, -1.28450623e+01, -1.19179354e+01, -1.07594900e+01,\n", - " -9.40466499e+00, -7.81799698e+00, -6.03295660e+00, -4.35296869e+00,\n", - " -2.36515474e+00, -4.06748503e-01, 1.75163662e+00, 3.73803329e+00,\n", - " 5.86639977e+00, 7.98941374e+00, 9.86133385e+00, 1.17813129e+01,\n", - " 1.33550673e+01, 1.49084988e+01, 1.61548271e+01, 1.73153419e+01,\n", - " 1.81719952e+01, 1.87052574e+01, 1.89303913e+01, 1.88749886e+01,\n", - " 1.85987663e+01, 1.79191532e+01, 1.69609566e+01, 1.57270803e+01,\n", - " 1.42448225e+01, 1.23795385e+01, 1.02109413e+01, 8.01280880e+00,\n", - " 5.49813223e+00, 2.89395428e+00, 5.37771396e-02, -3.00232887e+00,\n", - " -6.02308273e+00, -8.95809174e+00, -1.18841352e+01, -1.48639030e+01,\n", - " -1.78287868e+01, -2.03142662e+01, -2.28633080e+01, -2.51014042e+01,\n", - " -2.72459736e+01, -2.90301781e+01, -3.02393055e+01, -3.12937450e+01,\n", - " -3.19238663e+01, -3.21995316e+01, -3.19541264e+01, -3.12862282e+01,\n", - " -3.01354771e+01, -2.86372204e+01, -2.65992775e+01, -2.41690311e+01,\n", - " -2.13332558e+01, -1.81720409e+01, -1.45338621e+01, -1.03474331e+01,\n", - " -6.08415222e+00, -1.40853214e+00, 3.29980421e+00, 8.35535431e+00,\n", - " 1.37836151e+01, 1.89796047e+01, 2.41642418e+01, 2.93892212e+01,\n", - " 3.48417130e+01, 4.02113190e+01, 4.50756607e+01, 4.96866379e+01,\n", - " 5.40651703e+01, 5.83653717e+01, 6.21502800e+01, 6.54130478e+01,\n", - " 6.82579117e+01, 7.07123795e+01, 7.26812973e+01, 7.40298157e+01,\n", - " 7.47914352e+01, 7.49953308e+01, 7.46389465e+01, 7.36509781e+01,\n", - " 7.20302200e+01, 6.99344711e+01, 6.73451157e+01, 6.42106323e+01,\n", - " 6.03940163e+01, 5.62932777e+01, 5.17331505e+01, 4.69316330e+01,\n", - " 4.13555679e+01, 3.55166130e+01, 2.97218113e+01, 2.36869488e+01,\n", - " 1.76951599e+01, 1.11073389e+01, 4.76392841e+00, -1.35005784e+00,\n", - " -7.75571394e+00, -1.36067734e+01, -1.97326603e+01, -2.53451424e+01],\n", - " dtype=float32),\n", + " 'y': array([ 7.755577 , 7.5237527 , 7.224617 , 6.815465 , 6.314899 ,\n", + " 5.7965226 , 5.222597 , 4.60115 , 3.888913 , 3.192386 ,\n", + " 2.480555 , 1.7581066 , 0.9770121 , 0.21064605, -0.4857402 ,\n", + " -1.1512595 , -1.7806304 , -2.4084048 , -2.9580166 , -3.4257982 ,\n", + " -3.855735 , -4.175071 , -4.4489365 , -4.6284823 , -4.724985 ,\n", + " -4.7434278 , -4.6815186 , -4.530946 , -4.302239 , -4.0017447 ,\n", + " -3.6467676 , -3.2275088 , -2.7205632 , -2.1764457 , -1.6195372 ,\n", + " -1.0332191 , -0.3954899 , 0.27314964, 0.9262204 , 1.5400877 ,\n", + " 2.1368022 , 2.7368646 , 3.2993367 , 3.788525 , 4.2504687 ,\n", + " 4.6347704 , 4.9488535 , 5.214139 , 5.3700595 , 5.456191 ,\n", + " 5.4460826 , 5.3499737 , 5.1563053 , 4.8846197 , 4.5234604 ,\n", + " 4.1024094 , 3.5617476 , 2.9600484 , 2.3084648 , 1.5802894 ,\n", + " 0.81475705, -0.04111661, -0.90458584, -1.7624184 , -2.661798 ,\n", + " -3.53263 , -4.4244337 , -5.295709 , -6.1333375 , -6.931251 ,\n", + " -7.696276 , -8.319346 , -8.925283 , -9.427417 , -9.868788 ,\n", + " -10.188685 , -10.430543 , -10.565057 , -10.580534 , -10.500571 ,\n", + " -10.296801 , -9.9909 , -9.583059 , -9.07224 , -8.425532 ,\n", + " -7.701064 , -6.915518 , -6.0554175 , -5.111462 , -4.0590043 ,\n", + " -3.0526552 , -1.9540824 , -0.86603564, 0.3103157 , 1.494285 ,\n", + " 2.560513 , 3.6657102 , 4.700126 , 5.7225103 , 6.720365 ,\n", + " 7.6138163 , 8.40019 , 9.098643 , 9.724774 , 10.263643 ,\n", + " 10.650285 , 10.948275 , 11.123994 , 11.187449 , 11.121981 ,\n", + " 10.951801 , 10.656306 , 10.26268 , 9.729534 , 9.129556 ,\n", + " 8.439994 , 7.6128716 , 6.767242 , 5.794735 , 4.8272123 ,\n", + " 3.8182733 , 2.7837412 , 1.6946272 , 0.55752414, -0.5014487 ,\n", + " -1.5780605 , -2.5846512 , -3.583228 , -4.562805 , -5.417882 ,\n", + " -6.1975856 , -6.9047456 , -7.558581 , -8.112523 , -8.532403 ,\n", + " -8.8463745 , -9.0614 , -9.167561 , -9.156995 , -9.042185 ,\n", + " -8.828875 , -8.505391 , -8.057592 , -7.5299797 , -6.938876 ,\n", + " -6.275693 , -5.5465274 , -4.701907 , -3.8042374 , -2.929176 ,\n", + " -2.0293174 , -1.1140192 , -0.13168602, 0.7801494 , 1.7001534 ,\n", + " 2.5384672 , 3.4228206 , 4.248964 , 4.967284 , 5.6338806 ,\n", + " 6.2010517 , 6.737668 , 7.1634216 , 7.488221 , 7.7340417 ,\n", + " 7.8822103 , 7.9357567 , 7.8853316 ], dtype=float32),\n", " 'yaxis': 'y3'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'winter',\n", " 'type': 'scatter',\n", - " 'uid': '6deeeb47-3be1-4b89-a492-0e6bf00cdd53',\n", + " 'uid': '11f43753-196c-416c-8679-caab4f55210d',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x4',\n", - " 'y': array([-0.43253064, 0.38781527, 1.1857711 , ..., -2.2786856 , -1.8861564 ,\n", - " -1.1711025 ], dtype=float32),\n", + " 'y': array([ 1.5749581 , 0.68877584, -0.0385443 , ..., 3.668294 , 3.1973646 ,\n", + " 2.3746142 ], dtype=float32),\n", " 'yaxis': 'y4'},\n", " {'fill': 'none',\n", " 'line': {'color': '#2d92ff', 'width': 2},\n", " 'mode': 'lines',\n", " 'name': 'summer',\n", " 'type': 'scatter',\n", - " 'uid': 'b5b32b83-f643-4973-8db7-1797073f910f',\n", + " 'uid': '82f6c088-001e-487b-947b-2b32fbf6b06c',\n", " 'x': array([ 0, 1, 2, ..., 285, 286, 287]),\n", " 'xaxis': 'x5',\n", - " 'y': array([-21.528275, -22.261412, -22.15886 , ..., -17.52789 , -18.709206,\n", - " -20.179296], dtype=float32),\n", + " 'y': array([ 1.621103 , 0.41932815, -0.4366651 , ..., 4.0205083 , 3.520041 ,\n", + " 2.6508992 ], dtype=float32),\n", " 'yaxis': 'y5'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'AR',\n", " 'type': 'bar',\n", - " 'uid': 'dea67a1e-e499-4e43-a97e-0fb70895854a',\n", + " 'uid': 'c3664db2-136a-40d7-9103-625733dda176',\n", " 'width': 0.8,\n", " 'x': array([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x6',\n", - " 'y': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " 'y': array([-0.3051259 , -0.20342994, 0.05545649, 0.13164242, 0.2857538 ,\n", + " 0.06021553, -0.3720306 , -0.01876787, -0.00064597, 0.04794757],\n", + " dtype=float32),\n", " 'yaxis': 'y6'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Lagged Regressor \"temp\"',\n", " 'type': 'bar',\n", - " 'uid': '78f9e1a7-125b-4563-aaec-5ff3fe054d54',\n", + " 'uid': 'd28e2fd9-0522-488a-b89a-02b1296bae1a',\n", " 'width': 0.8,\n", " 'x': array([33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,\n", " 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),\n", " 'xaxis': 'x7',\n", - " 'y': array([ 0.4898405 , 1.880786 , 2.1640918 , -0.2928444 , 0.86554664,\n", - " -0.08299019, 0.70314807, 1.0031208 , 0.90763193, 0.4682139 ,\n", - " -1.1896216 , -1.6901426 , -1.2109685 , 0.9098389 , -0.64848685,\n", - " 0.9634216 , 0.91694885, 2.0049295 , 2.8199239 , 0.83436155,\n", - " 1.654415 , 2.4778936 , 0.64203995, 2.3134997 , 1.7692485 ,\n", - " 1.1947386 , 0.62410027, 0.79597855, 2.8871663 , 0.6992378 ,\n", - " 0.69474286, 1.997743 , 2.4678695 ], dtype=float32),\n", + " 'y': array([ 6.2832564e-01, -2.3228288e-01, 6.6017294e-01, 3.3255139e-01,\n", + " 5.0744390e-01, 1.1816436e-01, -1.5548144e-01, 1.6358766e-01,\n", + " -1.7810777e-01, 3.2371131e-01, 4.4875324e-01, -3.1604797e-01,\n", + " -1.3501082e-03, -9.3391158e-02, 8.1444037e-01, -7.3939008e-01,\n", + " 4.2238832e-01, 5.4563276e-02, 3.5837620e-01, -5.2361876e-02,\n", + " -5.4710191e-01, -7.3065239e-01, -3.4761795e-01, 4.7559822e-01,\n", + " 2.0330952e-02, 2.5780448e-01, 1.0076398e-01, 3.2984644e-01,\n", + " 2.2101782e-01, 2.5692052e-01, -8.7424242e-01, -8.4744475e-04,\n", + " 3.0343091e-01], dtype=float32),\n", " 'yaxis': 'y7'},\n", " {'marker': {'color': '#2d92ff'},\n", " 'name': 'Additive event',\n", " 'type': 'bar',\n", - " 'uid': 'df559219-3f3e-45df-864c-8eebbaeadb67',\n", + " 'uid': 'ec08d1ed-660d-43c4-8112-e6bd03345ae9',\n", " 'width': 0.8,\n", - " 'x': array(['Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1', 'Veterans Day_+0',\n", - " 'Veterans Day_+1', 'Veterans Day_-1', 'Martin Luther King Jr. Day_+0',\n", + " 'x': array(['Memorial Day_+0', 'Memorial Day_+1', 'Memorial Day_-1',\n", + " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", + " \"Washington's Birthday_-1\", 'Columbus Day_+0', 'Columbus Day_+1',\n", + " 'Columbus Day_-1', 'Martin Luther King Jr. Day_+0',\n", " 'Martin Luther King Jr. Day_+1', 'Martin Luther King Jr. Day_-1',\n", " \"New Year's Day_+0\", \"New Year's Day_+1\", \"New Year's Day_-1\",\n", - " \"Washington's Birthday_+0\", \"Washington's Birthday_+1\",\n", - " \"Washington's Birthday_-1\", 'Independence Day_+0',\n", - " 'Independence Day_+1', 'Independence Day_-1', 'Memorial Day_+0',\n", - " 'Memorial Day_+1', 'Memorial Day_-1', 'Columbus Day_+0',\n", - " 'Columbus Day_+1', 'Columbus Day_-1', 'Thanksgiving_+0',\n", - " 'Thanksgiving_+1', 'Thanksgiving_-1', 'Christmas Day_+0',\n", - " 'Christmas Day_+1', 'Christmas Day_-1'], dtype=object),\n", + " 'Thanksgiving_+0', 'Thanksgiving_+1', 'Thanksgiving_-1',\n", + " 'Christmas Day_+0', 'Christmas Day_+1', 'Christmas Day_-1',\n", + " 'Veterans Day_+0', 'Veterans Day_+1', 'Veterans Day_-1',\n", + " 'Independence Day_+0', 'Independence Day_+1', 'Independence Day_-1',\n", + " 'Labor Day_+0', 'Labor Day_+1', 'Labor Day_-1'], dtype=object),\n", " 'xaxis': 'x8',\n", - " 'y': [-6.505150318145752, 3.078960418701172, -3.0603911876678467,\n", - " -1.937178373336792, -0.9162442684173584, -3.922412395477295,\n", - " -43.94681930541992, 48.15086364746094, -47.38690948486328,\n", - " 0.037264175713062286, 4.726099967956543, 0.49383774399757385,\n", - " -0.8578076362609863, -8.193577766418457, 8.767333030700684,\n", - " -2.2701916694641113, -0.19705480337142944, -1.0239486694335938,\n", - " 2.4958767890930176, 5.431707859039307, -3.5964465141296387,\n", - " -3.9246764183044434, 3.1682686805725098, 1.5535764694213867,\n", - " -3.401339054107666, 0.7919614911079407, 1.1661392450332642,\n", - " 0.8668169975280762, -6.069958686828613, -1.4564253091812134],\n", + " 'y': [2.815302848815918, -1.2508420944213867, -2.8079898357391357,\n", + " 1.0754282474517822, 1.3034065961837769, 3.282367706298828,\n", + " -0.20270073413848877, 5.353240489959717, 3.5992088317871094,\n", + " 1.0151381492614746, 2.7972025871276855, 0.7658103108406067,\n", + " 0.6110072731971741, -3.8502631187438965, -1.1796778440475464,\n", + " -1.184556484222412, -1.5890964269638062, 6.425841331481934,\n", + " -0.3609931468963623, 5.461928367614746, -3.2548437118530273,\n", + " -7.143401622772217, 4.949648857116699, -2.3242950439453125,\n", + " -3.5402631759643555, 1.2076290845870972, 5.756880283355713,\n", + " -1.5092906951904297, -2.149829149246216, 2.4397971630096436],\n", " 'yaxis': 'y8'}],\n", " 'layout': {'autosize': True,\n", " 'font': {'size': 10},\n",