Skip to content

Commit

Permalink
Merge branch 'main' into season-glocal-reg
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory authored Jun 26, 2024
2 parents 274c660 + 92a0198 commit 4d7f519
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 264 deletions.
30 changes: 29 additions & 1 deletion docs/source/tutorials/tutorial10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"collapsed": false
},
"source": [
"Validation is performed by passing the validation set to the fit method during training. The resulting metrics show the performance of the model compared to our validation set."
"Validation is performed by passing the validation set to the fit method during training. The resulting metrics show the performance of the model compared to our validation set. "
]
},
{
Expand Down Expand Up @@ -356,6 +356,34 @@
"set_random_seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Aditionally, it is important to make sure to set the flag `deterministic` in the `fit` function to `True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from neuralprophet import NeuralProphet\n",
"\n",
"# Load the dataset from the CSV file using pandas\n",
"df = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv\")\n",
"\n",
"# Model and prediction\n",
"m = NeuralProphet()\n",
"\n",
"df_train, df_val = m.split_df(df, valid_p=0.2)\n",
"\n",
"# Set the deterministic flag to True\n",
"metrics = m.fit(df_train, validation_df=df_val, progress=None, deterministic=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
5 changes: 5 additions & 0 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def fit(
checkpointing: bool = False,
continue_training: bool = False,
num_workers: int = 0,
deterministic: bool = False,
):
"""Train, and potentially evaluate model.
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def fit(
checkpointing_enabled=checkpointing,
continue_training=continue_training,
num_workers=num_workers,
deterministic=deterministic,
)
else:
df_val, _, _, _ = df_utils.prep_or_copy_df(validation_df)
Expand All @@ -1093,6 +1095,7 @@ def fit(
checkpointing_enabled=checkpointing,
continue_training=continue_training,
num_workers=num_workers,
deterministic=deterministic,
)

# Show training plot
Expand Down Expand Up @@ -2714,6 +2717,7 @@ def _train(
checkpointing_enabled: bool = False,
continue_training=False,
num_workers=0,
deterministic: bool = False,
):
"""
Execute model training procedure for a configured number of epochs.
Expand Down Expand Up @@ -2771,6 +2775,7 @@ def _train(
metrics_enabled=metrics_enabled,
checkpointing_enabled=checkpointing_enabled,
num_batches_per_epoch=len(train_loader),
deterministic=deterministic,
)

# Tune hyperparams and train
Expand Down
5 changes: 5 additions & 0 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.seed import seed_everything

from neuralprophet import utils_torch
from neuralprophet.logger import ProgressBar
Expand Down Expand Up @@ -710,6 +711,7 @@ def set_random_seed(seed: int = 0):
"""
np.random.seed(seed)
torch.manual_seed(seed)
seed_everything(seed, workers=True)


def set_logger_level(logger, log_level, include_handlers=False):
Expand Down Expand Up @@ -818,6 +820,7 @@ def configure_trainer(
metrics_enabled: bool = False,
checkpointing_enabled: bool = False,
num_batches_per_epoch: int = 100,
deterministic: bool = False,
):
"""
Configures the PyTorch Lightning trainer.
Expand Down Expand Up @@ -888,6 +891,8 @@ def configure_trainer(
else:
config["logger"] = False

config["deterministic"] = deterministic

# Configure callbacks
callbacks = []
has_custom_callbacks = True if "callbacks" in config else False
Expand Down
Loading

0 comments on commit 4d7f519

Please sign in to comment.