Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add auto_series_param and custom_series_param #58

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
23 changes: 22 additions & 1 deletion torch_dreams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
from .model_bunch import *
from .tests import *
from .auto_image_param import AutoImageParam
from .auto_series_param import AutoSeriesParam
from .base_series_param import BaseSeriesParam
from .custom_image_param import CustomImageParam
from .custom_series_param import CustomSeriesParam
from .masked_image_param import MaskedImageParam

import torch_dreams.image_transforms as image_transforms
import torch_dreams.series_transforms as series_transforms
import torch_dreams.transforms as transforms

from . import series_transforms


__version__ = "4.0.0"

Expand All @@ -12,8 +23,18 @@
"utils",
"model_bunch",
"auto_image_param",
"AutoImageParam",
"auto_series_param",
"AutoSeriesParam",
"base_series_param",
"BaseSeriesParam",
"custom_image_param",
"CustomImageParam",
"custom_series_param",
"CustomSeriesParam",
"masked_image_param",
"MaskedImageParam"
"image_transforms",
"transforms"
"series_transforms",
"transforms",
]
75 changes: 75 additions & 0 deletions torch_dreams/auto_series_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch

from .base_series_param import BaseSeriesParam
from .utils import init_series_param
from .utils import fft_to_series


class AutoSeriesParam(BaseSeriesParam):
"""Trainable series parameter which can be used to activate
different parts of a neural net

Args:
length (int): The sequence length of the series
channels (int): The number of channels of the series

device (str): 'cpu' or 'cuda'
standard_deviation (float): Standard deviation of the series initiated
in the frequency domain.
batch_size (int): The batch size of the input tensor. If batch_size=1,
no batch dimension is expected.
"""

def __init__(
self,
length,
channels,
device,
standard_deviation,
normalize_mean=None,
normalize_std=None,
batch_size: int = 1,
seed: int = 42,
):
# odd length is resized to even with one extra element
if length % 2 == 1:
param = init_series_param(
batch_size=batch_size,
channels=channels,
length=length + 1,
sd=standard_deviation,
seed=seed,
device=device,
)
else:
param = init_series_param(
batch_size=batch_size,
channels=channels,
length=length,
sd=standard_deviation,
seed=seed,
device=device,
)

super().__init__(
batch_size=batch_size,
channels=channels,
length=length,
param=param,
normalize_mean=normalize_mean,
normalize_std=normalize_std,
device=device,
)

self.standard_deviation = standard_deviation

def postprocess(self, device):
series = fft_to_series(
channels=self.channels,
length=self.length,
series_parameter=self.param,
device=device,
)
#TODO: img = lucid_colorspace_to_rgb(t=img, device=device)
series = torch.sigmoid(series)
return series
127 changes: 127 additions & 0 deletions torch_dreams/base_series_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch


class BaseSeriesParam(torch.nn.Module):
def __init__(self, batch_size, channels, length, param, normalize_mean, normalize_std, device):
super().__init__()

self.batch_size = batch_size
self.channels = channels
self.length = length

if normalize_mean is None:
normalize_mean = torch.FloatTensor([0] * channels)
self.normalize_mean = normalize_mean

if normalize_std is None:
normalize_std=torch.FloatTensor([1] * channels)
self.normalize_std = normalize_std

self.param = param
self.param.requires_grad_()

self.device = device

self.optimizer = None

def forward(self, device):
"""This is what the model gets, should be processed and normalized with the right values

The model gets: self.normalize(self.postprocess(self.param))

Raises:
NotImplementedError: Implemented below, you're in the base class.
"""

if self.batch_size == 1:
return self.normalize(self.postprocess(device=device), device=device)
else:
return torch.cat(
[
self.normalize(self.postprocess(device=device), device=device)
for i in range(self.batch_size)
],
dim=0,
)

def postprocess(self):
"""Moves the series from the frequency domain to Spatial (Visible to the eyes)

Raises:
NotImplementedError: Implemented below, you're in the base class.
"""
raise NotImplementedError

def normalize(self, x, device='cuda'):
"""Normalizing wrapper"""
return (
(x - self.normalize_mean[..., None].to(device))
/ self.normalize_std[..., None].to(device)
)

def denormalize(self, x, device='cuda'):
"""Denormalizing wrapper."""
return (
x * self.normalize_std[..., None].to(device)
+ self.normalize_mean[..., None].to(device)
)

def fetch_optimizer(self, params_list, optimizer=None, lr=1e-3, weight_decay=0.0):
if optimizer is not None:
optimizer = optimizer(params_list, lr=lr, weight_decay=weight_decay)
else:
optimizer = torch.optim.AdamW(params_list, lr=lr, weight_decay=weight_decay)
return optimizer

def get_optimizer(self, lr, weight_decay):
self.optimizer = self.fetch_optimizer(
params_list=[self.param], lr=lr, weight_decay=weight_decay
)

def clip_grads(self, grad_clip=1.0):
return torch.nn.utils.clip_grad_norm_(self.param, grad_clip)

def to_cl_tensor(self, device="cpu"):
"""Return CL series tensor (channels, length).

Args:
device (str): The device to operate on ('cpu' or 'cuda').

Returns:
torch.Tensor
"""
t = self.forward(device=device)[0].detach()
return t

def to_lc_tensor(self, device="cpu"):
"""Return LC series tensor (length, channels).

Args:
device (str): The device to operate on ('cpu' or 'cuda').

Returns:
torch.Tensor
"""
t = self.forward(device=device)[0].permute(1, 0).detach()
return t

def __array__(self):
"""Generally used for plt.imshow(), converts the series parameter to a NCL numpy array

Returns:
numpy.ndarray
"""
return self.to_cl_tensor().numpy()

def save(self, filename):
"""Save an image_param as an image. Uses PIL to save the image

usage:

image_param.save(filename = 'my_image.jpg')

Args:
filename (str): image.jpg
"""
tensor = self.to_cl_tensor()
torch.save(tensor, filename)
129 changes: 129 additions & 0 deletions torch_dreams/custom_series_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from .base_series_param import BaseSeriesParam

import numpy as np
import torch

from .utils import (
lucid_colorspace_to_rgb,
normalize,
get_fft_scale_custom_series,
cl_series_to_fft_param,
fft_to_series
)


class CustomSeriesParam(BaseSeriesParam):
"""FFT parameterization for custom series.

Works well with:
* lower learning rates (3e-4)
* gradients clipped to (0, 0.1)
* weight decay (1e-1)

Args:
series (torch.tensor): input tensor with shape [channels, length].
device (str): 'cuda' or 'cpu'

Example:
```
series = torch.ones((1, 2, 100))
param = custom_series_param(series=series, device='cuda')

result = dreamy_boi.render(
image_parameter=param,
layers = [model.Mixed_6c],
lr = 3e-4,
grad_clip = 0.1,
weight_decay= 1e-1
)
```
"""
def __init__(
self,
series,
device,
#channel_correlation_matrix,
normalize_mean=None,
normalize_std=None,
):
batch_size = series.shape[0]
channels = series.shape[1]
length = series.shape[2]

super().__init__(
batch_size=batch_size,
channels=channels,
length=length,
param=series, # we use set_param in the next step
normalize_mean=normalize_mean,
normalize_std=normalize_std,
device=device,
)

channel_correlation_matrix = get_normalized_correlation_matrix(channels)

self.set_param(series, channel_correlation_matrix, device=device)

def postprocess(self, device):
out = fft_to_series(
channels=self.channels,
length=self.length,
series_parameter=self.param,
device=device,
)
out = lucid_colorspace_to_rgb(t=out, device=device).clamp(0,1)
return out


def set_param(self, tensor, channel_correlation_matrix, device):
"""sets an NCL tensor as the parameter in the frequency domain,
useful for transforming custom series between iterations.

Use in combination with `self.to_ncl_tensor()` like:

```
a = self.to_cl_tensor()
# do something with a
t = transforms.Compose([
transforms.RandomScale(0,5, 1.2)
])
a = t(a)
#set as parameter again
self.set_param(a)
```

WARNING: tensor should have values clipped between 0 and 1.

Args:
tensor (torch.tensor): input tensor with shape [1,channels, length] and values clipped between 0,1.
"""
assert len(tensor.shape) == 3
assert tensor.shape[0] == 1

self.tensor = tensor

batch_size = tensor.shape[0]
channels = tensor.shape[1]
length = tensor.shape[2]

scale = get_fft_scale_custom_series(length=length, device=device)
# TODO: denormalize
#fft_param = cl_series_to_fft_param(self.denormalize(tensor.squeeze(0)), device=device)
fft_param = cl_series_to_fft_param(tensor, channel_correlation_matrix=channel_correlation_matrix, device=device)
self.param = fft_param / scale

self.param.requires_grad_()

self.batch_size = batch_size
self.channels = channels
self.length = length
self.device = device


def get_normalized_correlation_matrix(channels):
# TODO: these values must be passed by the user
correlation_svd_sqrt = np.random.rand(channels, channels).astype(np.float32)

max_norm_svd_sqrt = np.max(np.linalg.norm(correlation_svd_sqrt, axis=0))
correlation_normalized = torch.tensor(correlation_svd_sqrt / max_norm_svd_sqrt)
return correlation_normalized
Loading
Loading