Skip to content

Commit

Permalink
Merge branch 'main' into custom-lr-scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory authored Aug 28, 2024
2 parents a87651c + 00301ad commit c83e7cc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
23 changes: 18 additions & 5 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def __init__(
self.data_params = None

# Pytorch Lightning Trainer
self.metrics_logger = MetricsLogger(save_dir=os.getcwd())
self.accelerator = accelerator

# set during prediction
Expand Down Expand Up @@ -934,6 +933,7 @@ def fit(
early_stopping: bool = False,
minimal: bool = False,
metrics: Optional[np_types.CollectMetricsMode] = None,
metrics_log_dir: Optional[str] = None,
progress: Optional[str] = "bar",
checkpointing: bool = False,
num_workers: int = 0,
Expand Down Expand Up @@ -1049,16 +1049,29 @@ def fit(
number of epochs to train for."
)

if metrics:
# Setup Metrics
if metrics is not None:
self.metrics = utils_metrics.get_metrics(metrics)

if progress == "plot" and not metrics:
log.info("Progress plot requires metrics to be enabled. Disabling progress plot.")
progress = None
if progress == "plot" and not self.metrics:
log.info("Progress plot requires metrics to be enabled. Setting progress to bar.")
progress = "bar"

if not self.config_normalization.global_normalization:
log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.")

if metrics_log_dir is not None and not self.metrics:
log.error("Metrics are disabled. Ignoring provided logging directory.")
metrics_log_dir = None
if metrics_log_dir is None and self.metrics:
log.warning("Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD")
metrics_log_dir = os.getcwd()

if self.metrics:
self.metrics_logger = MetricsLogger(save_dir=metrics_log_dir)
else:
self.metrics_logger = None

# Pre-processing
# Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled)
df, _, _, self.id_list = df_utils.prep_or_copy_df(df)
Expand Down
6 changes: 4 additions & 2 deletions neuralprophet/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def get_metrics(metric_input):
Dict of names of torchmetrics.Metric metrics
"""
if metric_input is None:
return {}
return False
elif metric_input is False:
return False
elif metric_input is True:
return {"MAE": METRICS["MAE"], "RMSE": METRICS["RMSE"]}
elif isinstance(metric_input, str):
Expand All @@ -51,5 +53,5 @@ def get_metrics(metric_input):
"All metrics must be valid names of torchmetrics.Metric objects."
)
return {k: [v, {}] for k, v in metric_input.items()}
elif metric_input is not False:
else:
raise ValueError("Received unsupported argument for collect_metrics.")
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ def test_create_dummy_datestamps():
m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR)
_ = m.fit(df_dummy)
_ = m.make_future_dataframe(df_dummy, periods=365, n_historic_predictions=True)


def test_no_log():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR)
_ = m.fit(df, metrics=False, metrics_log_dir=False)

0 comments on commit c83e7cc

Please sign in to comment.