diff --git a/lightorch/_version.py b/lightorch/_version.py index 89ac199..9be0efa 100644 --- a/lightorch/_version.py +++ b/lightorch/_version.py @@ -1 +1 @@ -__version__ = '{{VERSION_PLACEHOLDER}}' +__version__ = "{{VERSION_PLACEHOLDER}}" diff --git a/lightorch/nn/complex.py b/lightorch/nn/complex.py index 812c021..849f1ca 100644 --- a/lightorch/nn/complex.py +++ b/lightorch/nn/complex.py @@ -1,6 +1,7 @@ from torch import nn, Tensor from copy import deepcopy + class Complex(nn.Module): """ # Complex @@ -15,4 +16,5 @@ def __init__(self, module: nn.Module) -> None: def forward(self, x: Tensor) -> Tensor: return self.Re_mod(x.real) + 1j * self.Im_mod(x.imag) + __all__ = ["Complex"] diff --git a/lightorch/nn/criterions.py b/lightorch/nn/criterions.py index f3df6ee..bdfb251 100644 --- a/lightorch/nn/criterions.py +++ b/lightorch/nn/criterions.py @@ -4,12 +4,14 @@ from . import functional as F from itertools import chain + def _merge_dicts(dicts: Sequence[Dict[str, float]]) -> Dict[str, float]: out = dict() for dict_ in dicts: out.update(dict_) return out + class _Base(nn.Module): def __init__( self, @@ -17,33 +19,36 @@ def __init__( factors: Dict[str, float] | Sequence[Dict[str, float]], ) -> None: super().__init__() - if 'Overall' not in labels: + if "Overall" not in labels: self.labels = labels.append("Overall") self.factors = factors + class Loss(_Base): def __init__(self, *loss) -> None: super().__init__( list(set([*chain.from_iterable([i.labels for i in loss])])), - _merge_dicts([i.factors for i in loss]) + _merge_dicts([i.factors for i in loss]), ) - assert (len(self.loss)==len(self.factors)), 'Must have the same length of losses as factors' + assert len(self.loss) == len( + self.factors + ), "Must have the same length of losses as factors" self.loss = loss def forward(self, **kwargs) -> Tuple[Tensor, ...]: loss = 0 out_list = [] - + for loss in self.loss: loss_, out_ = loss(**kwargs) out_list.append(loss_) loss += out_ - + out_list.append(loss) - + return tuple(*out_list) - - + + class ELBO(_Base): """ # Variational Autoencoder Loss: @@ -64,9 +69,14 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]: """ input, target, logvar, mu """ - L_recons, L_recons_out = self.L_recons(kwargs['input'], kwargs['target']) + L_recons, L_recons_out = self.L_recons(kwargs["input"], kwargs["target"]) - L_kl = -0.5 * torch.sum(torch.log(kwargs['logvar']) - 1 + kwargs['logvar'] + torch.pow(kwargs['mu'], 2)) + L_kl = -0.5 * torch.sum( + torch.log(kwargs["logvar"]) + - 1 + + kwargs["logvar"] + + torch.pow(kwargs["mu"], 2) + ) return (L_recons, L_kl, L_recons_out + self.beta * L_kl) @@ -76,8 +86,13 @@ class StyleLoss(_Base): """ forward (input, target, feature_extractor: bool = True) """ - def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3) -> None: - super().__init__(labels=[self.__class__.__name__], factors={self.__class__.__name__: factor}) + + def __init__( + self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3 + ) -> None: + super().__init__( + labels=[self.__class__.__name__], factors={self.__class__.__name__: factor} + ) self.feature_extractor = feature_extractor F_p: List[int] = [] @@ -90,10 +105,10 @@ def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e- def forward(self, **kwargs) -> Tuple[Tensor, ...]: out = F.style_loss( - kwargs['input'], - kwargs['target'], + kwargs["input"], + kwargs["target"], self.F_p, - self.feature_extractor if kwargs.get('feature_extractor', True) else None, + self.feature_extractor if kwargs.get("feature_extractor", True) else None, ) return out, self.factors[self.__class__.__name__] * out @@ -103,11 +118,11 @@ class PerceptualLoss(_Base): """ forward (input, target, feature_extractor: bool = True) """ - def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3) -> None: - super().__init__( - [self.__class__.__name__], - {self.__class__.__name__: factor} - ) + + def __init__( + self, feature_extractor, sample_tensor: Tensor, factor: float = 1e-3 + ) -> None: + super().__init__([self.__class__.__name__], {self.__class__.__name__: factor}) self.feature_extractor = feature_extractor N_phi_p: List[int] = [] @@ -119,12 +134,12 @@ def __init__(self, feature_extractor, sample_tensor: Tensor, factor: float = 1e- def forward(self, **kwargs) -> Tensor: out = F.perceptual_loss( - kwargs['input'], - kwargs['target'], + kwargs["input"], + kwargs["target"], self.N_phi_p, - self.feature_extractor if kwargs.get('feature_extractor', True) else None, + self.feature_extractor if kwargs.get("feature_extractor", True) else None, ) - return out, self.factors[self.__class__.__name__]*out + return out, self.factors[self.__class__.__name__] * out # pnsr @@ -134,19 +149,19 @@ class PeakNoiseSignalRatio(_Base): """ forward (input, target) """ + def __init__(self, max: float, factor: float = 1) -> None: - super().__init__( - [self.__class__.__name__], - {self.__class__.__name__: factor} - ) + super().__init__([self.__class__.__name__], {self.__class__.__name__: factor}) self.max = max def forward(self, **kwargs) -> Tensor: - out = F.psnr(kwargs['input'], kwargs['target'], self.max) + out = F.psnr(kwargs["input"], kwargs["target"], self.max) return out, out * self.factors[self.__class__.__name__] - + + # Total variance + class TV(nn.Module): """ # Total Variance (TV) @@ -154,14 +169,12 @@ class TV(nn.Module): """ def __init__(self, factor: float = 1): - super().__init__( - [self.__class__.__name__], - {self.__class__.__name__: factor} - ) + super().__init__([self.__class__.__name__], {self.__class__.__name__: factor}) def forward(self, **kwargs) -> Tensor: - out = F.total_variance(kwargs['input']) - return out, out*self.factors[self.__class__.__name__] + out = F.total_variance(kwargs["input"]) + return out, out * self.factors[self.__class__.__name__] + # lambda class LagrangianFunctional(_Base): @@ -205,5 +218,5 @@ def forward(self, out: Tensor, target: Tensor) -> Tensor: "PeakNoiseSignalRatio", "StyleLoss", "PerceptualLoss", - "Loss" + "Loss", ] diff --git a/lightorch/nn/fourier.py b/lightorch/nn/fourier.py index 84873ba..aa42011 100644 --- a/lightorch/nn/fourier.py +++ b/lightorch/nn/fourier.py @@ -7,6 +7,7 @@ import torch.nn.functional as f from typing import Tuple + class _FourierConvNd(nn.Module): def __init__( self, @@ -33,14 +34,16 @@ def __init__( self.ifft = lambda x: ifftn(x, dim=(-i for i in range(1, len(kernel_size)))) else: self.ifft = False - + if out_channels == in_channels: self.one = None else: - self.one = torch.ones(out_channels, in_channels, 1, 1) + 1j*0 - + self.one = torch.ones(out_channels, in_channels, 1, 1) + 1j * 0 + self.eps = eps - self.weight = nn.Parameter(torch.empty(out_channels, *kernel_size, **self.factory_kwargs)) + self.weight = nn.Parameter( + torch.empty(out_channels, *kernel_size, **self.factory_kwargs) + ) if bias: self.bias = nn.Parameter(torch.empty(out_channels, **self.factory_kwargs)) @@ -65,8 +68,30 @@ def _init_parameters(self) -> None: class _FourierDeconvNd(_FourierConvNd): - def __init__(self, in_channels: int, out_channels: int, *kernel_size, bias: bool = True, eps: float = 0.00001, pre_fft: bool = True, post_ifft: bool = False, device=None, dtype=None) -> None: - super().__init__(in_channels, out_channels, *kernel_size, bias=bias, eps=eps, pre_fft=pre_fft, post_ifft=post_ifft, device=device, dtype=dtype) + def __init__( + self, + in_channels: int, + out_channels: int, + *kernel_size, + bias: bool = True, + eps: float = 0.00001, + pre_fft: bool = True, + post_ifft: bool = False, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + *kernel_size, + bias=bias, + eps=eps, + pre_fft=pre_fft, + post_ifft=post_ifft, + device=device, + dtype=dtype, + ) + class FourierConv1d(_FourierConvNd): def __init__(self, *args, **kwargs) -> None: @@ -78,9 +103,12 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierconv1d(input, self.one, self.weight, self.bias) else: - out = F.fourierconv1d(f.pad( - input, self.padding, mode = 'constant', value = 0 - ), self.one, self.weight, self.bias) + out = F.fourierconv1d( + f.pad(input, self.padding, mode="constant", value=0), + self.one, + self.weight, + self.bias, + ) if self.ifft: return self.ifft(out) return out @@ -96,8 +124,13 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierconv2d(input, self.one, self.weight, self.bias) else: - out = F.fourierconv2d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias) - + out = F.fourierconv2d( + f.pad(input, self.padding, "constant", value=0), + self.one, + self.weight, + self.bias, + ) + if self.ifft: return self.ifft(out) return out @@ -113,11 +146,17 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierconv3d(input, self.one, self.weight, self.bias) else: - out = F.fourierconv3d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias) + out = F.fourierconv3d( + f.pad(input, self.padding, "constant", value=0), + self.one, + self.weight, + self.bias, + ) if self.ifft: return self.ifft(out) return out + class FourierDeconv1d(_FourierDeconvNd): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -128,7 +167,12 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierdeconv1d(input, self.one, self.weight, self.bias) else: - out = F.fourierdeconv1d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias) + out = F.fourierdeconv1d( + f.pad(input, self.padding, "constant", value=0), + self.one, + self.weight, + self.bias, + ) if self.ifft: return self.ifft(out) return out @@ -144,7 +188,12 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierdeconv2d(input, self.one, self.weight, self.bias) else: - out = F.fourierdeconv2d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias) + out = F.fourierdeconv2d( + f.pad(input, self.padding, "constant", value=0), + self.one, + self.weight, + self.bias, + ) if self.ifft: return self.ifft(out) return out @@ -160,7 +209,12 @@ def forward(self, input: Tensor) -> Tensor: if self.padding is not None: out = F.fourierdeconv3d(input, self.one, self.weight, self.bias) else: - out = F.fourierdeconv3d(f.pad(input, self.padding, 'constant', value = 0), self.one, self.weight, self.bias) + out = F.fourierdeconv3d( + f.pad(input, self.padding, "constant", value=0), + self.one, + self.weight, + self.bias, + ) if self.ifft: return self.ifft(out) return out diff --git a/lightorch/nn/functional.py b/lightorch/nn/functional.py index eae18c5..863040f 100644 --- a/lightorch/nn/functional.py +++ b/lightorch/nn/functional.py @@ -5,24 +5,29 @@ from lightning.pytorch import LightningModule from einops import rearrange + def _fourierconvNd(x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: # weight -> 1, 1, out channels, *kernel_size - x *= weight.reshape(1,1, *weight.shape) # Convolution in the fourier space + x *= weight.reshape(1, 1, *weight.shape) # Convolution in the fourier space if bias is not None: - return x + bias.reshape(1,1,-1,1,1) - + return x + bias.reshape(1, 1, -1, 1, 1) + return x -def _fourierdeconvNd(x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5) -> Tensor: + +def _fourierdeconvNd( + x: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5 +) -> Tensor: # weight -> 1, 1, out channels, *kernel_size - x /= (weight.reshape(1,1, *weight.shape) + eps) # Convolution in the fourier space + x /= weight.reshape(1, 1, *weight.shape) + eps # Convolution in the fourier space if bias is not None: - return x + bias.reshape(1,1,-1,1,1) - + return x + bias.reshape(1, 1, -1, 1, 1) + return x + def fourierconv3d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ x (Tensor): batch size, channels, height, width @@ -34,17 +39,24 @@ def fourierconv3d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ if one is not None: # Augment the channel dimension of the input - out = F.conv3d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - + out = F.conv3d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + # Rearrange tensors for Fourier convolution - out = rearrange(out, 'B C (f kd) (h kh) (w kw) -> B (f h w) C kd kh kw', kd=weight.shape[-3], kh=weight.shape[-2], kw=weight.shape[-1]) - + out = rearrange( + out, + "B C (f kd) (h kh) (w kw) -> B (f h w) C kd kh kw", + kd=weight.shape[-3], + kh=weight.shape[-2], + kw=weight.shape[-1], + ) + out = _fourierconvNd(out, weight, bias) - out = rearrange(out, 'B (f h w) C kd kh kw -> B C (f kd) (h kh) (w kw)') - + out = rearrange(out, "B (f h w) C kd kh kw -> B C (f kd) (h kh) (w kw)") + return out + def fourierconv2d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ x (Tensor): batch size, channels, height, width @@ -56,16 +68,22 @@ def fourierconv2d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ if one is not None: # Augment the channel dimension of the input - out = F.conv2d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + out = F.conv2d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - out = rearrange(out, 'B C (h k1) (w k2) -> B (h w) C k1 k2', k1 = weight.shape[-2], k2 = weight.shape[-1]) + out = rearrange( + out, + "B C (h k1) (w k2) -> B (h w) C k1 k2", + k1=weight.shape[-2], + k2=weight.shape[-1], + ) out = _fourierconvNd(out, weight, bias) - out = rearrange(out, 'B (h w) C k1 k2 -> B C (h k1) (w k2)') + out = rearrange(out, "B (h w) C k1 k2 -> B C (h k1) (w k2)") return out + def fourierconv1d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ x (Tensor): batch size, channels, sequence length @@ -77,17 +95,20 @@ def fourierconv1d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None): """ if one is not None: # Augment the channel dimension of the input - out = F.conv1d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - - out = rearrange(out, 'B C (l k) -> B l C k', k = weight.shape[-1]) + out = F.conv1d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + + out = rearrange(out, "B C (l k) -> B l C k", k=weight.shape[-1]) out = _fourierconvNd(out, weight, bias) - out = rearrange(out, 'B l C k -> B C (l k)') + out = rearrange(out, "B l C k -> B C (l k)") return out -def fourierdeconv3d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5): + +def fourierdeconv3d( + x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5 +): """ x (Tensor): batch size, channels, height, width weight (Tensor): out channels, *kernel_size @@ -98,18 +119,27 @@ def fourierdeconv3d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, """ if one is not None: # Augment the channel dimension of the input - out = F.conv3d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - + out = F.conv3d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + # Rearrange tensors for Fourier convolution - out = rearrange(out, 'B C (f kd) (h kh) (w kw) -> B (f h w) C kd kh kw', kd=weight.shape[-3], kh=weight.shape[-2], kw=weight.shape[-1]) - + out = rearrange( + out, + "B C (f kd) (h kh) (w kw) -> B (f h w) C kd kh kw", + kd=weight.shape[-3], + kh=weight.shape[-2], + kw=weight.shape[-1], + ) + out = _fourierdeconvNd(out, weight, bias, eps) - out = rearrange(out, 'B (f h w) C kd kh kw -> B C (f kd) (h kh) (w kw)') - + out = rearrange(out, "B (f h w) C kd kh kw -> B C (f kd) (h kh) (w kw)") + return out -def fourierdeconv2d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5): + +def fourierdeconv2d( + x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5 +): """ x (Tensor): batch size, channels, height, width weight (Tensor): out channels, *kernel_size @@ -120,17 +150,25 @@ def fourierdeconv2d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, """ if one is not None: # Augment the channel dimension of the input - out = F.conv2d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + out = F.conv2d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - out = rearrange(out, 'B C (h k1) (w k2) -> B (h w) C k1 k2', k1 = weight.shape[-2], k2 = weight.shape[-1]) + out = rearrange( + out, + "B C (h k1) (w k2) -> B (h w) C k1 k2", + k1=weight.shape[-2], + k2=weight.shape[-1], + ) out = _fourierdeconvNd(out, weight, bias, eps) - out = rearrange(out, 'B (h w) C k1 k2 -> B C (h k1) (w k2)') + out = rearrange(out, "B (h w) C k1 k2 -> B C (h k1) (w k2)") return out -def fourierdeconv1d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5): + +def fourierdeconv1d( + x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, eps: float = 1e-5 +): """ x (Tensor): batch size, channels, sequence length weight (Tensor): out channels, kernel_size @@ -141,47 +179,48 @@ def fourierdeconv1d(x: Tensor, one: Tensor, weight: Tensor, bias: Tensor | None, """ if one is not None: # Augment the channel dimension of the input - out = F.conv1d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) - - out = rearrange(out, 'B C (l k) -> B l C k', k = weight.shape[-1]) + out = F.conv1d(x, one, None, 1) # one: (out_channel, in_channel, *kernel_size) + + out = rearrange(out, "B C (l k) -> B l C k", k=weight.shape[-1]) out = _fourierdeconvNd(out, weight, bias, eps) - out = rearrange(out, 'B l C k -> B C (l k)') + out = rearrange(out, "B l C k -> B C (l k)") return out def _partialconvnd( - conv: F, - input: Tensor, - mask_in: Tensor, - weight: Tensor, - one_sum: Tensor, - bias: Optional[Tensor], - stride, - padding, - dilation, - update_mask: bool = True + conv: F, + input: Tensor, + mask_in: Tensor, + weight: Tensor, + one_sum: Tensor, + bias: Optional[Tensor], + stride, + padding, + dilation, + update_mask: bool = True, ) -> Tuple[Tensor, Tensor] | Tensor: - + with torch.no_grad(): sum_m: Tensor = conv( mask_in, torch.ones_like(weight, requires_grad=False), - stride = stride, - padding = padding, - dilation = dilation + stride=stride, + padding=padding, + dilation=dilation, ) if update_mask: updated_mask = sum_m.clamp_max(1) - + out = conv(input * mask_in, weight, None, stride, padding, dilation) - out *= (one_sum / sum_m) + out *= one_sum / sum_m out += bias - return (out, updated_mask) if update_mask else out + return (out, updated_mask) if update_mask else out + def partialconv3d( input: Tensor, @@ -195,7 +234,19 @@ def partialconv3d( update_mask: bool = True, ) -> Union[Tuple[Tensor, Tensor], Tensor]: - return _partialconvnd(F.conv3d, input, mask_in, weight, one_sum, bias, stride, padding, dilation, update_mask) + return _partialconvnd( + F.conv3d, + input, + mask_in, + weight, + one_sum, + bias, + stride, + padding, + dilation, + update_mask, + ) + def partialconv2d( input: Tensor, @@ -209,7 +260,19 @@ def partialconv2d( update_mask: bool = True, ) -> Union[Tuple[Tensor, Tensor], Tensor]: - return _partialconvnd(F.conv2d, input, mask_in, weight, one_sum, bias, stride, padding, dilation, update_mask) + return _partialconvnd( + F.conv2d, + input, + mask_in, + weight, + one_sum, + bias, + stride, + padding, + dilation, + update_mask, + ) + def partialconv1d( input: Tensor, @@ -223,7 +286,19 @@ def partialconv1d( update_mask: bool = True, ) -> Union[Tuple[Tensor, Tensor], Tensor]: - return _partialconvnd(F.conv1d, input, mask_in, weight, one_sum, bias, stride, padding, dilation, update_mask) + return _partialconvnd( + F.conv1d, + input, + mask_in, + weight, + one_sum, + bias, + stride, + padding, + dilation, + update_mask, + ) + def residual_connection( x: Tensor, @@ -232,13 +307,16 @@ def residual_connection( ) -> Tensor: return to_dim_layer(x) + sublayer(x) + # Criterion functionals + def psnr(input: Tensor, target: Tensor, max: float) -> Tensor: return 10 * torch.log10( torch.div(torch.pow(max, 2), torch.nn.functional.mse_loss(input, target)) ) + def style_loss( input: Tensor, target: Tensor, @@ -254,6 +332,7 @@ def style_loss( return ((_style_forward(phi_input, phi_output)) / F_p).sum() + def perceptual_loss( input: Tensor, target: Tensor, @@ -276,9 +355,11 @@ def perceptual_loss( / N_phi_p ).sum() + def change_dim(P: List[Tensor]) -> List[Tensor]: return [tensor.view(tensor.shape[0], tensor.shape[1], -1) for tensor in P] + def _style_forward(input_list: List[Tensor], gt_list: List[Tensor]) -> List[Tensor]: return Tensor( [ @@ -287,10 +368,12 @@ def _style_forward(input_list: List[Tensor], gt_list: List[Tensor]) -> List[Tens ] ) + def total_variance(input: Tensor) -> Tensor: return torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + torch.mean( torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]) ) + def kl_div(mu: Tensor, logvar: Tensor) -> Tensor: return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/lightorch/nn/transformer/attention.py b/lightorch/nn/transformer/attention.py index 3c76ac4..96fb0dc 100644 --- a/lightorch/nn/transformer/attention.py +++ b/lightorch/nn/transformer/attention.py @@ -88,9 +88,7 @@ def forward(self, input: Tensor) -> Tensor: class CrossAttention(nn.Module): - def __init__( - self, attention: _AttentionBase, method: str = "i i c" - ) -> None: + def __init__(self, attention: _AttentionBase, method: str = "i i c") -> None: super().__init__() self.attention = attention self.method = method.lower() @@ -99,11 +97,12 @@ def __init__( "c c i": lambda input, cross: self.attention(cross, cross, input), "i c c": lambda input, cross: self.attention(input, cross, cross), } - assert (method in self.valid), 'Not valid method' + assert method in self.valid, "Not valid method" def forward(self, input: Tensor, cross: Tensor) -> Tensor: return self.valid[self.method](input, cross) + class GroupedQueryAttention(_AttentionBase): def __init__( self, diff --git a/lightorch/nn/transformer/positional.py b/lightorch/nn/transformer/positional.py index ef988a6..5790a67 100644 --- a/lightorch/nn/transformer/positional.py +++ b/lightorch/nn/transformer/positional.py @@ -123,7 +123,7 @@ def forward(self, x: Tensor) -> Tensor: i / pow(10000, (2 * j) / embed_dim) ) x += pos_embedding.unsqueeze(0) - + return self.dropout(x) diff --git a/lightorch/nn/transformer/transformer.py b/lightorch/nn/transformer/transformer.py index 545dda0..530026a 100644 --- a/lightorch/nn/transformer/transformer.py +++ b/lightorch/nn/transformer/transformer.py @@ -10,6 +10,7 @@ """ + class _Transformer(nn.Module): def __init__(self, self_attention, cross_attention, ffn, postnorm, prenorm) -> None: super().__init__() @@ -20,40 +21,43 @@ def __init__(self, self_attention, cross_attention, ffn, postnorm, prenorm) -> N self.prenorm = prenorm if prenorm is not None else nn.Identity() def _apply_sublayer(self, input: Tensor, sublayer: nn.Module, *args) -> Tensor: - return residual_connection(input, lambda x: self.postnorm(sublayer(self.prenorm(x), *args))) + return residual_connection( + input, lambda x: self.postnorm(sublayer(self.prenorm(x), *args)) + ) def ffn(self, input: Tensor) -> Tensor: return self._apply_sublayer(input, self._ffn) - + def cross_attention(self, input: Tensor, cross: Tensor, is_causal) -> Tensor: return self._apply_sublayer(input, self._cross_attention, cross, is_causal) - + def self_attention(self, input: Tensor, is_causal: bool = False) -> Tensor: return self._apply_sublayer(input, self._self_attention, is_causal) - + class TransformerCell(_Transformer): def __init__( - self, - *, - self_attention: nn.Module = None, - cross_attention: nn.Module = None, - ffn: nn.Module = None, - prenorm: nn.Module = None, - postnorm: nn.Module = None + self, + *, + self_attention: nn.Module = None, + cross_attention: nn.Module = None, + ffn: nn.Module = None, + prenorm: nn.Module = None, + postnorm: nn.Module = None ) -> None: super().__init__(self_attention, cross_attention, ffn, postnorm, prenorm) + class Transformer(nn.Module): def __init__( - self, - embedding_layer: Optional[nn.Module] = None, - positional_encoding: Optional[nn.Module] = None, - encoder: Optional[nn.Module] = None, - decoder: Optional[nn.Module] = None, - fc: Optional[nn.Module] = None, - n_layers: int = 1 - ) -> None: + self, + embedding_layer: Optional[nn.Module] = None, + positional_encoding: Optional[nn.Module] = None, + encoder: Optional[nn.Module] = None, + decoder: Optional[nn.Module] = None, + fc: Optional[nn.Module] = None, + n_layers: int = 1, + ) -> None: super().__init__() self.embedding = embedding_layer self.pe = positional_encoding @@ -62,18 +66,18 @@ def __init__( else: self.encoder = False if decoder is not None: - self.decoder = nn.ModuleList([decoder for _ in range(n_layers)]) + self.decoder = nn.ModuleList([decoder for _ in range(n_layers)]) else: self.decoder = False self.n_layers = n_layers self.fc = fc - + def forward(self, **kwargs) -> Tensor: if self.embedding is not None: out = self.embedding(**kwargs) if self.pe is not None: out = self.pe(out) - + if self.encoder and self.decoder: hist: List = [] for encoder in self.encoder: @@ -81,47 +85,50 @@ def forward(self, **kwargs) -> Tensor: hist.append(out) for cross, decoder in zip(hist, self.decoder): - out = decoder(**kwargs, cross = cross) - + out = decoder(**kwargs, cross=cross) + out = self.fc(out) elif self.encoder: for encoder in self.encoder: out = encoder(out) - + else: for decoder in self.decoder: out = decoder(out) - + return out + class CrossTransformer(nn.Module): def __init__(self, *cells, n_layers: int, fc: nn.Module) -> None: - assert (len(cells) == 2), 'Must be 2 transformer cells' + assert len(cells) == 2, "Must be 2 transformer cells" self.cells = nn.ModuleList([cells for _ in range(n_layers)]) self.fc = fc self.n_layers = n_layers - def _single_forward(self, cells: Sequence[TransformerCell], first_args: Sequence, second_args: Sequence) -> Tensor: + def _single_forward( + self, + cells: Sequence[TransformerCell], + first_args: Sequence, + second_args: Sequence, + ) -> Tensor: out0 = cells[0].self_attention(*first_args) out1 = cells[1].self_attention(*second_args) out0 = cells[0].cross_attention(out0, out1) out1 = cells[1].cross_attention(out1, out0) - + out0 = cells[0].ffn(*out0) out1 = cells[1].ffn(*out1) - return (out0, ), (out1, ) + return (out0,), (out1,) + def forward(self, first_inputs: Sequence, second_inputs: Sequence) -> Tensor: for layer, cells in enumerate(self.cells): - first_inputs, second_inputs = self._single_forward(cells, layer, first_inputs, second_inputs) - - + first_inputs, second_inputs = self._single_forward( + cells, layer, first_inputs, second_inputs + ) -__all__ = [ - 'Transformer', - 'TransformerCell', - 'CrossTransformer' -] \ No newline at end of file +__all__ = ["Transformer", "TransformerCell", "CrossTransformer"] diff --git a/requirements.sh b/requirements.sh old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..545e898 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,132 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile requirements.in +# +aiohttp==3.9.5 + # via fsspec +aiosignal==1.3.1 + # via aiohttp +attrs==23.2.0 + # via aiohttp +einops==0.8.0 + # via -r requirements.in +filelock==3.14.0 + # via + # torch + # triton +frozenlist==1.4.1 + # via + # aiohttp + # aiosignal +fsspec[http]==2024.5.0 + # via + # lightning + # pytorch-lightning + # torch +idna==3.7 + # via yarl +jinja2==3.1.4 + # via torch +lightning==2.2.5 + # via -r requirements.in +lightning-utilities==0.11.2 + # via + # lightning + # pytorch-lightning + # torchmetrics +markupsafe==2.1.5 + # via jinja2 +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via + # aiohttp + # yarl +networkx==3.3 + # via torch +numpy==1.26.4 + # via + # lightning + # pytorch-lightning + # torchmetrics + # torchvision +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.5.40 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +packaging==24.0 + # via + # lightning + # lightning-utilities + # pytorch-lightning + # torchmetrics +pillow==10.3.0 + # via torchvision +pytorch-lightning==2.2.5 + # via lightning +pyyaml==6.0.1 + # via + # lightning + # pytorch-lightning +sympy==1.12 + # via torch +torch==2.3.0 + # via + # -r requirements.in + # lightning + # pytorch-lightning + # torchmetrics + # torchvision +torchmetrics==1.4.0.post0 + # via + # lightning + # pytorch-lightning +torchvision==0.18.0 + # via -r requirements.in +tqdm==4.66.4 + # via + # -r requirements.in + # lightning + # pytorch-lightning +triton==2.3.0 + # via torch +typing-extensions==4.12.0 + # via + # lightning + # lightning-utilities + # pytorch-lightning + # torch +yarl==1.9.4 + # via aiohttp + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/tests/test_nn.py b/tests/test_nn.py index f75e7bd..144f566 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -18,46 +18,52 @@ def test_tv() -> None: loss = TV(randint) - loss(input = input) - + loss(input=input) + def test_style() -> None: loss = StyleLoss(model, input, randint) - loss(input = input, target = target, feature_extractor = False) + loss(input=input, target=target, feature_extractor=False) def test_perc() -> None: loss = PerceptualLoss(model, input, randint) - loss(input = input, target = target, feature_extractor = False) + loss(input=input, target=target, feature_extractor=False) + def test_psnr() -> None: loss = PeakNoiseSignalRatio(1, randint) - loss(input = input, target = target) + loss(input=input, target=target) + # Integration tests def test_lagrange() -> None: loss = LagrangianFunctional( - nn.MSELoss(), nn.MSELoss(), nn.MSELoss(), lambd = Tensor([randint for _ in range(2)]) + nn.MSELoss(), + nn.MSELoss(), + nn.MSELoss(), + lambd=Tensor([randint for _ in range(2)]), ) loss(input=input, target=target) + def test_loss() -> None: - loss = Loss( - TV(randint), PeakNoiseSignalRatio(1, randint) - ) - loss(input = input, target = target) + loss = Loss(TV(randint), PeakNoiseSignalRatio(1, randint)) + loss(input=input, target=target) + def test_elbo() -> None: - loss = ELBO( - randint, PeakNoiseSignalRatio(1, randint) - ) - loss(input = input, target = target, mu = mu, logvar = logvar) + loss = ELBO(randint, PeakNoiseSignalRatio(1, randint)) + loss(input=input, target=target, mu=mu, logvar=logvar) + def test_fourier() -> None: input: Tensor = create_inputs(1, 3, 256, 256) + def test_partial() -> None: raise NotImplementedError + def test_kan() -> None: raise NotImplementedError diff --git a/tests/test_supervised.py b/tests/test_supervised.py index c16222c..8375c87 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -5,36 +5,36 @@ import random from torch import Tensor, nn + in_size: int = 32 input: Tensor = create_inputs(1, in_size) randint: int = random.randint(-100, 100) -#Integrated test +# Integrated test + class SupModel(Module): def __init__(self, **hparams) -> None: super().__init__(**hparams) - self.criterion = PeakNoiseSignalRatio(1, randint) + self.criterion = PeakNoiseSignalRatio(1, randint) self.model = Model(in_size) def forward(self, input: Tensor) -> Tensor: return self.model(input) + def objective(): pass + def test_supervised() -> None: htuning( - model_class = SupModel, + model_class=SupModel, hparam_objective=objective, - datamodule = DataModule, - valid_metrics = 'MSE', - datamodule_kwargs= dict( - pin_memory = False, - num_workers = 1, - batch_size = 1 - ), - directions = 'minimize', - precision = 'high', - n_trials = 10, - trianer_kwargs = dict(fast_dev_run = True) + datamodule=DataModule, + valid_metrics="MSE", + datamodule_kwargs=dict(pin_memory=False, num_workers=1, batch_size=1), + directions="minimize", + precision="high", + n_trials=10, + trianer_kwargs=dict(fast_dev_run=True), ) diff --git a/tests/utils.py b/tests/utils.py index ffee27e..48c6e42 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,44 +3,51 @@ from torch import Tensor, nn from lightning.pytorch import LightningDataModule from torch.utils.data import Dataset, DataLoader + + def create_inputs(*size) -> Tensor: return torch.randn(*size) + class Model(nn.Sequential): def __init__(self, in_channels) -> None: - super().__init__( - nn.Linear(in_channels, 10), - nn.Linear(10, 1) - ) + super().__init__(nn.Linear(in_channels, 10), nn.Linear(10, 1)) + def forward(self, input): return super().forward(input) + # Add feature extractor class Data(Dataset): - def __init__(self,) -> None: + def __init__( + self, + ) -> None: pass - + + class DataModule(LightningDataModule): def __init__(self, batch_size: int, pin_memory: bool = False, num_workers: int = 1): self.batch_size = batch_size self.pin_memory = pin_memory self.num_workers = num_workers + def setup(self) -> None: self.train_ds + def train_dataloader(self) -> DataLoader: return DataLoader( Data(), self.batch_size, False, num_workers=self.num_workers, - pin_memory=self.pin_memory + pin_memory=self.pin_memory, ) - + def val_dataloader(self) -> DataLoader: return DataLoader( Data(), self.batch_size * 2, False, num_workers=self.num_workers, - pin_memory=self.pin_memory - ) \ No newline at end of file + pin_memory=self.pin_memory, + )