Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed May 29, 2024
2 parents 07fa4c6 + a8baf15 commit feab327
Show file tree
Hide file tree
Showing 13 changed files with 493 additions and 190 deletions.
2 changes: 1 addition & 1 deletion lightorch/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '{{VERSION_PLACEHOLDER}}'
__version__ = "{{VERSION_PLACEHOLDER}}"
2 changes: 2 additions & 0 deletions lightorch/nn/complex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn, Tensor
from copy import deepcopy


class Complex(nn.Module):
"""
# Complex
Expand All @@ -15,4 +16,5 @@ def __init__(self, module: nn.Module) -> None:
def forward(self, x: Tensor) -> Tensor:
return self.Re_mod(x.real) + 1j * self.Im_mod(x.imag)


__all__ = ["Complex"]
87 changes: 50 additions & 37 deletions lightorch/nn/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,51 @@
from . import functional as F
from itertools import chain


def _merge_dicts(dicts: Sequence[Dict[str, float]]) -> Dict[str, float]:
out = dict()
for dict_ in dicts:
out.update(dict_)
return out


class _Base(nn.Module):
def __init__(
self,
labels: Sequence[str] | str,
factors: Dict[str, float] | Sequence[Dict[str, float]],
) -> None:
super().__init__()
if 'Overall' not in labels:
if "Overall" not in labels:
self.labels = labels.append("Overall")
self.factors = factors


class Loss(_Base):
def __init__(self, *loss) -> None:
super().__init__(
list(set([*chain.from_iterable([i.labels for i in loss])])),
_merge_dicts([i.factors for i in loss])
_merge_dicts([i.factors for i in loss]),
)
assert (len(self.loss)==len(self.factors)), 'Must have the same length of losses as factors'
assert len(self.loss) == len(
self.factors
), "Must have the same length of losses as factors"
self.loss = loss

def forward(self, **kwargs) -> Tuple[Tensor, ...]:
loss = 0
out_list = []

for loss in self.loss:
loss_, out_ = loss(**kwargs)
out_list.append(loss_)
loss += out_

out_list.append(loss)

return tuple(*out_list)


class ELBO(_Base):
"""
# Variational Autoencoder Loss:
Expand All @@ -64,9 +69,14 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]:
"""
input, target, logvar, mu
"""
L_recons, L_recons_out = self.L_recons(kwargs['input'], kwargs['target'])
L_recons, L_recons_out = self.L_recons(kwargs["input"], kwargs["target"])

L_kl = -0.5 * torch.sum(torch.log(kwargs['logvar']) - 1 + kwargs['logvar'] + torch.pow(kwargs['mu'], 2))
L_kl = -0.5 * torch.sum(
torch.log(kwargs["logvar"])
- 1
+ kwargs["logvar"]
+ torch.pow(kwargs["mu"], 2)
)

return (L_recons, L_kl, L_recons_out + self.beta * L_kl)

Expand All @@ -76,8 +86,13 @@ class StyleLoss(_Base):
"""
forward (input, target, feature_extractor: bool = True)
"""
def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3) -> None:
super().__init__(labels=[self.__class__.__name__], factors={self.__class__.__name__: factor})

def __init__(
self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3
) -> None:
super().__init__(
labels=[self.__class__.__name__], factors={self.__class__.__name__: factor}
)
self.feature_extractor = feature_extractor

F_p: List[int] = []
Expand All @@ -90,10 +105,10 @@ def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-

def forward(self, **kwargs) -> Tuple[Tensor, ...]:
out = F.style_loss(
kwargs['input'],
kwargs['target'],
kwargs["input"],
kwargs["target"],
self.F_p,
self.feature_extractor if kwargs.get('feature_extractor', True) else None,
self.feature_extractor if kwargs.get("feature_extractor", True) else None,
)
return out, self.factors[self.__class__.__name__] * out

Expand All @@ -103,11 +118,11 @@ class PerceptualLoss(_Base):
"""
forward (input, target, feature_extractor: bool = True)
"""
def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3) -> None:
super().__init__(
[self.__class__.__name__],
{self.__class__.__name__: factor}
)

def __init__(
self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3
) -> None:
super().__init__([self.__class__.__name__], {self.__class__.__name__: factor})
self.feature_extractor = feature_extractor
N_phi_p: List[int] = []

Expand All @@ -119,12 +134,12 @@ def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-

def forward(self, **kwargs) -> Tensor:
out = F.perceptual_loss(
kwargs['input'],
kwargs['target'],
kwargs["input"],
kwargs["target"],
self.N_phi_p,
self.feature_extractor if kwargs.get('feature_extractor', True) else None,
self.feature_extractor if kwargs.get("feature_extractor", True) else None,
)
return out, self.factors[self.__class__.__name__]*out
return out, self.factors[self.__class__.__name__] * out


# pnsr
Expand All @@ -134,34 +149,32 @@ class PeakNoiseSignalRatio(_Base):
"""
forward (input, target)
"""

def __init__(self, max: float, factor: float = 1) -> None:
super().__init__(
[self.__class__.__name__],
{self.__class__.__name__: factor}
)
super().__init__([self.__class__.__name__], {self.__class__.__name__: factor})
self.max = max

def forward(self, **kwargs) -> Tensor:
out = F.psnr(kwargs['input'], kwargs['target'], self.max)
out = F.psnr(kwargs["input"], kwargs["target"], self.max)
return out, out * self.factors[self.__class__.__name__]



# Total variance


class TV(nn.Module):
"""
# Total Variance (TV)
forward (input)
"""

def __init__(self, factor: float = 1):
super().__init__(
[self.__class__.__name__],
{self.__class__.__name__: factor}
)
super().__init__([self.__class__.__name__], {self.__class__.__name__: factor})

def forward(self, **kwargs) -> Tensor:
out = F.total_variance(kwargs['input'])
return out, out*self.factors[self.__class__.__name__]
out = F.total_variance(kwargs["input"])
return out, out * self.factors[self.__class__.__name__]


# lambda
class LagrangianFunctional(_Base):
Expand Down Expand Up @@ -205,5 +218,5 @@ def forward(self, out: Tensor, target: Tensor) -> Tensor:
"PeakNoiseSignalRatio",
"StyleLoss",
"PerceptualLoss",
"Loss"
"Loss",
]
84 changes: 69 additions & 15 deletions lightorch/nn/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as f
from typing import Tuple


class _FourierConvNd(nn.Module):
def __init__(
self,
Expand All @@ -33,14 +34,16 @@ def __init__(
self.ifft = lambda x: ifftn(x, dim=(-i for i in range(1, len(kernel_size))))
else:
self.ifft = False

if out_channels == in_channels:
self.one = None
else:
self.one = torch.ones(out_channels, in_channels, 1, 1) + 1j*0
self.one = torch.ones(out_channels, in_channels, 1, 1) + 1j * 0

self.eps = eps
self.weight = nn.Parameter(torch.empty(out_channels, *kernel_size, **self.factory_kwargs))
self.weight = nn.Parameter(
torch.empty(out_channels, *kernel_size, **self.factory_kwargs)
)

if bias:
self.bias = nn.Parameter(torch.empty(out_channels, **self.factory_kwargs))
Expand All @@ -65,8 +68,30 @@ def _init_parameters(self) -> None:


class _FourierDeconvNd(_FourierConvNd):
def __init__(self, in_channels: int, out_channels: int, *kernel_size, bias: bool = True, eps: float = 0.00001, pre_fft: bool = True, post_ifft: bool = False, device=None, dtype=None) -> None:
super().__init__(in_channels, out_channels, *kernel_size, bias=bias, eps=eps, pre_fft=pre_fft, post_ifft=post_ifft, device=device, dtype=dtype)
def __init__(
self,
in_channels: int,
out_channels: int,
*kernel_size,
bias: bool = True,
eps: float = 0.00001,
pre_fft: bool = True,
post_ifft: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
*kernel_size,
bias=bias,
eps=eps,
pre_fft=pre_fft,
post_ifft=post_ifft,
device=device,
dtype=dtype,
)


class FourierConv1d(_FourierConvNd):
def __init__(self, *args, **kwargs) -> None:
Expand All @@ -78,9 +103,12 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierconv1d(input, self.one, self.weight, self.bias)
else:
out = F.fourierconv1d(f.pad(
input, self.padding, mode = 'constant', value = 0
), self.one, self.weight, self.bias)
out = F.fourierconv1d(
f.pad(input, self.padding, mode="constant", value=0),
self.one,
self.weight,
self.bias,
)
if self.ifft:
return self.ifft(out)
return out
Expand All @@ -96,8 +124,13 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierconv2d(input, self.one, self.weight, self.bias)
else:
out = F.fourierconv2d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias)

out = F.fourierconv2d(
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
)

if self.ifft:
return self.ifft(out)
return out
Expand All @@ -113,11 +146,17 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierconv3d(input, self.one, self.weight, self.bias)
else:
out = F.fourierconv3d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias)
out = F.fourierconv3d(
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
)
if self.ifft:
return self.ifft(out)
return out


class FourierDeconv1d(_FourierDeconvNd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -128,7 +167,12 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierdeconv1d(input, self.one, self.weight, self.bias)
else:
out = F.fourierdeconv1d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias)
out = F.fourierdeconv1d(
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
)
if self.ifft:
return self.ifft(out)
return out
Expand All @@ -144,7 +188,12 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierdeconv2d(input, self.one, self.weight, self.bias)
else:
out = F.fourierdeconv2d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias)
out = F.fourierdeconv2d(
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
)
if self.ifft:
return self.ifft(out)
return out
Expand All @@ -160,7 +209,12 @@ def forward(self, input: Tensor) -> Tensor:
if self.padding is not None:
out = F.fourierdeconv3d(input, self.one, self.weight, self.bias)
else:
out = F.fourierdeconv3d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias)
out = F.fourierdeconv3d(
f.pad(input, self.padding, "constant", value=0),
self.one,
self.weight,
self.bias,
)
if self.ifft:
return self.ifft(out)
return out
Expand Down
Loading

0 comments on commit feab327

Please sign in to comment.