diff --git a/torch_dreams/__init__.py b/torch_dreams/__init__.py index 0c7cefe..68f1dd3 100644 --- a/torch_dreams/__init__.py +++ b/torch_dreams/__init__.py @@ -3,7 +3,18 @@ from .model_bunch import * from .tests import * from .auto_image_param import AutoImageParam +from .auto_series_param import AutoSeriesParam +from .base_series_param import BaseSeriesParam from .custom_image_param import CustomImageParam +from .custom_series_param import CustomSeriesParam +from .masked_image_param import MaskedImageParam + +import torch_dreams.image_transforms as image_transforms +import torch_dreams.series_transforms as series_transforms +import torch_dreams.transforms as transforms + +from . import series_transforms + __version__ = "4.0.0" @@ -12,8 +23,18 @@ "utils", "model_bunch", "auto_image_param", + "AutoImageParam", + "auto_series_param", + "AutoSeriesParam", + "base_series_param", + "BaseSeriesParam", "custom_image_param", + "CustomImageParam", + "custom_series_param", + "CustomSeriesParam", "masked_image_param", + "MaskedImageParam" "image_transforms", - "transforms" + "series_transforms", + "transforms", ] diff --git a/torch_dreams/auto_series_param.py b/torch_dreams/auto_series_param.py new file mode 100644 index 0000000..99e783d --- /dev/null +++ b/torch_dreams/auto_series_param.py @@ -0,0 +1,75 @@ +import torch + +from .base_series_param import BaseSeriesParam +from .utils import init_series_param +from .utils import fft_to_series + + +class AutoSeriesParam(BaseSeriesParam): + """Trainable series parameter which can be used to activate + different parts of a neural net + + Args: + length (int): The sequence length of the series + channels (int): The number of channels of the series + + device (str): 'cpu' or 'cuda' + standard_deviation (float): Standard deviation of the series initiated + in the frequency domain. + batch_size (int): The batch size of the input tensor. If batch_size=1, + no batch dimension is expected. + """ + + def __init__( + self, + length, + channels, + device, + standard_deviation, + normalize_mean=None, + normalize_std=None, + batch_size: int = 1, + seed: int = 42, + ): + # odd length is resized to even with one extra element + if length % 2 == 1: + param = init_series_param( + batch_size=batch_size, + channels=channels, + length=length + 1, + sd=standard_deviation, + seed=seed, + device=device, + ) + else: + param = init_series_param( + batch_size=batch_size, + channels=channels, + length=length, + sd=standard_deviation, + seed=seed, + device=device, + ) + + super().__init__( + batch_size=batch_size, + channels=channels, + length=length, + param=param, + normalize_mean=normalize_mean, + normalize_std=normalize_std, + device=device, + ) + + self.standard_deviation = standard_deviation + + def postprocess(self, device): + series = fft_to_series( + channels=self.channels, + length=self.length, + series_parameter=self.param, + device=device, + ) + #TODO: img = lucid_colorspace_to_rgb(t=img, device=device) + series = torch.sigmoid(series) + return series diff --git a/torch_dreams/base_series_param.py b/torch_dreams/base_series_param.py new file mode 100644 index 0000000..7bd4b45 --- /dev/null +++ b/torch_dreams/base_series_param.py @@ -0,0 +1,127 @@ +import torch + + +class BaseSeriesParam(torch.nn.Module): + def __init__(self, batch_size, channels, length, param, normalize_mean, normalize_std, device): + super().__init__() + + self.batch_size = batch_size + self.channels = channels + self.length = length + + if normalize_mean is None: + normalize_mean = torch.FloatTensor([0] * channels) + self.normalize_mean = normalize_mean + + if normalize_std is None: + normalize_std=torch.FloatTensor([1] * channels) + self.normalize_std = normalize_std + + self.param = param + self.param.requires_grad_() + + self.device = device + + self.optimizer = None + + def forward(self, device): + """This is what the model gets, should be processed and normalized with the right values + + The model gets: self.normalize(self.postprocess(self.param)) + + Raises: + NotImplementedError: Implemented below, you're in the base class. + """ + + if self.batch_size == 1: + return self.normalize(self.postprocess(device=device), device=device) + else: + return torch.cat( + [ + self.normalize(self.postprocess(device=device), device=device) + for i in range(self.batch_size) + ], + dim=0, + ) + + def postprocess(self): + """Moves the series from the frequency domain to Spatial (Visible to the eyes) + + Raises: + NotImplementedError: Implemented below, you're in the base class. + """ + raise NotImplementedError + + def normalize(self, x, device='cuda'): + """Normalizing wrapper""" + return ( + (x - self.normalize_mean[..., None].to(device)) + / self.normalize_std[..., None].to(device) + ) + + def denormalize(self, x, device='cuda'): + """Denormalizing wrapper.""" + return ( + x * self.normalize_std[..., None].to(device) + + self.normalize_mean[..., None].to(device) + ) + + def fetch_optimizer(self, params_list, optimizer=None, lr=1e-3, weight_decay=0.0): + if optimizer is not None: + optimizer = optimizer(params_list, lr=lr, weight_decay=weight_decay) + else: + optimizer = torch.optim.AdamW(params_list, lr=lr, weight_decay=weight_decay) + return optimizer + + def get_optimizer(self, lr, weight_decay): + self.optimizer = self.fetch_optimizer( + params_list=[self.param], lr=lr, weight_decay=weight_decay + ) + + def clip_grads(self, grad_clip=1.0): + return torch.nn.utils.clip_grad_norm_(self.param, grad_clip) + + def to_cl_tensor(self, device="cpu"): + """Return CL series tensor (channels, length). + + Args: + device (str): The device to operate on ('cpu' or 'cuda'). + + Returns: + torch.Tensor + """ + t = self.forward(device=device)[0].detach() + return t + + def to_lc_tensor(self, device="cpu"): + """Return LC series tensor (length, channels). + + Args: + device (str): The device to operate on ('cpu' or 'cuda'). + + Returns: + torch.Tensor + """ + t = self.forward(device=device)[0].permute(1, 0).detach() + return t + + def __array__(self): + """Generally used for plt.imshow(), converts the series parameter to a NCL numpy array + + Returns: + numpy.ndarray + """ + return self.to_cl_tensor().numpy() + + def save(self, filename): + """Save an image_param as an image. Uses PIL to save the image + + usage: + + image_param.save(filename = 'my_image.jpg') + + Args: + filename (str): image.jpg + """ + tensor = self.to_cl_tensor() + torch.save(tensor, filename) diff --git a/torch_dreams/custom_series_param.py b/torch_dreams/custom_series_param.py new file mode 100644 index 0000000..859c5a0 --- /dev/null +++ b/torch_dreams/custom_series_param.py @@ -0,0 +1,129 @@ +from .base_series_param import BaseSeriesParam + +import numpy as np +import torch + +from .utils import ( + lucid_colorspace_to_rgb, + normalize, + get_fft_scale_custom_series, + cl_series_to_fft_param, + fft_to_series +) + + +class CustomSeriesParam(BaseSeriesParam): + """FFT parameterization for custom series. + + Works well with: + * lower learning rates (3e-4) + * gradients clipped to (0, 0.1) + * weight decay (1e-1) + + Args: + series (torch.tensor): input tensor with shape [channels, length]. + device (str): 'cuda' or 'cpu' + + Example: + ``` + series = torch.ones((1, 2, 100)) + param = custom_series_param(series=series, device='cuda') + + result = dreamy_boi.render( + image_parameter=param, + layers = [model.Mixed_6c], + lr = 3e-4, + grad_clip = 0.1, + weight_decay= 1e-1 + ) + ``` + """ + def __init__( + self, + series, + device, + #channel_correlation_matrix, + normalize_mean=None, + normalize_std=None, + ): + batch_size = series.shape[0] + channels = series.shape[1] + length = series.shape[2] + + super().__init__( + batch_size=batch_size, + channels=channels, + length=length, + param=series, # we use set_param in the next step + normalize_mean=normalize_mean, + normalize_std=normalize_std, + device=device, + ) + + channel_correlation_matrix = get_normalized_correlation_matrix(channels) + + self.set_param(series, channel_correlation_matrix, device=device) + + def postprocess(self, device): + out = fft_to_series( + channels=self.channels, + length=self.length, + series_parameter=self.param, + device=device, + ) + out = lucid_colorspace_to_rgb(t=out, device=device).clamp(0,1) + return out + + + def set_param(self, tensor, channel_correlation_matrix, device): + """sets an NCL tensor as the parameter in the frequency domain, + useful for transforming custom series between iterations. + + Use in combination with `self.to_ncl_tensor()` like: + + ``` + a = self.to_cl_tensor() + # do something with a + t = transforms.Compose([ + transforms.RandomScale(0,5, 1.2) + ]) + a = t(a) + #set as parameter again + self.set_param(a) + ``` + + WARNING: tensor should have values clipped between 0 and 1. + + Args: + tensor (torch.tensor): input tensor with shape [1,channels, length] and values clipped between 0,1. + """ + assert len(tensor.shape) == 3 + assert tensor.shape[0] == 1 + + self.tensor = tensor + + batch_size = tensor.shape[0] + channels = tensor.shape[1] + length = tensor.shape[2] + + scale = get_fft_scale_custom_series(length=length, device=device) + # TODO: denormalize + #fft_param = cl_series_to_fft_param(self.denormalize(tensor.squeeze(0)), device=device) + fft_param = cl_series_to_fft_param(tensor, channel_correlation_matrix=channel_correlation_matrix, device=device) + self.param = fft_param / scale + + self.param.requires_grad_() + + self.batch_size = batch_size + self.channels = channels + self.length = length + self.device = device + + +def get_normalized_correlation_matrix(channels): + # TODO: these values must be passed by the user + correlation_svd_sqrt = np.random.rand(channels, channels).astype(np.float32) + + max_norm_svd_sqrt = np.max(np.linalg.norm(correlation_svd_sqrt, axis=0)) + correlation_normalized = torch.tensor(correlation_svd_sqrt / max_norm_svd_sqrt) + return correlation_normalized diff --git a/torch_dreams/series_transforms.py b/torch_dreams/series_transforms.py new file mode 100644 index 0000000..f6f5a5c --- /dev/null +++ b/torch_dreams/series_transforms.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from numbers import Number +from collections.abc import Sequence + +import torch + + +class RandomSeriesTranslate(torch.nn.Module): + + def __init__( + self, + translate: float, + fill: Number | Sequence | None =0, + seed=42, + ): + super().__init__() + + if not isinstance(translate, Number): + raise TypeError(f"translate should be a number but is {type(translate)}.") + if not (0.0 <= translate <= 1.0): + raise ValueError("translation value should be between 0 and 1") + self.translate = translate + + if fill is not None and not isinstance(fill, (Sequence, Number)): + raise TypeError("Fill must be either a sequence, a number, or None.") + self.fill = fill + + self.seed = seed + self.generator = torch.Generator() + self.generator.manual_seed(seed) + + + def forward(self, series): + fill = self.fill + channels, length = _get_series_dimensions(series) + + max_shift = float(self.translate * length) + shift = int(round(torch.empty(1).uniform_(-max_shift, max_shift, generator=self.generator).item())) + + # if fill is None, the overflow values are rolled over to the other end of the series + out = torch.roll(series, shift, dims=-1) + + if fill is not None and not isinstance(fill, Number): + # fill must be a sequence, check that length matches + if shift > len(fill): + raise RuntimeError(f'random shift greater than fill length ({shift} > {len(fill)})') + + fill = torch.FloatTensor(fill)[:shift] + + if fill is not None and shift > 0: + out[..., :shift] = fill + + if fill is not None and shift < 0: + out[..., -shift:] = fill + + return out + + +class RandomSeriesScale(torch.nn.Module): + + def __init__( + self, + min_scale: float, + max_scale: float, + seed=42, + ): + super().__init__() + + if not isinstance(min_scale, Number): + raise TypeError(f"min_scale should be a number but is {type(min_scale)}.") + if not isinstance(max_scale, Number): + raise TypeError(f"max_scale should be a number but is {type(max_scale)}.") + + self.min_scale = min_scale + self.max_scale = max_scale + + self.seed = seed + self.generator = torch.Generator() + self.generator.manual_seed(seed) + + def forward(self, series): + scale = torch.empty(1).uniform_(self.min_scale, self.max_scale, generator=self.generator).to(series.device) + out = series * scale + return out + + +def _get_series_dimensions(series): + # TODO: _assert_image_tensor(img) + if series.ndim == 1: + channels = 1 + length = len(series) + else: + channels = series.shape[-2] + length = series.shape[-1] + return [channels, length] diff --git a/torch_dreams/tests/test_series.py b/torch_dreams/tests/test_series.py new file mode 100644 index 0000000..b89e32c --- /dev/null +++ b/torch_dreams/tests/test_series.py @@ -0,0 +1,190 @@ +import os + +import numpy as np +import pytest +import torch +import torchvision.transforms as transforms + +from torch_dreams import Dreamer +from torch_dreams.auto_series_param import AutoSeriesParam +from torch_dreams.custom_series_param import CustomSeriesParam +from torch_dreams.series_transforms import RandomSeriesScale +from torch_dreams.series_transforms import RandomSeriesTranslate +from torch_dreams.transforms import random_resize + + +class CNN1d(torch.nn.Module): + def __init__(self, in_channels, out_features, channels=25): + super(CNN1d, self).__init__() + self.conv1 = torch.nn.Conv1d(in_channels, channels, kernel_size=5, stride=2, padding=1) + self.conv2 = torch.nn.Conv1d(channels, channels, kernel_size=3, stride=2, padding=1) + self.conv3 = torch.nn.Conv1d(channels, channels, kernel_size=3, stride=2, padding=1) + self.flatten = torch.nn.Flatten() + self.fc = torch.nn.LazyLinear(out_features) + + def forward(self, x): + h1 = self.conv1(x).relu() + h2 = self.conv2(h1).relu() + h3 = self.conv3(h2).relu() + return self.fc(self.flatten(h3)) + + +@pytest.mark.parametrize("out_features", [3, 10, 21]) +@pytest.mark.parametrize("sequence_length", [11, 20, 40, 99, 1000]) +@pytest.mark.parametrize("channels", [1, 2, 10, 59]) +@pytest.mark.parametrize("batch_size", [1, 16, 64, 256]) +def test_cnn_model_outputs_correct_shape(sequence_length, channels, out_features, batch_size): + model = CNN1d(in_channels=channels, out_features=out_features) + x = torch.zeros((batch_size, channels, sequence_length)) + + assert model(x).shape == (batch_size, out_features) + + +@pytest.mark.parametrize("iters", [1, 2, 10, 20]) +@pytest.mark.parametrize("sequence_length", [11, 20, 40, 99, 1000]) +@pytest.mark.parametrize("channels", [1, 2, 10, 59]) +@pytest.mark.parametrize("batch_size", [1]) +def test_auto_series_param(iters, sequence_length, channels, batch_size): + model = CNN1d(in_channels=channels, out_features=10) + + # Prepare lazy modules. + x = torch.zeros((batch_size, channels, sequence_length)) + y = model(x) + + series_param = AutoSeriesParam( + length=sequence_length, + channels=channels, + device="cpu", + standard_deviation=0.01, + batch_size=batch_size, + ) + + series_transforms = transforms.Compose( + [ + RandomSeriesTranslate(0.1), + RandomSeriesScale(0.5, 1.2), + ] + ) + + dreamy_boi = Dreamer(model=model, device='cpu', quiet=False) + dreamy_boi.set_custom_transforms(series_transforms) + + result = dreamy_boi.render( + layers=[model.conv1], + iters=iters, + image_parameter=series_param, + ) + + assert isinstance(result, AutoSeriesParam), "should be an instance of auto_series_param" + assert isinstance(result.__array__(), np.ndarray) + assert isinstance(result.to_cl_tensor(), torch.Tensor), "should be a torch.Tensor" + assert isinstance(result.to_lc_tensor(), torch.Tensor), "should be a torch.Tensor" + assert result.to_cl_tensor().shape == x[0].shape + +def test_auto_series_save(iters=2, sequence_length=40, channels=2, batch_size=1): + model = CNN1d(in_channels=channels, out_features=10) + + # Prepare lazy modules. + x = torch.zeros((batch_size, channels, sequence_length)) + y = model(x) + + series_param = AutoSeriesParam( + length=sequence_length, + channels=channels, + device="cpu", + standard_deviation=0.01, + ) + + series_transforms = transforms.Compose( + [ + RandomSeriesTranslate(0.1), + RandomSeriesScale(0.5, 1.2), + ] + ) + + dreamy_boi = Dreamer(model=model, device='cpu', quiet=False) + dreamy_boi.set_custom_transforms(series_transforms) + + result = dreamy_boi.render( + layers=[model.conv1], + iters=iters, + image_parameter=series_param, + ) + + filename = f"test_ts_single_model.jpg" + result.save(filename=filename) + assert os.path.exists(filename) + os.remove(filename) + + +@pytest.mark.parametrize("iters", [1, 2, 10, 20]) +@pytest.mark.parametrize("sequence_length", [11, 20, 40, 99, 1000]) +@pytest.mark.parametrize("channels", [1, 2, 10, 59]) +@pytest.mark.parametrize("batch_size", [1]) +def test_custom_series_param(iters, sequence_length, channels, batch_size): + model = CNN1d(in_channels=channels, out_features=10) + + # Prepare lazy modules. + x = torch.zeros((batch_size, channels, sequence_length)) + y = model(x) + + series_param = CustomSeriesParam( + series=x, + device="cpu", + ) + + series_transforms = transforms.Compose( + [ + RandomSeriesTranslate(0.1), + RandomSeriesScale(0.5, 1.2), + ] + ) + + dreamy_boi = Dreamer(model=model, device='cpu', quiet=False) + dreamy_boi.set_custom_transforms(series_transforms) + + result = dreamy_boi.render( + layers=[model.conv1], + iters=iters, + image_parameter=series_param, + ) + + assert isinstance(result, AutoSeriesParam), "should be an instance of auto_series_param" + assert isinstance(result.__array__(), np.ndarray) + assert isinstance(result.to_cl_tensor(), torch.Tensor), "should be a torch.Tensor" + assert isinstance(result.to_lc_tensor(), torch.Tensor), "should be a torch.Tensor" + assert result.to_cl_tensor().shape == x[0].shape + + +def test_custom_series_save(iters=2, sequence_length=40, channels=2, batch_size=1): + model = CNN1d(in_channels=channels, out_features=10) + + # Prepare lazy modules. + x = torch.zeros((batch_size, channels, sequence_length)) + y = model(x) + + series_param = CustomSeriesParam( + series=x, + device="cpu", + ) + + series_transforms = transforms.Compose( + [ + RandomSeriesTranslate(0.1), + RandomSeriesScale(0.5, 1.2), + ] + ) + + dreamy_boi = Dreamer(model=model, device='cpu', quiet=False) + dreamy_boi.set_custom_transforms(series_transforms) + + result = dreamy_boi.render( + layers=[model.conv1], + iters=iters, + image_parameter=series_param, + ) + + filename = f"test_ts_single_model.jpg" + result.save(filename=filename) + assert os.path.exists(filename) + os.remove(filename) diff --git a/torch_dreams/utils.py b/torch_dreams/utils.py index a03e0ec..0a7e7be 100644 --- a/torch_dreams/utils.py +++ b/torch_dreams/utils.py @@ -44,6 +44,25 @@ def init_image_param(height, width, sd=0.01, device="cuda"): return spectrum_t +def init_series_param(batch_size, channels, length, sd=0.01, seed=42, device="cuda"): + """Initializes a series parameter in the frequency domain + + Args: + batch_size (int): batch size of series + channels (int): number of channels of series + length (int): length of series + sd (float, optional): Standard deviation of step values. Defaults to 0.01. + device (str): 'cpu' or 'cuda' + + Returns: + torch.tensor: series param to backpropagate on + """ + np.random.seed(seed=seed) + buffer = np.random.normal(size=(batch_size, channels, length), scale=sd).astype(np.float32) + spectrum_t = tensor(buffer).float().to(device) + return spectrum_t + + def get_fft_scale(h, w, decay_power=0.75, device="cuda"): d = 0.5**0.5 # set center frequency scale to 1 fy = np.fft.fftfreq(h, d=d)[:, None] @@ -61,6 +80,23 @@ def get_fft_scale(h, w, decay_power=0.75, device="cuda"): return scale +def get_fft_series_scale(length: int, decay_power: float = 0.75, device: str = "cuda"): + d = 0.5**0.5 # set center frequency scale to 1 + + if length % 2 == 1: + fx = np.fft.rfftfreq(length, d=d)[: (length + 1) // 2] + else: + fx = np.fft.rfftfreq(length, d=d)[: length // 2] + + freqs = (fx * fx) ** decay_power + + scale = 1.0 / np.maximum(freqs, 1.0 / (length * d)) + scale = tensor(scale).float().to(device) + + return scale + + + def fft_to_rgb(height, width, image_parameter, device="cuda"): """convert image param to NCHW @@ -100,6 +136,40 @@ def fft_to_rgb(height, width, image_parameter, device="cuda"): return t +def fft_to_series(channels, length, series_parameter, device="cuda"): + """convert series param to NCL + + WARNING: torch v1.7.0 works differently from torch v1.8.0 on fft. + torch-dreams supports ONLY 1.8.x + + Latest docs: https://pytorch.org/docs/stable/fft.html + + Also refer: + https://github.com/pytorch/pytorch/issues/49637 + + Args: + channels (int): number of channels of series + length (int): length of series + series_parameter (auto_series_param): auto_series_param.param + + Returns: + torch.tensor: NCHW tensor + + """ + scale = get_fft_series_scale(length, device=device).to(series_parameter.device) + + if length % 2 == 1: + series_parameter = series_parameter.reshape(1, channels, (length + 1) // 2, 2) + else: + series_parameter = series_parameter.reshape(1, channels, length // 2, 2) + + series_parameter = torch.complex(series_parameter[..., 0], series_parameter[..., 1]) + t = scale * series_parameter + t = torch.fft.irfft(t, n=length, norm="ortho") + + return t + + def lucid_colorspace_to_rgb(t, device="cuda"): t_flat = t.permute(0, 2, 3, 1) @@ -128,6 +198,16 @@ def get_fft_scale_custom_img(h, w, decay_power=0.75, device="cuda"): return scale +def get_fft_scale_custom_series(length, decay_power=0.75, device="cuda"): + d = 0.5**0.5 # set center frequency scale to 1 + fx = np.fft.rfftfreq(length, d=d)[: (length // 2) + 1] + freqs = (fx * fx) ** decay_power + scale = 1.0 / np.maximum(freqs, 1.0 / (length * d)) + scale = torch.tensor(scale).float().to(device) + + return scale + + def denormalize(x): return x.float() * Constants.imagenet_std[..., None, None].to( @@ -143,6 +223,14 @@ def rgb_to_lucid_colorspace(t, device="cuda"): return t +def series_space_to_lucid_space(t, channel_correlation_matrix, device="cuda"): + t_flat = t.permute(0, 2, 1) + inverse = torch.inverse(channel_correlation_matrix.T.to(device)) + t_flat = torch.matmul(t_flat.to(device), inverse) + t = t_flat.permute(0, 2, 1) + return t + + def chw_rgb_to_fft_param(x, device): im_tensor = torch.tensor(x).unsqueeze(0).float() @@ -152,6 +240,25 @@ def chw_rgb_to_fft_param(x, device): return x +def cl_series_to_fft_param(x, channel_correlation_matrix, device): + length = x.shape[-1] + series_tensor = torch.tensor(x).float() + + x = series_space_to_lucid_space( + series_tensor, + channel_correlation_matrix=channel_correlation_matrix, + device=device, + ) + + print(x.shape) + + x = torch.fft.rfft(x, n=length, norm="ortho") + + print(x.shape) + + return x + + def fft_to_rgb_custom_img(height, width, image_parameter, device="cuda"): scale = get_fft_scale_custom_img(height, width, device=device).to( @@ -164,5 +271,5 @@ def fft_to_rgb_custom_img(height, width, image_parameter, device="cuda"): sub_version = int(version[1]) t = torch.fft.irfft2(t, s=(height, width), norm="ortho") - + return t