Skip to content

Commit

Permalink
fix lr-finder
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Aug 29, 2024
1 parent c83e7cc commit b79b7e1
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 726 deletions.
5 changes: 3 additions & 2 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ def set_lr_finder_args(self, dataset_size, num_batches):
# num_training = num_batches
self.lr_finder_args.update(
{
"min_lr": 1e-7,
"max_lr": 10,
"min_lr": 1e-8,
"max_lr": 1e1,
"num_training": num_training,
"early_stop_threshold": None,
"mode": "exponential",
}
)

Expand Down
77 changes: 33 additions & 44 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,7 +2796,10 @@ def _train(
# Set up data the training dataloader
df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be removed?
train_loader = self._init_train_loader(df, num_workers)
dataset_size = len(df) # train_loader.dataset
dataset_size = len(train_loader.dataset) # df
batches_per_epoch = len(train_loader)
log.info(f"Dataset size: {dataset_size}")
log.info(f"Number of batches per training epoch: {batches_per_epoch}")

# Internal flag to check if validation is enabled
validation_enabled = df_val is not None
Expand All @@ -2818,55 +2821,41 @@ def _train(
deterministic=deterministic,
)

# Find suitable learning rate
if not self.config_train.learning_rate:
log.info("No Learning Rate provided. Activating learning rate finder")
# Set parameters for the learning rate finder
self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=batches_per_epoch)
log.info(f"Learning rate finder ---- ARGs: {self.config_train.lr_finder_args}")
self.model.finding_lr = True
tuner = Tuner(self.trainer)
lr_finder = tuner.lr_find(
model=self.model,
train_dataloaders=train_loader,
# val_dataloaders=val_loader, # not used, but may lead to Lightning bug if not provided
**self.config_train.lr_finder_args,
)
# Estimate the optimal learning rate from the loss curve
assert lr_finder is not None
_, _, lr_suggested = utils.smooth_loss_and_suggest(lr_finder)
self.model.learning_rate = lr_suggested
self.config_train.learning_rate = lr_suggested
log.info(f"Learning rate finder suggested learning rate: {lr_suggested}")
self.model.finding_lr = False

# Tune hyperparams and train
if validation_enabled:
# Set up data the validation dataloader
df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
val_loader = self._init_val_loader(df_val)

if not self.config_train.learning_rate:
# Find suitable learning rate
# Set parameters for the learning rate finder
self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader))
self.model.finding_lr = True
tuner = Tuner(self.trainer)
lr_finder = tuner.lr_find(
model=self.model,
train_dataloaders=train_loader,
# val_dataloaders=val_loader, # not be used, but may lead to Lightning bug if not provided
**self.config_train.lr_finder_args,
)
# Estimate the optimal learning rate from the loss curve
assert lr_finder is not None
_, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder)
self.model.finding_lr = False
start = time.time()
self.trainer.fit(
self.model,
train_loader,
val_loader,
)
else:
if not self.config_train.learning_rate:
# Find suitable learning rate
# Set parameters for the learning rate finder
self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader))
self.model.finding_lr = True
tuner = Tuner(self.trainer)
lr_finder = tuner.lr_find(
model=self.model,
train_dataloaders=train_loader,
**self.config_train.lr_finder_args,
)
assert lr_finder is not None
# Estimate the optimal learning rate from the loss curve
_, _, self.model.learning_rate = utils.smooth_loss_and_suggest(lr_finder)
self.model.finding_lr = False
start = time.time()
self.trainer.fit(
self.model,
train_loader,
)
self.model.finding_lr = False
start = time.time()
self.trainer.fit(
model=self.model,
train_dataloaders=train_loader,
val_dataloaders=val_loader if validation_enabled else None,
)

log.debug("Train Time: {:8.3f}".format(time.time() - start))

Expand Down
11 changes: 7 additions & 4 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,8 @@ def loss_func(self, inputs, predicted, targets):

def training_step(self, batch, batch_idx):
inputs, targets, meta = batch
epoch_float = self.trainer.current_epoch + float(batch_idx / self.train_steps_per_epoch)
self.train_progress = epoch_float / self.config_train.epochs
epoch_float = self.trainer.current_epoch + batch_idx / float(self.train_steps_per_epoch)
self.train_progress = epoch_float / float(self.config_train.epochs)
# Global-local
if self.meta_used_in_model:
meta_name_tensor = torch.tensor([self.id_dict[i] for i in meta["df_name"]], device=self.device)
Expand All @@ -796,7 +796,10 @@ def training_step(self, batch, batch_idx):
optimizer.step()

scheduler = self.lr_schedulers()
scheduler.step(epoch=epoch_float)
if self.finding_lr:
scheduler.step()
else:
scheduler.step(epoch=epoch_float)

if self.finding_lr:
# Manually track the loss for the lr finder
Expand Down Expand Up @@ -874,7 +877,7 @@ def configure_optimizers(self):

# Optimizer
if self.finding_lr and self.learning_rate is None:
self.learning_rate = self.config_train.lr_finder_args["min_lr"]
self.learning_rate = 0.1
optimizer = self.config_train.optimizer(
self.parameters(),
lr=self.learning_rate,
Expand Down
39 changes: 27 additions & 12 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,17 +771,17 @@ def smooth_loss_and_suggest(lr_finder, window=10):
"""
lr_finder_results = lr_finder.results
lr = lr_finder_results["lr"]
loss = lr_finder_results["loss"]
loss = np.array(lr_finder_results["loss"])
# Derive window size from num lr searches, ensure window is divisible by 2
# half_window = math.ceil(round(len(loss) * 0.1) / 2)
half_window = math.ceil(window / 2)
# Pad sequence and initialialize hamming filter
loss = np.pad(np.array(loss), pad_width=half_window, mode="edge")
window = np.hamming(half_window * 2)
loss = np.pad(loss, pad_width=half_window, mode="edge")
hamming_window = np.hamming(2 * half_window)
# Convolve the over the loss distribution
try:
loss = np.convolve(
window / window.sum(),
loss_smooth = np.convolve(
hamming_window / hamming_window.sum(),
loss,
mode="valid",
)[1:]
Expand All @@ -790,26 +790,41 @@ def smooth_loss_and_suggest(lr_finder, window=10):
f"The number of loss values ({len(loss)}) is too small to apply smoothing with a the window size of "
f"{window}."
)

# Suggest the lr with steepest negative gradient
try:
# Find the steepest gradient and the minimum loss after that
suggestion = lr[np.argmin(np.gradient(loss))]
suggestion_steepest = lr[np.argmin(np.gradient(loss_smooth))]
suggestion_minimum = lr[np.argmin(loss_smooth)]
except ValueError:
log.error(
f"The number of loss values ({len(loss)}) is too small to estimate a learning rate. Increase the number of "
"samples or manually set the learning rate."
)
raise
suggestion_default = lr_finder.suggestion(skip_begin=10, skip_end=3)
if suggestion is not None and suggestion_default is not None:
log_suggestion_smooth = np.log(suggestion)
# get the tuner's default suggestion
suggestion_default = lr_finder.suggestion(skip_begin=20, skip_end=10)

log.info(f"Learning rate finder ---- default suggestion: {suggestion_default}")
log.info(f"Learning rate finder ---- steepest: {suggestion_steepest}")
log.info(f"Learning rate finder ---- minimum: {suggestion_minimum}")
if suggestion_steepest is not None and suggestion_minimum is not None and suggestion_default is not None:
log_suggestion_smooth = np.log(suggestion_steepest)
log_suggestion_minimum = np.log(suggestion_minimum)
log_suggestion_default = np.log(suggestion_default)
lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_default) / 2)
elif suggestion is None and suggestion_default is None:
lr_suggestion = np.exp((log_suggestion_smooth + log_suggestion_minimum + log_suggestion_default) / 3)
log.info(f"Learning rate finder ---- log-avg: {lr_suggestion}")
elif suggestion_steepest is None and suggestion_default is None:
log.error("Automatic learning rate test failed. Please set manually the learning rate.")
raise
else:
lr_suggestion = suggestion if suggestion is not None else suggestion_default
lr_suggestion = suggestion_steepest if suggestion_steepest is not None else suggestion_default

log.info(f"Learning rate finder ---- returning: {lr_suggestion}")
log.info(f"Learning rate finder ---- LR (start): {lr[:5]}")
log.info(f"Learning rate finder ---- LR (end): {lr[-5:]}")
log.info(f"Learning rate finder ---- LOSS (start): {loss[:5]}")
log.info(f"Learning rate finder ---- LOSS (end): {loss[-5:]}")
return (loss, lr, lr_suggestion)


Expand Down
Loading

0 comments on commit b79b7e1

Please sign in to comment.