diff --git a/lightorch/nn/criterions.py b/lightorch/nn/criterions.py index 5141a29..b9cb3bd 100644 --- a/lightorch/nn/criterions.py +++ b/lightorch/nn/criterions.py @@ -1,5 +1,5 @@ from torch import nn, Tensor -from typing import Sequence, Dict, Tuple, Sequence, List, Union +from typing import Optional, Sequence, Dict, Tuple, Sequence, List, Union from . import functional as F from itertools import chain @@ -14,11 +14,13 @@ def _merge_dicts(dicts: Sequence[Dict[str, float]]) -> Dict[str, float]: class LighTorchLoss(nn.Module): def __init__( self, - labels: Union[Sequence[str], str], + labels: Union[List[str], str], factors: Union[Dict[str, float], Sequence[Dict[str, float]]], ) -> None: super().__init__() - self.labels = labels + if isinstance(labels, str): + labels = [labels] + self.labels: List[str] = labels if "Overall" not in labels: self.labels.append("Overall") self.factors = factors @@ -30,36 +32,29 @@ def __init__(self, *loss) -> None: loss ), "Not valid input classes, each should be different." super().__init__( - list(set([*chain.from_iterable([i.labels for i in loss])])), - _merge_dicts([i.factors for i in loss]), + labels = list(set([*chain.from_iterable([i.labels for i in loss])])), + factors = _merge_dicts([i.factors for i in loss]), ) self.loss = loss def forward(self, **kwargs) -> Tuple[Tensor, ...]: loss_ = Tensor([0.0]) out_list = [] - for loss in self.loss: args = loss(**kwargs) out_list.extend(list(args[:-1])) loss_ += args[-1] - out_list.append(loss_) - out_list = tuple(out_list) - return out_list - class MSELoss(nn.MSELoss): - def __init__( - self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1 - ) -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1) -> None: super(MSELoss, self).__init__(size_average, reduce, reduction) self.factors = {self.__class__.__name__: factor} self.labels = [self.__class__.__name__] - def forward(self, **kwargs) -> Tensor: + def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: out = super().forward(kwargs["input"], kwargs["target"]) return out, out * self.factors[self.__class__.__name__] @@ -81,11 +76,22 @@ def __init__( self.factors = {self.__class__.__name__: factor} self.labels = [self.__class__.__name__] - def forward(self, **kwargs) -> Tensor: + def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: + out = super().forward(kwargs["input"], kwargs["target"]) + return out, out * self.factors[self.__class__.__name__] + +class BinaryCrossEntropy(nn.BCELoss): + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + self.factors = {self.__class__.__name__: factor} + self.labels = [self.__class__.__name__] + + def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: out = super().forward(kwargs["input"], kwargs["target"]) return out, out * self.factors[self.__class__.__name__] + class ELBO(LighTorchLoss): """ # Variational Autoencoder Loss: @@ -94,9 +100,11 @@ class ELBO(LighTorchLoss): """ def __init__(self, beta: float, reconstruction_criterion: LighTorchLoss) -> None: + factors = {"KL Divergence": beta} + factors.update(reconstruction_criterion.factors) super().__init__( - ["KL Divergence"] + reconstruction_criterion.labels, - {"KL Divergence": beta}.update(reconstruction_criterion.factors), + labels = ["KL Divergence"] + reconstruction_criterion.labels, + factors = factors ) self.L_recons = reconstruction_criterion @@ -107,9 +115,7 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]: input, target, logvar, mu """ *L_recons, L_recons_out = self.L_recons(**kwargs) - L_kl = F.kl_div(kwargs["mu"], kwargs["logvar"]) - return (*L_recons, L_kl, L_recons_out + self.beta * L_kl) @@ -173,10 +179,6 @@ def forward(self, **kwargs) -> Tensor: ) return out, self.factors[self.__class__.__name__] * out - -# pnsr - - class PeakSignalNoiseRatio(LighTorchLoss): """ forward (input, target) @@ -190,10 +192,6 @@ def forward(self, **kwargs) -> Tensor: out = F.psnr(kwargs["input"], kwargs["target"], self.max) return out, out * self.factors[self.__class__.__name__] - -# Total variance - - class TV(LighTorchLoss): """ # Total Variance (TV) @@ -207,8 +205,6 @@ def forward(self, **kwargs) -> Tensor: out = F.total_variance(kwargs["input"]) return out, out * self.factors[self.__class__.__name__] - -# lambda class LagrangianFunctional(LighTorchLoss): """ Creates a lagrangian function of the form: @@ -281,7 +277,7 @@ def __init__( self.g = g self.f = f - def forward(self, **kwargs) -> Tensor: + def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: g_out_list: List[float] = [] g_out_fact: List[float] = [] for constraint in self.g: diff --git a/lightorch/training/supervised.py b/lightorch/training/supervised.py index a00abde..50bc31e 100644 --- a/lightorch/training/supervised.py +++ b/lightorch/training/supervised.py @@ -1,11 +1,17 @@ from lightning.pytorch import LightningModule -from typing import Union, Sequence, Any, Dict, List +from typing import Optional, Union, Sequence, Any, Dict, List from torch import Tensor, nn -from torch.optim import Optimizer +from torch.optim.optimizer import Optimizer from torch.optim.lr_scheduler import LRScheduler from collections import defaultdict import torch -from torch.optim import Adam, Adadelta, Adamax, AdamW, SGD, LBFGS, RMSprop +from torch.optim.adadelta import Adadelta +from torch.optim.adam import Adam +from torch.optim.adamax import Adamax +from torch.optim.adamw import AdamW +from torch.optim.sgd import SGD +from torch.optim.lbfgs import LBFGS +from torch.optim.rmsprop import RMSprop from torch.optim.lr_scheduler import ( OneCycleLR, @@ -53,10 +59,10 @@ def __init__( self, *, optimizer: Union[str, Optimizer], - scheduler: Union[str, LRScheduler] = None, - triggers: Dict[str, Dict[str, float]] = None, - optimizer_kwargs: Dict[str, Any] = None, - scheduler_kwargs: Dict[str, Any] = None, + scheduler: Optional[Union[str, LRScheduler]] = None, + triggers: Optional[Dict[str, Dict[str, float]]] = None, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__() @@ -147,28 +153,29 @@ def get_param_groups(self) -> List[Dict[str, Union[nn.Module, List[float]]]]: Given a list of "triggers", the param groups are defined. """ if self.triggers is not None: - param_groups: Sequence[Dict[str, List[nn.Module]]] = [ - defaultdict(list) for _ in range(len(self.triggers)) + param_groups: Sequence[Dict[str, Union[List[nn.Parameter], float]]] = [ + defaultdict(list) for _ in self.triggers ] - # Update the model parameters per group and finally add the - # hyperparameters - for param_group, trigger in zip(param_groups, self.triggers): + for idx, trigger in enumerate(self.triggers): for name, param in self.named_parameters(): if name.startswith(trigger): - param_group["params"].append(param) + param_groups[idx]["params"].append(param) - param_group.update(self.triggers[trigger]) + param_groups[idx].update(self.triggers[trigger]) return param_groups - return None + raise TypeError("Triggers are not defined") def _configure_optimizer(self) -> Optimizer: - if params := self.get_param_groups() is not None: + params = self.get_param_groups() + if params is not None: if isinstance(self.optimizer, str): - return VALID_OPTIMIZERS[self.optimizer](params) - elif isinstance(self.optimizer, torch.optim.Optimizer): + if valid:=VALID_OPTIMIZERS.get(self.optimizer, None): + return valid(params) + raise TypeError("Not valid optimizer") + elif isinstance(self.optimizer, Optimizer): return self.optimizer - elif issubclass(self.optimizer, torch.optim.Optimizer): + elif issubclass(self.optimizer, Optimizer): return self.optimizer(params) else: @@ -182,18 +189,23 @@ def _configure_optimizer(self) -> Optimizer: return self.optimizer(self.parameters(), **self.optimizer_kwargs) def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: - if isinstance(self.scheduler, str): - if self.scheduler == "onecycle": - self.scheduler_kwargs["total_steps"] = ( - self.trainer.estimated_stepping_batches - ) - return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs) - else: - return self.scheduler(optimizer) + if self.scheduler is not None: + if isinstance(self.scheduler, str): + if self.scheduler == "onecycle": + if self.scheduler_kwargs is not None: + self.scheduler_kwargs["total_steps"] = ( + self.trainer.estimated_stepping_batches + ) + else: + raise ValueError(f'Scheduler kwargs not defined for {self.scheduler}') + if self.scheduler_kwargs is not None: + return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs) + else: + return self.scheduler(optimizer) def configure_optimizers( self, - ) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]: + ) -> Dict[str, Union[Optimizer, Dict[str, Union[str, int, LRScheduler]]]]: optimizer = self._configure_optimizer() if self.scheduler is not None: scheduler = self._configure_scheduler(optimizer)