Skip to content

Commit

Permalink
[Major] Support Custom Learning Rate Scheduler (#1637)
Browse files Browse the repository at this point in the history
* enable re-training

* update scheduler

* change scheduler for continued training

* add test

* fix metrics logging

* include feedback

* get correct optimizer states

* fix tests

* enable setting the scheduler

* update for onecyclelr

* add tests and adapt docstring

* fix array mismatch

* robustify scheduler config

* clean up train config setup

* restructure train model config

* remove continue train

* fix regularization

* fix regularization of holidays test

* address events reg test

* fixed reg tests

* fix save

* move to debug folder

* debugging

* fix custom lr

* set finding lr arg

* add logging of progress and lr

* update lr schedulers to use epochs

* fix lr-finder

* improve num_training calculation for lr-finder and remove loss-min for lr calc

* large changeset - isolate lr-finder

* fix progressbar

* remove dataloader from model

* fix callbacks ProgressBar

* fixing tuner

* fix tuner

* readd prep_or_copy

* undo copy of model, loader, trainer

* add comment about separate lr finder copies

* improve lr finder comment

---------

Co-authored-by: Constantin Weberpals <constantin.weberpals@tum.de>
  • Loading branch information
ourownstory and weberpals authored Aug 30, 2024
1 parent 4744ef1 commit 4459338
Show file tree
Hide file tree
Showing 17 changed files with 1,759 additions and 1,217 deletions.
170 changes: 106 additions & 64 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
@dataclass
class Model:
lagged_reg_layers: Optional[List[int]]
quantiles: Optional[List[float]] = None

def setup_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be provided as list."
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)


@dataclass
Expand Down Expand Up @@ -92,30 +110,31 @@ class Train:
batch_size: Optional[int]
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable]
optimizer: Union[str, Type[torch.optim.Optimizer]]
quantiles: List[float] = field(default_factory=list)
# quantiles: List[float] = field(default_factory=list)
optimizer_args: dict = field(default_factory=dict)
scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None
scheduler_args: dict = field(default_factory=dict)
early_stopping: Optional[bool] = False
newer_samples_weight: float = 1.0
newer_samples_start: float = 0.0
reg_delay_pct: float = 0.5
reg_lambda_trend: Optional[float] = None
trend_reg_threshold: Optional[Union[bool, float]] = None
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
lr_finder_args: dict = field(default_factory=dict)
pl_trainer_config: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
self.set_quantiles()
assert self.newer_samples_weight >= 1.0
assert self.newer_samples_start >= 0.0
assert self.newer_samples_start < 1.0
self.set_loss_func()
self.set_optimizer()
self.set_scheduler()
# self.set_loss_func(self.quantiles)

def set_loss_func(self):
# called in TimeNet configure_optimizers:
# self.set_optimizer()
# self.set_scheduler()

def set_loss_func(self, quantiles: List[float]):
if isinstance(self.loss_func, str):
if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]:
# keeping 'huber' for backwards compatiblility, though not identical
Expand All @@ -135,25 +154,8 @@ def set_loss_func(self):
self.loss_func_name = type(self.loss_func).__name__
else:
raise NotImplementedError(f"Loss function {self.loss_func} not found")
if len(self.quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles)

def set_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)
if len(quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles)

def set_auto_batch_epoch(
self,
Expand Down Expand Up @@ -182,51 +184,88 @@ def set_optimizer(self):
"""
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding
torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet.
Parameters
----------
optimizer_name : int
Object provided to NeuralProphet as optimizer.
optimizer_args : dict
Arguments for the optimizer.
"""
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config(
self.optimizer, self.optimizer_args
)
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adamw":
# Tends to overfit, but reliable
self.optimizer = torch.optim.AdamW
self.optimizer_args["weight_decay"] = 1e-3
elif self.optimizer.lower() == "sgd":
# better validation performance, but diverges sometimes
self.optimizer = torch.optim.SGD
self.optimizer_args["momentum"] = 0.9
self.optimizer_args["weight_decay"] = 1e-4
else:
raise ValueError(
f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class."
)
elif not issubclass(self.optimizer, torch.optim.Optimizer):
raise ValueError("The provided optimizer is not supported.")

def set_scheduler(self):
"""
Set the scheduler and scheduler args.
Set the scheduler and scheduler arg depending on the user selection.
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
)

def set_lr_finder_args(self, dataset_size, num_batches):
"""
Set the lr_finder_args.
This is the range of learning rates to test.
"""
num_training = 100 + int(np.log10(dataset_size) * 20)
if num_batches < num_training:
log.warning(
f"Learning rate finder: The number of batches ({num_batches}) is too small than the required number \
for the learning rate finder ({num_training}). The results might not be optimal."
)
# num_training = num_batches
self.lr_finder_args.update(
{
"min_lr": 1e-7,
"max_lr": 10,
"num_training": num_training,
"early_stop_threshold": None,
}
)
if self.scheduler is None:
log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.")
self.scheduler = "exponentiallr"

if isinstance(self.scheduler, str):
if self.scheduler.lower() in ["onecycle", "onecyclelr"]:
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
defaults = {
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
elif self.scheduler.lower() == "steplr":
self.scheduler = torch.optim.lr_scheduler.StepLR
defaults = {
"step_size": 10,
"gamma": 0.1,
}
elif self.scheduler.lower() == "exponentiallr":
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
defaults = {
"gamma": 0.9,
}
elif self.scheduler.lower() == "cosineannealinglr":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
defaults = {
"T_max": 50,
}
elif self.scheduler.lower() == "cosineannealingwarmrestarts":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
defaults = {
"T_0": 5,
"T_mult": 2,
}
else:
raise NotImplementedError(
f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class."
)
if self.scheduler_args is not None:
defaults.update(self.scheduler_args)
self.scheduler_args = defaults
else:
assert issubclass(
self.scheduler, torch.optim.lr_scheduler.LRScheduler
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler"

def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0):
def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0):
# Ignore type warning of epochs possibly being None (does not work with dataclasses)
progress = (e + iter_progress) / float(self.epochs) # type: ignore
if reg_start_pct == reg_full_pct:
reg_progress = float(progress > reg_start_pct)
else:
Expand All @@ -239,6 +278,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re
delay_weight = 1
return delay_weight

def set_batches_per_epoch(self, batches_per_epoch: int):
self.batches_per_epoch = batches_per_epoch


@dataclass
class Trend:
Expand Down
Loading

0 comments on commit 4459338

Please sign in to comment.