From 664cfa6752c3a9e49ade553525d75a3c290f1582 Mon Sep 17 00:00:00 2001 From: Jorgedavyd Date: Tue, 8 Oct 2024 18:18:47 +0000 Subject: [PATCH] Automatically formatted with black --- lightorch/nn/criterions.py | 25 ++++++++++++++++++------- lightorch/training/supervised.py | 10 +++++++--- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/lightorch/nn/criterions.py b/lightorch/nn/criterions.py index b9cb3bd..9d59642 100644 --- a/lightorch/nn/criterions.py +++ b/lightorch/nn/criterions.py @@ -32,8 +32,8 @@ def __init__(self, *loss) -> None: loss ), "Not valid input classes, each should be different." super().__init__( - labels = list(set([*chain.from_iterable([i.labels for i in loss])])), - factors = _merge_dicts([i.factors for i in loss]), + labels=list(set([*chain.from_iterable([i.labels for i in loss])])), + factors=_merge_dicts([i.factors for i in loss]), ) self.loss = loss @@ -48,8 +48,11 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]: out_list = tuple(out_list) return out_list + class MSELoss(nn.MSELoss): - def __init__(self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1) -> None: + def __init__( + self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1 + ) -> None: super(MSELoss, self).__init__(size_average, reduce, reduction) self.factors = {self.__class__.__name__: factor} self.labels = [self.__class__.__name__] @@ -80,8 +83,15 @@ def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: out = super().forward(kwargs["input"], kwargs["target"]) return out, out * self.factors[self.__class__.__name__] + class BinaryCrossEntropy(nn.BCELoss): - def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, reduce, reduction) self.factors = {self.__class__.__name__: factor} self.labels = [self.__class__.__name__] @@ -91,7 +101,6 @@ def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: return out, out * self.factors[self.__class__.__name__] - class ELBO(LighTorchLoss): """ # Variational Autoencoder Loss: @@ -103,8 +112,7 @@ def __init__(self, beta: float, reconstruction_criterion: LighTorchLoss) -> None factors = {"KL Divergence": beta} factors.update(reconstruction_criterion.factors) super().__init__( - labels = ["KL Divergence"] + reconstruction_criterion.labels, - factors = factors + labels=["KL Divergence"] + reconstruction_criterion.labels, factors=factors ) self.L_recons = reconstruction_criterion @@ -179,6 +187,7 @@ def forward(self, **kwargs) -> Tensor: ) return out, self.factors[self.__class__.__name__] * out + class PeakSignalNoiseRatio(LighTorchLoss): """ forward (input, target) @@ -192,6 +201,7 @@ def forward(self, **kwargs) -> Tensor: out = F.psnr(kwargs["input"], kwargs["target"], self.max) return out, out * self.factors[self.__class__.__name__] + class TV(LighTorchLoss): """ # Total Variance (TV) @@ -205,6 +215,7 @@ def forward(self, **kwargs) -> Tensor: out = F.total_variance(kwargs["input"]) return out, out * self.factors[self.__class__.__name__] + class LagrangianFunctional(LighTorchLoss): """ Creates a lagrangian function of the form: diff --git a/lightorch/training/supervised.py b/lightorch/training/supervised.py index 50bc31e..f682c7e 100644 --- a/lightorch/training/supervised.py +++ b/lightorch/training/supervised.py @@ -170,7 +170,7 @@ def _configure_optimizer(self) -> Optimizer: params = self.get_param_groups() if params is not None: if isinstance(self.optimizer, str): - if valid:=VALID_OPTIMIZERS.get(self.optimizer, None): + if valid := VALID_OPTIMIZERS.get(self.optimizer, None): return valid(params) raise TypeError("Not valid optimizer") elif isinstance(self.optimizer, Optimizer): @@ -197,9 +197,13 @@ def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: self.trainer.estimated_stepping_batches ) else: - raise ValueError(f'Scheduler kwargs not defined for {self.scheduler}') + raise ValueError( + f"Scheduler kwargs not defined for {self.scheduler}" + ) if self.scheduler_kwargs is not None: - return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs) + return VALID_SCHEDULERS[self.scheduler]( + optimizer, **self.scheduler_kwargs + ) else: return self.scheduler(optimizer)