Skip to content

Commit

Permalink
restructure train model config
Browse files Browse the repository at this point in the history
  • Loading branch information
ourownstory committed Aug 24, 2024
1 parent 6a74680 commit 420f8a6
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 180 deletions.
48 changes: 24 additions & 24 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
@dataclass
class Model:
lagged_reg_layers: Optional[List[int]]
quantiles: Optional[List[float]] = None

def setup_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be provided as list."
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)


@dataclass
Expand Down Expand Up @@ -92,7 +110,7 @@ class Train:
batch_size: Optional[int]
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable]
optimizer: Union[str, Type[torch.optim.Optimizer]]
quantiles: List[float] = field(default_factory=list)
# quantiles: List[float] = field(default_factory=list)
optimizer_args: dict = field(default_factory=dict)
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None
scheduler_args: dict = field(default_factory=dict)
Expand All @@ -106,20 +124,19 @@ class Train:
lr_finder_args: dict = field(default_factory=dict)
optimizer_state: dict = field(default_factory=dict)
continue_training: bool = False
trainer_config: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
# self.set_quantiles()
assert self.newer_samples_weight >= 1.0
assert self.newer_samples_start >= 0.0
assert self.newer_samples_start < 1.0
self.set_loss_func()
# self.set_loss_func(self.quantiles)

# called in TimeNet configure_optimizers:
# self.set_optimizer()
# self.set_scheduler()

def set_loss_func(self):
def set_loss_func(self, quantiles: List[float]):
if isinstance(self.loss_func, str):
if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]:
# keeping 'huber' for backwards compatiblility, though not identical
Expand All @@ -139,25 +156,8 @@ def set_loss_func(self):
self.loss_func_name = type(self.loss_func).__name__
else:
raise NotImplementedError(f"Loss function {self.loss_func} not found")
if len(self.quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles)

# def set_quantiles(self):
# # convert quantiles to empty list [] if None
# if self.quantiles is None:
# self.quantiles = []
# # assert quantiles is a list type
# assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
# # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
# self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# # check if quantiles are float values in (0, 1)
# assert all(
# 0 < quantile < 1 for quantile in self.quantiles
# ), "The quantiles specified need to be floats in-between (0, 1)."
# # sort the quantiles
# self.quantiles.sort()
# # 0 is the median quantile index
# self.quantiles.insert(0, 0.5)
if len(quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles)

def set_auto_batch_epoch(
self,
Expand Down
Loading

0 comments on commit 420f8a6

Please sign in to comment.