diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..e9a1dc9 --- /dev/null +++ b/config.yml @@ -0,0 +1,37 @@ +VERBOSE: True + +MODEL: + SESSION: 'de_highlight' + INPUT: 'specular' + TARGET: 'diffuse' + +# Optimization arguments. +OPTIM: + BATCH_SIZE: 8 + NUM_EPOCHS: 100 + LR_INITIAL: 2e-4 + LR_MIN: 1e-6 + SEED: 3407 + WANDB: False + +TRAINING: + VAL_AFTER_EVERY: 1 + RESUME: False + WEIGHT: '' + PS_W: 256 + PS_H: 256 + TRAIN_DIR: '' # path to training data + VAL_DIR: '' # path to validation data + SAVE_DIR: '' # path to save models and images + ORI: False + LOG_FILE: '' + +TESTING: + WEIGHT: '' + TEST_DIR: '' # path to testing data + SAVE_IMAGES: True + RESULT_DIR: '' + LOG_FILE: '' + +LOG: + LOG_DIR: '' \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..cca5d9b --- /dev/null +++ b/config/__init__.py @@ -0,0 +1 @@ +from .config import Config diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..92d5d9c --- /dev/null +++ b/config/config.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 23 14:35:48 2019 + +@author: aditya +""" + +r"""This module provides package-wide configuration management.""" +from typing import Any, List + +from yacs.config import CfgNode as CN + + +class Config(object): + r""" + A collection of all the required configuration parameters. This class is a nested dict-like + structure, with nested keys accessible as attributes. It contains sensible default values for + all the parameters, which may be overriden by (first) through a YAML file and (second) through + a list of attributes and values. + + Extended Summary + ---------------- + This class definition contains default values corresponding to ``joint_training`` phase, as it + is the final training phase and uses almost all the configuration parameters. Modification of + any parameter after instantiating this class is not possible, so you must override required + parameter values in either through ``config_yaml`` file or ``config_override`` list. + + Parameters + ---------- + config_yaml: str + Path to a YAML file containing configuration parameters to override. + config_override: List[Any], optional (default= []) + A list of sequential attributes and values of parameters to override. This happens after + overriding from YAML file. + + Examples + -------- + Let a YAML file named "config.yaml" specify these parameters to override:: + + ALPHA: 1000.0 + BETA: 0.5 + + >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) + >>> _C.ALPHA # default: 100.0 + 1000.0 + >>> _C.BATCH_SIZE # default: 256 + 2048 + >>> _C.BETA # default: 0.1 + 0.7 + + Attributes + ---------- + """ + + def __init__(self, config_yaml: str, config_override: List[Any] = []): + self._C = CN() + self._C.GPU = [0] + self._C.VERBOSE = False + + self._C.MODEL = CN() + self._C.MODEL.SESSION = 'de_highlight' + self._C.MODEL.INPUT = 'input' + self._C.MODEL.TARGET = 'target' + + self._C.OPTIM = CN() + self._C.OPTIM.BATCH_SIZE = 1 + self._C.OPTIM.SEED = 3407 + self._C.OPTIM.NUM_EPOCHS = 100 + self._C.OPTIM.NEPOCH_DECAY = [50] + self._C.OPTIM.LR_INITIAL = 0.0002 + self._C.OPTIM.LR_MIN = 0.0002 + self._C.OPTIM.BETA1 = 0.5 + self._C.OPTIM.WANDB = False + + self._C.TRAINING = CN() + self._C.TRAINING.VAL_AFTER_EVERY = 1 + self._C.TRAINING.RESUME = False + self._C.TRAINING.TRAIN_DIR = '../dataset/train' + self._C.TRAINING.VAL_DIR = '../dataset/val' + self._C.TRAINING.SAVE_DIR = 'checkpoints' + self._C.TRAINING.PS_W = 256 + self._C.TRAINING.PS_H = 256 + self._C.TRAINING.ORI = False + self._C.TRAINING.LOG_FILE = 'log.txt' + self._C.TRAINING.WEIGHT = './checkpoints/model_epoch_68.pth' + + self._C.TESTING = CN() + self._C.TESTING.WEIGHT = './checkpoints/model_epoch_68.pth' + self._C.TESTING.SAVE_IMAGES = False + self._C.TESTING.LOG_FILE = 'log.txt' + self._C.TESTING.TEST_DIR = '../dataset/test' + self._C.TESTING.RESULT_DIR = '../result' + + self._C.LOG = CN() + self._C.LOG.LOG_DIR = 'output_dir' + + # Override parameter values from YAML file first, then from override list. + self._C.merge_from_file(config_yaml) + self._C.merge_from_list(config_override) + + # Make an instantiated object of this class immutable. + self._C.freeze() + + def dump(self, file_path: str): + r"""Save config at the specified file path. + + Parameters + ---------- + file_path: str + (YAML) path to save config at. + """ + self._C.dump(stream=open(file_path, "w")) + + def __getattr__(self, attr: str): + return self._C.__getattr__(attr) + + def __repr__(self): + return self._C.__repr__() diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..0a823c1 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1 @@ +from .data_RGB import get_training_data, get_validation_data diff --git a/data/data_RGB.py b/data/data_RGB.py new file mode 100644 index 0000000..f2b4196 --- /dev/null +++ b/data/data_RGB.py @@ -0,0 +1,12 @@ +import os +from .dataset_RGB import DataLoaderTrain, DataLoaderVal + + +def get_training_data(rgb_dir, inp, target, img_options): + assert os.path.exists(rgb_dir) + return DataLoaderTrain(rgb_dir, inp, target, img_options) + + +def get_validation_data(rgb_dir, inp, target, img_options): + assert os.path.exists(rgb_dir) + return DataLoaderVal(rgb_dir, inp, target, img_options) diff --git a/data/dataset_RGB.py b/data/dataset_RGB.py new file mode 100644 index 0000000..073c131 --- /dev/null +++ b/data/dataset_RGB.py @@ -0,0 +1,158 @@ +import os +import random +import albumentations as A +import numpy as np +import torchvision.transforms.functional as F +from PIL import Image +from torch.utils.data import Dataset + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) + + +class DataLoaderTrain(Dataset): + def __init__(self, rgb_dir, inp='input', target='target', img_options=None): + super(DataLoaderTrain, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) + + self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] + self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.transform = A.Compose([ + A.Flip(p=0.3), + A.RandomRotate90(p=0.3), + A.Rotate(p=0.3), + A.Transpose(p=0.3), + A.RandomResizedCrop(height=img_options['h'], width=img_options['w']), + ], + additional_targets={ + 'target': 'image', + } + ) + + def mixup(self, inp_img, tar_img, mode='mixup'): + mixup_index_ = random.randint(0, self.sizex - 1) + + mixup_inp_path = self.inp_filenames[mixup_index_] + mixup_tar_path = self.tar_filenames[mixup_index_] + + mixup_inp_img = Image.open(mixup_inp_path).convert('RGB') + mixup_tar_img = Image.open(mixup_tar_path).convert('RGB') + + mixup_inp_img = np.array(mixup_inp_img) + mixup_tar_img = np.array(mixup_tar_img) + + transformed = self.transform(image=mixup_inp_img, target=mixup_tar_img) + + alpha = 0.2 + lam = np.random.beta(alpha, alpha) + + mixup_inp_img = F.to_tensor(transformed['image']) + mixup_tar_img = F.to_tensor(transformed['target']) + + if mode == 'mixup': + inp_img = lam * inp_img + (1 - lam) * mixup_inp_img + tar_img = lam * tar_img + (1 - lam) * mixup_tar_img + elif mode == 'cutmix': + img_h, img_w = self.img_options['h'], self.img_options['w'] + + cx = np.random.uniform(0, img_w) + cy = np.random.uniform(0, img_h) + + w = img_w * np.sqrt(1 - lam) + h = img_h * np.sqrt(1 - lam) + + x0 = int(np.round(max(cx - w / 2, 0))) + x1 = int(np.round(min(cx + w / 2, img_w))) + y0 = int(np.round(max(cy - h / 2, 0))) + y1 = int(np.round(min(cy + h / 2, img_h))) + + inp_img[:, y0:y1, x0:x1] = mixup_inp_img[:, y0:y1, x0:x1] + tar_img[:, y0:y1, x0:x1] = mixup_tar_img[:, y0:y1, x0:x1] + + return inp_img, tar_img + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path).convert('RGB') + tar_img = Image.open(tar_path).convert('RGB') + + inp_img = np.array(inp_img) + tar_img = np.array(tar_img) + + transformed = self.transform(image=inp_img, target=tar_img) + + inp_img = F.to_tensor(transformed['image']) + tar_img = F.to_tensor(transformed['target']) + + if index_ > 0 and index_ % 3 == 0: + if random.random() > 0.5: + inp_img, tar_img = self.mixup(inp_img, tar_img, mode='mixup') + else: + inp_img, tar_img = self.mixup(inp_img, tar_img, mode='cutmix') + + filename = os.path.basename(tar_path) + + return inp_img, tar_img, filename + + +class DataLoaderVal(Dataset): + def __init__(self, rgb_dir, inp='input', target='target', img_options=None): + super(DataLoaderVal, self).__init__() + + inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) + tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) + + self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] + self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] + + self.img_options = img_options + self.sizex = len(self.tar_filenames) # get the size of target + + self.transform = A.Compose([ + A.Resize(height=img_options['h'], width=img_options['w']), ], + additional_targets={ + 'target': 'image', + } + ) + + def __len__(self): + return self.sizex + + def __getitem__(self, index): + index_ = index % self.sizex + + inp_path = self.inp_filenames[index_] + tar_path = self.tar_filenames[index_] + + inp_img = Image.open(inp_path).convert('RGB') + tar_img = Image.open(tar_path).convert('RGB') + + inp_img = np.array(inp_img) + tar_img = np.array(tar_img) + + if not self.img_options['ori']: + transformed = self.transform(image=inp_img, target=tar_img) + + inp_img = transformed['image'] + tar_img = transformed['target'] + + inp_img = F.to_tensor(inp_img) + tar_img = F.to_tensor(tar_img) + + filename = os.path.basename(tar_path) + + return inp_img, tar_img, filename diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..7040c63 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .model import Model \ No newline at end of file diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000..9a29e19 --- /dev/null +++ b/models/model.py @@ -0,0 +1,747 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft as fft +import math +import numpy as np +from einops import rearrange +import numbers + + +def inv_mag(x): + fft_ = torch.fft.fft2(x) + fft_ = torch.fft.ifft2(1 * torch.exp(1j * (fft_.angle()))) + return fft_.real + + +class Mapping(nn.Module): + def __init__(self, in_features=3, hidden_features=256, hidden_layers=3, out_features=3, res=True): + """ + Parameters: + in_features (int): Number of input features (channels). + hidden_features (int): Number of features in hidden layers. + hidden_layers (int): Number of hidden layers. + out_features (int): Number of output features (channels). + res (bool): Whether to use residual connections. + """ + super(Mapping, self).__init__() + + self.res = res + self.net = [] + self.net.append(nn.Linear(in_features, hidden_features)) + self.net.append(nn.ReLU()) + + for _ in range(hidden_layers): + self.net.append(nn.Linear(hidden_features, hidden_features)) + self.net.append(nn.Tanh()) + + self.net.append(nn.Linear(hidden_features, out_features)) + if not self.res: + self.net.append(torch.nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, inp): + original_shape = inp.shape + inp = inp.view(-1, inp.shape[1]) + + output = self.net(inp) + + if self.res: + output = output + inp + output = torch.clamp(output, 0., 1.) + + output = output.view(original_shape) + + return output + + +class Toning(nn.Module): + def __init__(self, channels, b=1, gamma=2): + super(Toning, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.channels = channels + self.b = b + self.gamma = gamma + self.conv = nn.Conv1d(1, 1, kernel_size=self.kernel_size(), padding=(self.kernel_size() - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def kernel_size(self): + k = int(abs((math.log2(self.channels) / self.gamma) + self.b / self.gamma)) + out = k if k % 2 else k + 1 + return out + + def forward(self, x): + x1 = inv_mag(x) + y = self.avg_pool(x1) + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + y = self.sigmoid(y) + return x * y.expand_as(x) + + +class FrequencyProcessor(nn.Module): + def __init__(self, channels=3, int_size=64): + super(FrequencyProcessor, self).__init__() + self.identity1 = nn.Conv2d(channels, channels, 1) + self.identity2 = nn.Conv2d(channels, channels, 1) + + self.conv_f1 = nn.Conv2d(channels, channels, kernel_size=1) + self.map = Mapping(in_features=channels, out_features=channels, hidden_features=int_size, hidden_layers=5) + self.fuse = nn.Conv2d(2 * channels, channels, kernel_size=1) + self.tone = Toning(channels) + + def forward(self, x): + out = self.identity1(x) + + x_fft = fft.fftn(x, dim=(-2, -1)).real + x_fft = F.gelu(self.conv_f1(x_fft)) + x_fft = self.map(x_fft) + x_reconstructed = fft.ifftn(x_fft, dim=(-2, -1)).real + x_reconstructed += self.identity2(x) + + f_out = self.fuse(torch.cat([out, x_reconstructed], dim=1)) + + return self.tone(f_out) + + +def window_partition(x, window_size: int, h, w): + """ + Args: + x: (B, H, W, C) + window_size (int): window size(M) + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + pad_l = pad_t = 0 + pad_r = (window_size - w % window_size) % window_size + pad_b = (window_size - h % window_size) % window_size + x = F.pad(x, [pad_l, pad_r, pad_t, pad_b]) + B, C, H, W = x.shape + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + windows = x.permute(0, 1, 2, 4, 3, 5).contiguous().view(-1, C, window_size, window_size) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size(M) + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + pad_l = pad_t = 0 + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + H = H + pad_b + W = W + pad_r + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, -1, H // window_size, W // window_size, window_size, window_size) + x = x.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, -1, H, W) + windows = F.pad(x, [pad_l, -pad_r, pad_t, -pad_b]) + return windows + + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super(LayerNorm, self).__init__() + self.body = BiasFree_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +class FeedForward(nn.Module): + def __init__(self, dim, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * 3) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.relu(x1) * x2 + x = self.project_out(x) + return x + + +class P_SSSWA(nn.Module): + def __init__(self, dim, window_size, shift_size, bias): + super(P_SSSWA, self).__init__() + self.window_size = window_size + self.shift_size = shift_size + + self.qk_conv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.qk_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + + self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + + self.project_out = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias) + self.project_out1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias) + + self.qkv_conv1 = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv1 = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + + def window_partitions(self, x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size(M) + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + def create_mask(self, x): + + n, c, H, W = x.shape + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1] + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = self.window_partitions(img_mask, self.window_size) # [nW, Mh, Mw, 1] + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] + # [nW, Mh*Mw, Mh*Mw] + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + def forward(self, x, x_f): + shortcut = x + b, c, h, w = x.shape + + x = window_partition(x, self.window_size, h, w) + x_f = window_partition(x_f, self.window_size, h, w) + + qk = self.qk_dwconv(self.qk_conv(x_f)) + v = self.v_dwconv(self.v_conv(x)) + q, k = qk.chunk(2, dim=1) + + q = rearrange(q, 'b c h w -> b c (h w)') + k = rearrange(k, 'b c h w -> b c (h w)') + v = rearrange(v, 'b c h w -> b c (h w)') + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q.transpose(-2, -1) @ k) / self.window_size + attn = attn.softmax(dim=-1) + out = (v @ attn) + out = rearrange(out, 'b c (h w) -> b c h w', h=int(self.window_size), + w=int(self.window_size)) + out = self.project_out(out) + out = window_reverse(out, self.window_size, h, w) + + shift = torch.roll(out, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) + shift_window = window_partition(shift, self.window_size, h, w) + qkv = self.qkv_dwconv1(self.qkv_conv1(shift_window)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b c h w -> b c (h w)') + k = rearrange(k, 'b c h w -> b c (h w)') + v = rearrange(v, 'b c h w -> b c (h w)') + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q.transpose(-2, -1) @ k) / self.window_size + mask = self.create_mask(shortcut) + attn = attn.view(b, -1, self.window_size * self.window_size, + self.window_size * self.window_size) + mask.unsqueeze(0) + attn = attn.view(-1, self.window_size * self.window_size, self.window_size * self.window_size) + attn = attn.softmax(dim=-1) + + out = (v @ attn) + + out = rearrange(out, 'b c (h w) -> b c h w', h=int(self.window_size), + w=int(self.window_size)) + + out = self.project_out1(out) + out = window_reverse(out, self.window_size, h, w) + out = torch.roll(out, shifts=(self.shift_size, self.shift_size), dims=(2, 3)) + + return out + + +# Pixel-wise Spatial-Spectral Shifting Window Attention (P_SSSWA) Transformer +class P_SSSWATransformer(nn.Module): + def __init__(self, dim, window_size, shift_size, bias): + super(P_SSSWATransformer, self).__init__() + self.norm1 = LayerNorm(dim) + self.attn = P_SSSWA(dim, window_size, shift_size, bias) + self.norm2 = LayerNorm(dim) + self.ffn = FeedForward(dim, bias) + + self.fp = FrequencyProcessor(channels=dim) + self.norm3 = LayerNorm(dim) + + def forward(self, x): + x_f = self.fp(x) + y = self.attn(self.norm1(x), self.norm3(x_f)) + diffY = x.size()[2] - y.size()[2] + diffX = x.size()[3] - y.size()[3] + + y = F.pad(y, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = x + y + x = x + self.ffn(self.norm2(x)) + + return x + + +class C_SSSWA(nn.Module): + def __init__(self, dim, window_size, shift_size, bias): + super(C_SSSWA, self).__init__() + self.window_size = window_size + self.shift_size = shift_size + self.temperature = nn.Parameter(torch.ones(1)) + + self.qk_conv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias) + self.qk_dwconv = nn.Conv2d(dim * 2, dim * 2, kernel_size=3, stride=1, padding=1, groups=dim * 2, bias=bias) + + self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + + self.project_out = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias) + self.project_out1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias) + + self.qkv_conv1 = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv1 = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + + def forward(self, x, x_f): + b, c, h, w = x.shape + + x = window_partition(x, self.window_size, h, w) + x_f = window_partition(x_f, self.window_size, h, w) + qk = self.qk_dwconv(self.qk_conv(x_f)) + v = self.v_dwconv(self.v_conv(x)) + q, k = qk.chunk(2, dim=1) + + q = rearrange(q, 'b c h w -> b c (h w)') + k = rearrange(k, 'b c h w -> b c (h w)') + v = rearrange(v, 'b c h w -> b c (h w)') + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) / self.temperature + attn = attn.softmax(dim=-1) + out = (attn @ v) + out = rearrange(out, 'b c (h w) -> b c h w', h=int(self.window_size), + w=int(self.window_size)) + out = self.project_out(out) + out = window_reverse(out, self.window_size, h, w) + + shift = torch.roll(out, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) + + shift_window = window_partition(shift, self.window_size, h, w) + qkv = self.qkv_dwconv1(self.qkv_conv1(shift_window)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b c h w -> b c (h w)') + k = rearrange(k, 'b c h w -> b c (h w)') + v = rearrange(v, 'b c h w -> b c (h w)') + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) / self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b c (h w) -> b c h w', h=int(self.window_size), + w=int(self.window_size)) + + out = self.project_out1(out) + out = window_reverse(out, self.window_size, h, w) + out = torch.roll(out, shifts=(self.shift_size, self.shift_size), dims=(2, 3)) + + return out + + +# Channel-wise Spatial-Spectral Shifting Window Attention (C_SSSWA) Transformer +class C_SSSWATransformer(nn.Module): + def __init__(self, dim, window_size, shift_size, bias): + super(C_SSSWATransformer, self).__init__() + self.norm1 = LayerNorm(dim) + self.attn = C_SSSWA(dim, window_size, shift_size, bias) + self.norm2 = LayerNorm(dim) + self.ffn = FeedForward(dim, bias) + + self.fp = FrequencyProcessor(channels=dim) + self.norm3 = LayerNorm(dim) + + def forward(self, x): + x_f = self.fp(x) + y = self.attn(self.norm1(x), self.norm3(x_f)) + diffY = x.size()[2] - y.size()[2] + diffX = x.size()[3] - y.size()[3] + + y = F.pad(y, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = x + y + x = x + self.ffn(self.norm2(x)) + + return x + + +# Channel-Wise Contextual Attention +class ChannelAttention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(ChannelAttention, self).__init__() + self.num_heads = num_heads + + self.qkv_conv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv_conv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) / np.sqrt(int(c / self.num_heads)) + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +# Channel-Wise Contextual Attention Transformer +class CCATransformer(nn.Module): + def __init__(self, dim, num_heads, bias): + super(CCATransformer, self).__init__() + + self.norm1 = LayerNorm(dim) + self.attn = ChannelAttention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim) + self.ffn = FeedForward(dim, bias) + + def forward(self, x): + y = self.attn(self.norm1(x)) + diffY = x.size()[2] - y.size()[2] + diffX = x.size()[3] - y.size()[3] + + y = F.pad(y, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = x + y + x = x + self.ffn(self.norm2(x)) + + return x + + +# Pixel-Wise Self Attention +class PixelWiseSelfAttention(nn.Module): + def __init__(self, dim, bias): + super(PixelWiseSelfAttention, self).__init__() + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, + groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b c h w -> b c (h w)') + k = rearrange(k, 'b c h w -> b c (h w)') + v = rearrange(v, 'b c h w -> b c (h w)') + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q.transpose(-2, -1) @ k) + attn = attn.softmax(dim=-1) + + out = (v @ attn) + + out = rearrange(out, 'b c (h w) -> b c h w', h=int(h), + w=int(w)) + + out = self.project_out(out) + + return out + + +# Pixel-Wise Self Attention Transformer +class PSATransformer(nn.Module): + def __init__(self, dim, bias): + super(PSATransformer, self).__init__() + + self.norm1 = LayerNorm(dim) + self.attn = PixelWiseSelfAttention(dim, bias) + self.norm2 = LayerNorm(dim) + self.ffn = FeedForward(dim, bias) + + def forward(self, x): + y = self.attn(self.norm1(x)) + diffY = x.size()[2] - y.size()[2] + diffX = x.size()[3] - y.size()[3] + + y = F.pad(y, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = x + y + x = x + self.ffn(self.norm2(x)) + + return x + + +# Adaptive Local Hybrid-Domain Dual Attention Transformer (L-HD-DAT) +class Adaptive_LHD_TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, window_size, shift_size, bias): + super(Adaptive_LHD_TransformerBlock, self).__init__() + + self.channel = C_SSSWATransformer(dim, window_size, shift_size, bias) + + self.pixel = P_SSSWATransformer(dim, window_size, shift_size, bias) + + self.alpha = nn.Parameter(torch.ones(1) / 2) + + def forward(self, x): + x = self.alpha * self.pixel(x) + (1 - self.alpha) * self.channel(x) + + return x + + +# Adaptive Global Dual Attention Transformer (G-DAT) Block +class G_DAT(nn.Module): + def __init__(self, dim, num_heads, window_size, shift_size, bias): + super(G_DAT, self).__init__() + + self.csa = CCATransformer(dim, num_heads, bias) + + self.pixel = PSATransformer(dim, bias) + + self.belta = nn.Parameter(torch.ones(1) / 2) + + def forward(self, x): + x = self.belta * self.pixel(x) + (1 - self.belta) * self.csa(x) + + return x + + +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.PixelUnshuffle(2), + nn.Conv2d(n_feat * 2 * 2, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False)) + + def forward(self, x): + _, _, h, w = x.shape + if h % 2 != 0: + x = F.pad(x, [0, 0, 1, 0]) + if w % 2 != 0: + x = F.pad(x, [1, 0, 0, 0]) + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + _, _, h, w = x.shape + if h % 2 != 0: + x = F.pad(x, [0, 0, 1, 0]) + if w % 2 != 0: + x = F.pad(x, [1, 0, 0, 0]) + return self.body(x) + + +def cat(x1, x2): + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + + return x + +class Processor(nn.Module): + def __init__(self, + inp_channels=3, + out_channels=3, + dim=36, + num_blocks=[2, 2, 2, 2], + num_refinement_blocks=2, + heads=[2, 2, 2, 2], + bias=False, + window_size=8, + shift_size=3 + ): + super(Processor, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential(*[ + Adaptive_LHD_TransformerBlock(dim=dim, num_heads=heads[0], bias=bias, window_size=window_size, shift_size=shift_size) for + i in range(num_blocks[0])]) + + self.down1_2 = Downsample(dim) + self.encoder_level2 = nn.Sequential(*[ + CCATransformer(dim=int(dim * 2 ** 1), num_heads=heads[1], bias=bias) for i in range(num_blocks[1])]) + + self.down2_3 = Downsample(int(dim * 2 ** 1)) + self.encoder_level3 = nn.Sequential(*[ + CCATransformer(dim=int(dim * 2 ** 2), num_heads=heads[2], bias=bias) for i in range(num_blocks[2])]) + + + self.bottleneck = nn.Sequential(*[ + G_DAT(dim=int(dim * 2 ** 2), num_heads=heads[3], bias=bias, window_size=window_size, + shift_size=shift_size) for i in range(num_blocks[3])]) + + self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) + self.decoder_level3 = nn.Sequential(*[ + CCATransformer(dim=int(dim * 2 ** 2), num_heads=heads[2], bias=bias) for i in range(num_blocks[2])]) + + self.up3_2 = Upsample(int(dim * 2 ** 2)) + self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) + self.decoder_level2 = nn.Sequential(*[ + CCATransformer(dim=int(dim * 2 ** 1), num_heads=heads[1], bias=bias) for i in range(num_blocks[1])]) + + self.up2_1 = Upsample(int(dim * 2 ** 1)) + self.reduce_chan_level1 = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) + + self.decoder_level1 = nn.Sequential(*[ + Adaptive_LHD_TransformerBlock(dim=int(dim), num_heads=heads[0], bias=bias, window_size=window_size, + shift_size=shift_size) for i in range(num_blocks[0])]) + + self.reduce_chan_ref = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) + self.refinement = nn.Sequential(*[ + CCATransformer(dim=int(dim), num_heads=heads[0], bias=bias) for i in range(num_refinement_blocks)]) + + self.output = nn.Conv2d(int(dim), out_channels, kernel_size=1, bias=bias) + + def forward(self, inp_img): + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + + inp_enc_level2 = self.down1_2(out_enc_level1) + out_enc_level2 = self.encoder_level2(inp_enc_level2) + + inp_enc_level3 = self.down2_3(out_enc_level2) + out_enc_level3 = self.encoder_level3(inp_enc_level3) + + latent = self.bottleneck(out_enc_level3) + + inp_dec_level3 = cat(latent, out_enc_level3) + inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) + out_dec_level3 = self.decoder_level3(inp_dec_level3) + + inp_dec_level2 = self.up3_2(out_dec_level3) + inp_dec_level2 = cat(inp_dec_level2, out_enc_level2) + inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) + out_dec_level2 = self.decoder_level2(inp_dec_level2) + + inp_dec_level1 = self.up2_1(out_dec_level2) + inp_dec_level1 = cat(inp_dec_level1, out_enc_level1) + inp_dec_level1 = self.reduce_chan_level1(inp_dec_level1) + out_dec_level1 = self.decoder_level1(inp_dec_level1) + + ref_out = self.refinement(out_dec_level1) + + out = self.output(ref_out) + inp_img + + return out + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.sfp = Processor() + + def forward(self, x): + out = self.sfp(x) + return out + + +if __name__ == '__main__': + from thop import profile, clever_format + + t = torch.randn(1, 3, 256, 256).cuda() + model = Model().cuda() + macs, params = profile(model, inputs=(t,)) + macs, params = clever_format([macs, params], "%.3f") + print(macs, params) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..adf571b --- /dev/null +++ b/test.py @@ -0,0 +1,80 @@ +import json +import warnings + +from accelerate import Accelerator +from torch.utils.data import DataLoader +from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure +from torchmetrics.functional.regression import mean_absolute_error +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from torchvision.utils import save_image +from tqdm import tqdm + +from config import Config +from data import get_validation_data +from models import * +from utils import * + +warnings.filterwarnings('ignore') + +opt = Config('config.yml') + +seed_everything(opt.OPTIM.SEED) + + +def test(): + accelerator = Accelerator() + + # Data Loader + val_dir = opt.TESTING.TEST_DIR + + criterion_lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex', normalize=True).cuda() + + val_dataset = get_validation_data(val_dir, opt.MODEL.INPUT, opt.MODEL.TARGET, {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H, 'ori': opt.TRAINING.ORI}) + testloader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) + + # Model & Metrics + model = Model() + + load_checkpoint(model, opt.TESTING.WEIGHT) + + model, testloader = accelerator.prepare(model, testloader) + + model.eval() + + size = len(testloader) + stat_psnr = 0 + stat_ssim = 0 + stat_mae = 0 + stat_lpips = 0 + for _, test_data in enumerate(tqdm(testloader)): + # get the inputs; data is a list of [targets, inputs, filename] + inp = test_data[0].contiguous() + tar = test_data[1] + + with torch.no_grad(): + res = model(inp).clamp(0, 1) + + save_image(res, os.path.join(opt.TESTING.RESULT_DIR, test_data[2][0])) + + stat_psnr += peak_signal_noise_ratio(res, tar, data_range=1) + stat_ssim += structural_similarity_index_measure(res, tar, data_range=1) + stat_mae += mean_absolute_error(torch.mul(res, 255), torch.mul(tar, 255)) + stat_lpips += criterion_lpips(res, tar).item() + + stat_psnr /= size + stat_ssim /= size + stat_mae /= size + stat_lpips /= size + + test_info = ("Test Result on {}, check point {}, testing data {}". + format(opt.MODEL.SESSION, opt.TESTING.WEIGHT, opt.TESTING.TEST_DIR)) + log_stats = ("PSNR: {}, SSIM: {}, MAE: {}, LPIPS: {}".format(stat_psnr, stat_ssim, stat_mae, stat_lpips)) + print(test_info) + print(log_stats) + with open(os.path.join(opt.LOG.LOG_DIR, opt.TESTING.LOG_FILE), mode='a', encoding='utf-8') as f: + f.write(json.dumps(test_info) + '\n') + f.write(json.dumps(log_stats) + '\n') + + +if __name__ == '__main__': + test() diff --git a/train.py b/train.py new file mode 100644 index 0000000..905fce4 --- /dev/null +++ b/train.py @@ -0,0 +1,139 @@ +import json +import time +import warnings + +import torch.optim as optim +from accelerate import Accelerator + +from torch.utils.data import DataLoader +from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure +from torchmetrics.functional.regression import mean_absolute_error +from tqdm import tqdm + +from config import Config +from data import get_training_data, get_validation_data +from models import * +from utils import * + +warnings.filterwarnings('ignore') + +opt = Config('config.yml') + +seed_everything(opt.OPTIM.SEED) + +def train(): + # Accelerate + accelerator = Accelerator(log_with='wandb') if opt.OPTIM.WANDB else Accelerator() + device = accelerator.device + config = { + "dataset": opt.TRAINING.TRAIN_DIR + } + accelerator.init_trackers("Highlight", config=config) + + if accelerator.is_local_main_process: + os.makedirs(opt.TRAINING.SAVE_DIR, exist_ok=True) + + # Data Loader + train_dir = opt.TRAINING.TRAIN_DIR + val_dir = opt.TRAINING.VAL_DIR + + train_dataset = get_training_data(train_dir, opt.MODEL.INPUT, opt.MODEL.TARGET, + {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H}) + trainloader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, + drop_last=False, pin_memory=True) + val_dataset = get_validation_data(val_dir, opt.MODEL.INPUT, opt.MODEL.TARGET, + {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H, 'ori': opt.TRAINING.ORI}) + testloader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False, + pin_memory=True) + + # Model & Loss + model = Model() + + # criterion_ssim = SSIM(data_range=1, size_average=True, channel=3).to(device) + criterion_psnr = torch.nn.MSELoss() + + # Optimizer & Scheduler + optimizer_b = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.OPTIM.LR_INITIAL, + betas=(0.9, 0.999), eps=1e-8) + scheduler_b = optim.lr_scheduler.CosineAnnealingLR(optimizer_b, opt.OPTIM.NUM_EPOCHS, eta_min=opt.OPTIM.LR_MIN) + + start_epoch = 1 + + trainloader, testloader = accelerator.prepare(trainloader, testloader) + model = accelerator.prepare(model) + optimizer_b, scheduler_b = accelerator.prepare(optimizer_b, scheduler_b) + + best_epoch = 1 + best_psnr = 0 + size = len(testloader) + # training + for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): + model.train() + for _, data in enumerate(tqdm(trainloader, disable=not accelerator.is_local_main_process)): + inp = data[0].contiguous() + tar = data[1] + + # forward + optimizer_b.zero_grad() + res = model(inp) + + loss_psnr = criterion_psnr(res, tar) + loss_ssim = 1 - structural_similarity_index_measure(res, tar, data_range=1) + + train_loss = loss_psnr + 0.4 * loss_ssim + + # backward + accelerator.backward(train_loss) + optimizer_b.step() + + scheduler_b.step() + + # testing + if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: + model.eval() + psnr = 0 + ssim = 0 + mae = 0 + for _, test_data in enumerate(tqdm(testloader, disable=not accelerator.is_local_main_process)): + # get the inputs; data is a list of [targets, inputs, filename] + inp = test_data[0].contiguous() + tar = test_data[1] + + with torch.no_grad(): + res = model(inp) + + res, tar = accelerator.gather((res, tar)) + + psnr += peak_signal_noise_ratio(res, tar, data_range=1) + ssim += structural_similarity_index_measure(res, tar, data_range=1) + mae += mean_absolute_error(torch.mul(res, 255), torch.mul(tar, 255)) + + psnr /= size + ssim /= size + mae /= size + + if psnr > best_psnr: + # save model + best_epoch = epoch + best_psnr = psnr + save_checkpoint({ + 'state_dict': model.state_dict(), + }, epoch, opt.MODEL.SESSION, opt.TRAINING.SAVE_DIR) + + accelerator.log({ + "PSNR": psnr, + "SSIM": ssim, + "MAE": mae + }, step=epoch) + + log_stats = ("epoch: {}, PSNR: {}, SSIM: {}, MAE: {}, best PSNR: {}, best epoch: {}" + .format(epoch, psnr, ssim, mae, best_psnr, best_epoch)) + print(log_stats) + with open(os.path.join(opt.LOG.LOG_DIR, opt.TRAINING.LOG_FILE), mode='a', encoding='utf-8') as f: + f.write(json.dumps(log_stats) + '\n') + + accelerator.end_training() + + +if __name__ == '__main__': + train() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..16281fe --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..11e64fc --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,37 @@ +import os +import random +from collections import OrderedDict + +import numpy as np +import torch + + +def seed_everything(seed=3407): + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def save_checkpoint(state, epoch, model_name, outdir): + if not os.path.exists(outdir): + os.makedirs(outdir) + checkpoint_file = os.path.join(outdir, model_name + '_' + 'epoch_' + str(epoch) + '.pth') + torch.save(state, checkpoint_file) + + +def load_checkpoint(model, weights): + checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0)) + new_state_dict = OrderedDict() + for key, value in checkpoint['state_dict'].items(): + if key.startswith('module'): + name = key[7:] + else: + name = key + new_state_dict[name] = value + model.load_state_dict(new_state_dict) +