Skip to content

Commit

Permalink
solved criterions and added binary cross entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Oct 8, 2024
1 parent c2abf81 commit 313265e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 58 deletions.
56 changes: 26 additions & 30 deletions lightorch/nn/criterions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn, Tensor
from typing import Sequence, Dict, Tuple, Sequence, List, Union
from typing import Optional, Sequence, Dict, Tuple, Sequence, List, Union
from . import functional as F
from itertools import chain

Expand All @@ -14,11 +14,13 @@ def _merge_dicts(dicts: Sequence[Dict[str, float]]) -> Dict[str, float]:
class LighTorchLoss(nn.Module):
def __init__(
self,
labels: Union[Sequence[str], str],
labels: Union[List[str], str],
factors: Union[Dict[str, float], Sequence[Dict[str, float]]],
) -> None:
super().__init__()
self.labels = labels
if isinstance(labels, str):
labels = [labels]
self.labels: List[str] = labels
if "Overall" not in labels:
self.labels.append("Overall")
self.factors = factors
Expand All @@ -30,36 +32,29 @@ def __init__(self, *loss) -> None:
loss
), "Not valid input classes, each should be different."
super().__init__(
list(set([*chain.from_iterable([i.labels for i in loss])])),
_merge_dicts([i.factors for i in loss]),
labels = list(set([*chain.from_iterable([i.labels for i in loss])])),
factors = _merge_dicts([i.factors for i in loss]),
)
self.loss = loss

def forward(self, **kwargs) -> Tuple[Tensor, ...]:
loss_ = Tensor([0.0])
out_list = []

for loss in self.loss:
args = loss(**kwargs)
out_list.extend(list(args[:-1]))
loss_ += args[-1]

out_list.append(loss_)

out_list = tuple(out_list)

return out_list


class MSELoss(nn.MSELoss):
def __init__(
self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1
) -> None:
def __init__(self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1) -> None:
super(MSELoss, self).__init__(size_average, reduce, reduction)
self.factors = {self.__class__.__name__: factor}
self.labels = [self.__class__.__name__]

def forward(self, **kwargs) -> Tensor:
def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
out = super().forward(kwargs["input"], kwargs["target"])
return out, out * self.factors[self.__class__.__name__]

Expand All @@ -81,11 +76,22 @@ def __init__(
self.factors = {self.__class__.__name__: factor}
self.labels = [self.__class__.__name__]

def forward(self, **kwargs) -> Tensor:
def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
out = super().forward(kwargs["input"], kwargs["target"])
return out, out * self.factors[self.__class__.__name__]

class BinaryCrossEntropy(nn.BCELoss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
super().__init__(weight, size_average, reduce, reduction)
self.factors = {self.__class__.__name__: factor}
self.labels = [self.__class__.__name__]

def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
out = super().forward(kwargs["input"], kwargs["target"])
return out, out * self.factors[self.__class__.__name__]



class ELBO(LighTorchLoss):
"""
# Variational Autoencoder Loss:
Expand All @@ -94,9 +100,11 @@ class ELBO(LighTorchLoss):
"""

def __init__(self, beta: float, reconstruction_criterion: LighTorchLoss) -> None:
factors = {"KL Divergence": beta}
factors.update(reconstruction_criterion.factors)
super().__init__(
["KL Divergence"] + reconstruction_criterion.labels,
{"KL Divergence": beta}.update(reconstruction_criterion.factors),
labels = ["KL Divergence"] + reconstruction_criterion.labels,
factors = factors
)

self.L_recons = reconstruction_criterion
Expand All @@ -107,9 +115,7 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]:
input, target, logvar, mu
"""
*L_recons, L_recons_out = self.L_recons(**kwargs)

L_kl = F.kl_div(kwargs["mu"], kwargs["logvar"])

return (*L_recons, L_kl, L_recons_out + self.beta * L_kl)


Expand Down Expand Up @@ -173,10 +179,6 @@ def forward(self, **kwargs) -> Tensor:
)
return out, self.factors[self.__class__.__name__] * out


# pnsr


class PeakSignalNoiseRatio(LighTorchLoss):
"""
forward (input, target)
Expand All @@ -190,10 +192,6 @@ def forward(self, **kwargs) -> Tensor:
out = F.psnr(kwargs["input"], kwargs["target"], self.max)
return out, out * self.factors[self.__class__.__name__]


# Total variance


class TV(LighTorchLoss):
"""
# Total Variance (TV)
Expand All @@ -207,8 +205,6 @@ def forward(self, **kwargs) -> Tensor:
out = F.total_variance(kwargs["input"])
return out, out * self.factors[self.__class__.__name__]


# lambda
class LagrangianFunctional(LighTorchLoss):
"""
Creates a lagrangian function of the form:
Expand Down Expand Up @@ -281,7 +277,7 @@ def __init__(
self.g = g
self.f = f

def forward(self, **kwargs) -> Tensor:
def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
g_out_list: List[float] = []
g_out_fact: List[float] = []
for constraint in self.g:
Expand Down
68 changes: 40 additions & 28 deletions lightorch/training/supervised.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from lightning.pytorch import LightningModule
from typing import Union, Sequence, Any, Dict, List
from typing import Optional, Union, Sequence, Any, Dict, List
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from collections import defaultdict
import torch
from torch.optim import Adam, Adadelta, Adamax, AdamW, SGD, LBFGS, RMSprop
from torch.optim.adadelta import Adadelta
from torch.optim.adam import Adam
from torch.optim.adamax import Adamax
from torch.optim.adamw import AdamW
from torch.optim.sgd import SGD
from torch.optim.lbfgs import LBFGS
from torch.optim.rmsprop import RMSprop

from torch.optim.lr_scheduler import (
OneCycleLR,
Expand Down Expand Up @@ -53,10 +59,10 @@ def __init__(
self,
*,
optimizer: Union[str, Optimizer],
scheduler: Union[str, LRScheduler] = None,
triggers: Dict[str, Dict[str, float]] = None,
optimizer_kwargs: Dict[str, Any] = None,
scheduler_kwargs: Dict[str, Any] = None,
scheduler: Optional[Union[str, LRScheduler]] = None,
triggers: Optional[Dict[str, Dict[str, float]]] = None,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -147,28 +153,29 @@ def get_param_groups(self) -> List[Dict[str, Union[nn.Module, List[float]]]]:
Given a list of "triggers", the param groups are defined.
"""
if self.triggers is not None:
param_groups: Sequence[Dict[str, List[nn.Module]]] = [
defaultdict(list) for _ in range(len(self.triggers))
param_groups: Sequence[Dict[str, Union[List[nn.Parameter], float]]] = [
defaultdict(list) for _ in self.triggers
]
# Update the model parameters per group and finally add the
# hyperparameters
for param_group, trigger in zip(param_groups, self.triggers):
for idx, trigger in enumerate(self.triggers):
for name, param in self.named_parameters():
if name.startswith(trigger):
param_group["params"].append(param)
param_groups[idx]["params"].append(param)

param_group.update(self.triggers[trigger])
param_groups[idx].update(self.triggers[trigger])

return param_groups
return None
raise TypeError("Triggers are not defined")

def _configure_optimizer(self) -> Optimizer:
if params := self.get_param_groups() is not None:
params = self.get_param_groups()
if params is not None:
if isinstance(self.optimizer, str):
return VALID_OPTIMIZERS[self.optimizer](params)
elif isinstance(self.optimizer, torch.optim.Optimizer):
if valid:=VALID_OPTIMIZERS.get(self.optimizer, None):
return valid(params)
raise TypeError("Not valid optimizer")
elif isinstance(self.optimizer, Optimizer):
return self.optimizer
elif issubclass(self.optimizer, torch.optim.Optimizer):
elif issubclass(self.optimizer, Optimizer):
return self.optimizer(params)
else:

Expand All @@ -182,18 +189,23 @@ def _configure_optimizer(self) -> Optimizer:
return self.optimizer(self.parameters(), **self.optimizer_kwargs)

def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler:
if isinstance(self.scheduler, str):
if self.scheduler == "onecycle":
self.scheduler_kwargs["total_steps"] = (
self.trainer.estimated_stepping_batches
)
return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs)
else:
return self.scheduler(optimizer)
if self.scheduler is not None:
if isinstance(self.scheduler, str):
if self.scheduler == "onecycle":
if self.scheduler_kwargs is not None:
self.scheduler_kwargs["total_steps"] = (
self.trainer.estimated_stepping_batches
)
else:
raise ValueError(f'Scheduler kwargs not defined for {self.scheduler}')
if self.scheduler_kwargs is not None:
return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs)
else:
return self.scheduler(optimizer)

def configure_optimizers(
self,
) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]:
) -> Dict[str, Union[Optimizer, Dict[str, Union[str, int, LRScheduler]]]]:
optimizer = self._configure_optimizer()
if self.scheduler is not None:
scheduler = self._configure_scheduler(optimizer)
Expand Down

0 comments on commit 313265e

Please sign in to comment.