Skip to content

Commit

Permalink
Automatically formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Oct 8, 2024
1 parent 313265e commit 664cfa6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
25 changes: 18 additions & 7 deletions lightorch/nn/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(self, *loss) -> None:
loss
), "Not valid input classes, each should be different."
super().__init__(
labels = list(set([*chain.from_iterable([i.labels for i in loss])])),
factors = _merge_dicts([i.factors for i in loss]),
labels=list(set([*chain.from_iterable([i.labels for i in loss])])),
factors=_merge_dicts([i.factors for i in loss]),
)
self.loss = loss

Expand All @@ -48,8 +48,11 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]:
out_list = tuple(out_list)
return out_list


class MSELoss(nn.MSELoss):
def __init__(self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1) -> None:
def __init__(
self, size_average=None, reduce=None, reduction: str = "mean", factor: float = 1
) -> None:
super(MSELoss, self).__init__(size_average, reduce, reduction)
self.factors = {self.__class__.__name__: factor}
self.labels = [self.__class__.__name__]
Expand Down Expand Up @@ -80,8 +83,15 @@ def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
out = super().forward(kwargs["input"], kwargs["target"])
return out, out * self.factors[self.__class__.__name__]


class BinaryCrossEntropy(nn.BCELoss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
def __init__(
self,
weight: Optional[Tensor] = None,
size_average=None,
reduce=None,
reduction: str = "mean",
) -> None:
super().__init__(weight, size_average, reduce, reduction)
self.factors = {self.__class__.__name__: factor}
self.labels = [self.__class__.__name__]
Expand All @@ -91,7 +101,6 @@ def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
return out, out * self.factors[self.__class__.__name__]



class ELBO(LighTorchLoss):
"""
# Variational Autoencoder Loss:
Expand All @@ -103,8 +112,7 @@ def __init__(self, beta: float, reconstruction_criterion: LighTorchLoss) -> None
factors = {"KL Divergence": beta}
factors.update(reconstruction_criterion.factors)
super().__init__(
labels = ["KL Divergence"] + reconstruction_criterion.labels,
factors = factors
labels=["KL Divergence"] + reconstruction_criterion.labels, factors=factors
)

self.L_recons = reconstruction_criterion
Expand Down Expand Up @@ -179,6 +187,7 @@ def forward(self, **kwargs) -> Tensor:
)
return out, self.factors[self.__class__.__name__] * out


class PeakSignalNoiseRatio(LighTorchLoss):
"""
forward (input, target)
Expand All @@ -192,6 +201,7 @@ def forward(self, **kwargs) -> Tensor:
out = F.psnr(kwargs["input"], kwargs["target"], self.max)
return out, out * self.factors[self.__class__.__name__]


class TV(LighTorchLoss):
"""
# Total Variance (TV)
Expand All @@ -205,6 +215,7 @@ def forward(self, **kwargs) -> Tensor:
out = F.total_variance(kwargs["input"])
return out, out * self.factors[self.__class__.__name__]


class LagrangianFunctional(LighTorchLoss):
"""
Creates a lagrangian function of the form:
Expand Down
10 changes: 7 additions & 3 deletions lightorch/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _configure_optimizer(self) -> Optimizer:
params = self.get_param_groups()
if params is not None:
if isinstance(self.optimizer, str):
if valid:=VALID_OPTIMIZERS.get(self.optimizer, None):
if valid := VALID_OPTIMIZERS.get(self.optimizer, None):
return valid(params)
raise TypeError("Not valid optimizer")
elif isinstance(self.optimizer, Optimizer):
Expand All @@ -197,9 +197,13 @@ def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler:
self.trainer.estimated_stepping_batches
)
else:
raise ValueError(f'Scheduler kwargs not defined for {self.scheduler}')
raise ValueError(
f"Scheduler kwargs not defined for {self.scheduler}"
)
if self.scheduler_kwargs is not None:
return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs)
return VALID_SCHEDULERS[self.scheduler](
optimizer, **self.scheduler_kwargs
)
else:
return self.scheduler(optimizer)

Expand Down

0 comments on commit 664cfa6

Please sign in to comment.