Skip to content

Commit

Permalink
flatten training function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Aug 30, 2024
1 parent 565c7d5 commit 72c011b
Showing 1 changed file with 168 additions and 50 deletions.
218 changes: 168 additions & 50 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,24 +1068,6 @@ def fit(
else:
self.metrics_logger = None

# Pre-processing
# Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled)
df, _, _, self.id_list = df_utils.prep_or_copy_df(df)
df = _check_dataframe(self, df, check_y=True, exogenous=True)
self.data_freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq)
df = _handle_missing_data(
df=df,
freq=self.data_freq,
n_lags=self.n_lags,
n_forecasts=self.n_forecasts,
config_missing=self.config_missing,
config_regressors=self.config_regressors,
config_lagged_regressors=self.config_lagged_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
predicting=False,
)

# Setup for global-local modelling: If there is only a single time series, then self.id_list = ['__df__']
self.num_trends_modelled = len(self.id_list) if self.config_trend.trend_global_local == "local" else 1
self.num_seasonalities_modelled = (
Expand Down Expand Up @@ -1115,18 +1097,31 @@ def fit(
"computed for any future time, independent of lagged values"
)

# Training
if validation_df is None:
metrics_df = self._train(
df,
progress_bar_enabled=bool(progress),
metrics_enabled=bool(self.metrics),
checkpointing_enabled=checkpointing,
num_workers=num_workers,
deterministic=deterministic,
)
else:
df_val, _, _, _ = df_utils.prep_or_copy_df(validation_df)
##### Data Setup, and Training Setup #####
# Set up train dataset and data dependent configurations
df = self._train_data_setup(df)
# Note: _create_dataset() needs to be called after set_auto_seasonalities()
dataset = _create_dataset(self, df, predict_mode=False, prediction_frequency=self.prediction_frequency)
# Determine the max_number of epochs
self.config_train.set_auto_batch_epoch(n_data=len(dataset))

# Set up DataLoaders: Train
loader = DataLoader(
dataset,
batch_size=self.config_train.batch_size,
shuffle=True,
num_workers=num_workers,
)
self.config_train.set_batches_per_epoch(len(loader))
log.info(f"Train Dataset size: {len(dataset)}")
log.info(f"Number of batches per training epoch: {len(loader)}")

# Set up DataLoaders: Validation
validation_enabled = validation_df is not None and isinstance(validation_df, pd.DataFrame)
if validation_enabled:
# df_val = self._val_data_setup(validation_df)
df_val = validation_df
df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
df_val = _check_dataframe(self, df_val, check_y=False, exogenous=False)
df_val = _handle_missing_data(
df=df_val,
Expand All @@ -1140,29 +1135,103 @@ def fit(
config_seasonality=self.config_seasonality,
predicting=False,
)
metrics_df = self._train(
df,
df_val=df_val,
progress_bar_enabled=bool(progress),
metrics_enabled=bool(self.metrics),
checkpointing_enabled=checkpointing,
num_workers=num_workers,
deterministic=deterministic,
# df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
df_val = _normalize(df=df_val, config_normalization=self.config_normalization)
dataset_val = _create_dataset(self, df_val, predict_mode=False)
loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False)

# Init the Trainer
self.trainer, checkpoint_callback = utils_lightning.configure_trainer(
config_train=self.config_train,
metrics_logger=self.metrics_logger,
early_stopping_target="Loss_val" if validation_enabled else "Loss",
accelerator=self.accelerator,
progress_bar_enabled=bool(progress),
metrics_enabled=bool(self.metrics),
checkpointing_enabled=checkpointing,
num_batches_per_epoch=len(loader),
deterministic=deterministic,
)

# Set up the model for training
if not self.fitted:
self.model = self._init_model()

# Find suitable learning rate if not set
if self.config_train.learning_rate is None:
assert not self.fitted, "Learning rate must be provided for re-training a fitted model."

# Init a separate Model, Loader and Trainer copy for LR finder (optional, done for safety)
# Note Leads to a CUDA issue. Needs to be fixed before enabling this feature.
# model_lr_finder = self._init_model()
# loader_lr_finder = DataLoader(
# dataset,
# batch_size=self.config_train.batch_size,
# shuffle=True,
# num_workers=num_workers,
# )
# trainer_lr_finder, _ = utils_lightning.configure_trainer(
# config_train=self.config_train,
# metrics_logger=self.metrics_logger,
# early_stopping_target="Loss",
# accelerator=self.accelerator,
# progress_bar_enabled=progress_bar_enabled,
# metrics_enabled=False,
# checkpointing=False,
# num_batches_per_epoch=len(loader),
# deterministic=deterministic,
# )

# Setup and execute LR finder
suggested_lr = utils_lightning.find_learning_rate(
model=self.model, # model_lr_finder,
loader=loader, # loader_lr_finder,
trainer=self.trainer, # trainer_lr_finder,
train_epochs=self.config_train.epochs,
)
# Clean up the LR finder copies of Model, Loader and Trainer
# del model_lr_finder, loader_lr_finder, trainer_lr_finder

# Save the suggested learning rate
self.config_train.learning_rate = suggested_lr
self.model.finding_lr = False

# Execute Training Loop
start = time.time()
self.trainer.fit(
model=self.model,
train_dataloaders=loader,
val_dataloaders=loader_val if validation_enabled else None,
)
log.info("Train Time: {:8.3f}".format(time.time() - start))
self.fitted = True

# Load best model from checkpoint if end state not best
if checkpoint_callback is not None:
if checkpoint_callback.best_model_score < checkpoint_callback.current_score:
log.info(
f"Loading best model with score {checkpoint_callback.best_model_score} from checkpoint (latest \
score is {checkpoint_callback.current_score})"
)
self.model = time_net.TimeNet.load_from_checkpoint(checkpoint_callback.best_model_path)

# Return metrics collected in logger as dataframe
metrics_df = pd.DataFrame(self.metrics_logger.history) if bool(self.metrics) else None

# Show training plot
if progress == "plot":
assert metrics_df is not None
if validation_df is None:
fig = pyplot.plot(metrics_df[["Loss"]])
if metrics_df is None:
log.error("Metrics must be enabled to show training progress plot.")
else:
fig = pyplot.plot(metrics_df[["Loss", "Loss_val"]])
# Only display the plot if the session is interactive, eg. do not show in github actions since it
# causes an error in the Windows and MacOS environment
if matplotlib.is_interactive():
fig.show()
if validation_df is None:
fig = pyplot.plot(metrics_df[["Loss"]])
else:
fig = pyplot.plot(metrics_df[["Loss", "Loss_val"]])
# Only display the plot if the session is interactive, eg. do not show in github actions since it
# causes an error in the Windows and MacOS environment
if matplotlib.is_interactive():
fig.show()

self.fitted = True
return metrics_df

def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, auto_extend=True):
Expand Down Expand Up @@ -2684,7 +2753,7 @@ def _init_model(self):
log.debug(model)
return model

def _data_setup(self, df):
def _train_data_setup(self, df):
"""Executes data preparation steps and initiates training procedure.
Parameters
Expand All @@ -2698,7 +2767,24 @@ def _data_setup(self, df):
-------
torch DataLoader
"""
df, _, _, _ = df_utils.prep_or_copy_df(df)
# Data Pre-processing
# Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled)
df, _, _, self.id_list = df_utils.prep_or_copy_df(df)
df = _check_dataframe(self, df, check_y=True, exogenous=True)
self.data_freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq)
df = _handle_missing_data(
df=df,
freq=self.data_freq,
n_lags=self.n_lags,
n_forecasts=self.n_forecasts,
config_missing=self.config_missing,
config_regressors=self.config_regressors,
config_lagged_regressors=self.config_lagged_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
predicting=False,
)
# df, _, _, _ = df_utils.prep_or_copy_df(df)

if not self.fitted:
# Initialize data normalization parameters
Expand Down Expand Up @@ -2734,6 +2820,7 @@ def _data_setup(self, df):
def _train(
self,
df: pd.DataFrame,
freq: str = "auto",
df_val: Optional[pd.DataFrame] = None,
progress_bar_enabled: bool = True,
metrics_enabled: bool = False,
Expand Down Expand Up @@ -2764,8 +2851,25 @@ def _train(
pd.DataFrame
metrics
"""
# Pre-processing
# Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled)
df, _, _, self.id_list = df_utils.prep_or_copy_df(df)
df = _check_dataframe(self, df, check_y=True, exogenous=True)
self.data_freq = df_utils.infer_frequency(df, n_lags=self.max_lags, freq=freq)
df = _handle_missing_data(
df=df,
freq=self.data_freq,
n_lags=self.n_lags,
n_forecasts=self.n_forecasts,
config_missing=self.config_missing,
config_regressors=self.config_regressors,
config_lagged_regressors=self.config_lagged_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
predicting=False,
)
# Set up train dataset and data dependent configurations
df = self._data_setup(df)
df = self._train_data_setup(df)
# Note: _create_dataset() needs to be called after set_auto_seasonalities()
dataset = _create_dataset(self, df, predict_mode=False, prediction_frequency=self.prediction_frequency)
# Determine the max_number of epochs
Expand All @@ -2786,6 +2890,20 @@ def _train(
validation_enabled = df_val is not None and isinstance(df_val, pd.DataFrame)
if validation_enabled:
df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
df_val = _check_dataframe(self, df_val, check_y=False, exogenous=False)
df_val = _handle_missing_data(
df=df_val,
freq=self.data_freq,
n_lags=self.n_lags,
n_forecasts=self.n_forecasts,
config_missing=self.config_missing,
config_regressors=self.config_regressors,
config_lagged_regressors=self.config_lagged_regressors,
config_events=self.config_events,
config_seasonality=self.config_seasonality,
predicting=False,
)
# df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
df_val = _normalize(df=df_val, config_normalization=self.config_normalization)
dataset_val = _create_dataset(self, df_val, predict_mode=False)
loader_val = DataLoader(dataset_val, batch_size=min(1024, len(dataset_val)), shuffle=False, drop_last=False)
Expand Down

0 comments on commit 72c011b

Please sign in to comment.