Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
guoxj11 committed Jul 28, 2024
1 parent 109ce04 commit f596132
Show file tree
Hide file tree
Showing 12 changed files with 1,333 additions and 0 deletions.
37 changes: 37 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
VERBOSE: True

MODEL:
SESSION: 'de_highlight'
INPUT: 'specular'
TARGET: 'diffuse'

# Optimization arguments.
OPTIM:
BATCH_SIZE: 8
NUM_EPOCHS: 100
LR_INITIAL: 2e-4
LR_MIN: 1e-6
SEED: 3407
WANDB: False

TRAINING:
VAL_AFTER_EVERY: 1
RESUME: False
WEIGHT: ''
PS_W: 256
PS_H: 256
TRAIN_DIR: '' # path to training data
VAL_DIR: '' # path to validation data
SAVE_DIR: '' # path to save models and images
ORI: False
LOG_FILE: ''

TESTING:
WEIGHT: ''
TEST_DIR: '' # path to testing data
SAVE_IMAGES: True
RESULT_DIR: ''
LOG_FILE: ''

LOG:
LOG_DIR: ''
1 change: 1 addition & 0 deletions config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import Config
119 changes: 119 additions & 0 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 23 14:35:48 2019
@author: aditya
"""

r"""This module provides package-wide configuration management."""
from typing import Any, List

from yacs.config import CfgNode as CN


class Config(object):
r"""
A collection of all the required configuration parameters. This class is a nested dict-like
structure, with nested keys accessible as attributes. It contains sensible default values for
all the parameters, which may be overriden by (first) through a YAML file and (second) through
a list of attributes and values.
Extended Summary
----------------
This class definition contains default values corresponding to ``joint_training`` phase, as it
is the final training phase and uses almost all the configuration parameters. Modification of
any parameter after instantiating this class is not possible, so you must override required
parameter values in either through ``config_yaml`` file or ``config_override`` list.
Parameters
----------
config_yaml: str
Path to a YAML file containing configuration parameters to override.
config_override: List[Any], optional (default= [])
A list of sequential attributes and values of parameters to override. This happens after
overriding from YAML file.
Examples
--------
Let a YAML file named "config.yaml" specify these parameters to override::
ALPHA: 1000.0
BETA: 0.5
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
>>> _C.ALPHA # default: 100.0
1000.0
>>> _C.BATCH_SIZE # default: 256
2048
>>> _C.BETA # default: 0.1
0.7
Attributes
----------
"""

def __init__(self, config_yaml: str, config_override: List[Any] = []):
self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False

self._C.MODEL = CN()
self._C.MODEL.SESSION = 'de_highlight'
self._C.MODEL.INPUT = 'input'
self._C.MODEL.TARGET = 'target'

self._C.OPTIM = CN()
self._C.OPTIM.BATCH_SIZE = 1
self._C.OPTIM.SEED = 3407
self._C.OPTIM.NUM_EPOCHS = 100
self._C.OPTIM.NEPOCH_DECAY = [50]
self._C.OPTIM.LR_INITIAL = 0.0002
self._C.OPTIM.LR_MIN = 0.0002
self._C.OPTIM.BETA1 = 0.5
self._C.OPTIM.WANDB = False

self._C.TRAINING = CN()
self._C.TRAINING.VAL_AFTER_EVERY = 1
self._C.TRAINING.RESUME = False
self._C.TRAINING.TRAIN_DIR = '../dataset/train'
self._C.TRAINING.VAL_DIR = '../dataset/val'
self._C.TRAINING.SAVE_DIR = 'checkpoints'
self._C.TRAINING.PS_W = 256
self._C.TRAINING.PS_H = 256
self._C.TRAINING.ORI = False
self._C.TRAINING.LOG_FILE = 'log.txt'
self._C.TRAINING.WEIGHT = './checkpoints/model_epoch_68.pth'

self._C.TESTING = CN()
self._C.TESTING.WEIGHT = './checkpoints/model_epoch_68.pth'
self._C.TESTING.SAVE_IMAGES = False
self._C.TESTING.LOG_FILE = 'log.txt'
self._C.TESTING.TEST_DIR = '../dataset/test'
self._C.TESTING.RESULT_DIR = '../result'

self._C.LOG = CN()
self._C.LOG.LOG_DIR = 'output_dir'

# Override parameter values from YAML file first, then from override list.
self._C.merge_from_file(config_yaml)
self._C.merge_from_list(config_override)

# Make an instantiated object of this class immutable.
self._C.freeze()

def dump(self, file_path: str):
r"""Save config at the specified file path.
Parameters
----------
file_path: str
(YAML) path to save config at.
"""
self._C.dump(stream=open(file_path, "w"))

def __getattr__(self, attr: str):
return self._C.__getattr__(attr)

def __repr__(self):
return self._C.__repr__()
1 change: 1 addition & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .data_RGB import get_training_data, get_validation_data
12 changes: 12 additions & 0 deletions data/data_RGB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from .dataset_RGB import DataLoaderTrain, DataLoaderVal


def get_training_data(rgb_dir, inp, target, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, inp, target, img_options)


def get_validation_data(rgb_dir, inp, target, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, inp, target, img_options)
158 changes: 158 additions & 0 deletions data/dataset_RGB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
import random
import albumentations as A
import numpy as np
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import Dataset


def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])


class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, inp='input', target='target', img_options=None):
super(DataLoaderTrain, self).__init__()

inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp)))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target)))

self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)]

self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target

self.transform = A.Compose([
A.Flip(p=0.3),
A.RandomRotate90(p=0.3),
A.Rotate(p=0.3),
A.Transpose(p=0.3),
A.RandomResizedCrop(height=img_options['h'], width=img_options['w']),
],
additional_targets={
'target': 'image',
}
)

def mixup(self, inp_img, tar_img, mode='mixup'):
mixup_index_ = random.randint(0, self.sizex - 1)

mixup_inp_path = self.inp_filenames[mixup_index_]
mixup_tar_path = self.tar_filenames[mixup_index_]

mixup_inp_img = Image.open(mixup_inp_path).convert('RGB')
mixup_tar_img = Image.open(mixup_tar_path).convert('RGB')

mixup_inp_img = np.array(mixup_inp_img)
mixup_tar_img = np.array(mixup_tar_img)

transformed = self.transform(image=mixup_inp_img, target=mixup_tar_img)

alpha = 0.2
lam = np.random.beta(alpha, alpha)

mixup_inp_img = F.to_tensor(transformed['image'])
mixup_tar_img = F.to_tensor(transformed['target'])

if mode == 'mixup':
inp_img = lam * inp_img + (1 - lam) * mixup_inp_img
tar_img = lam * tar_img + (1 - lam) * mixup_tar_img
elif mode == 'cutmix':
img_h, img_w = self.img_options['h'], self.img_options['w']

cx = np.random.uniform(0, img_w)
cy = np.random.uniform(0, img_h)

w = img_w * np.sqrt(1 - lam)
h = img_h * np.sqrt(1 - lam)

x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, img_w)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, img_h)))

inp_img[:, y0:y1, x0:x1] = mixup_inp_img[:, y0:y1, x0:x1]
tar_img[:, y0:y1, x0:x1] = mixup_tar_img[:, y0:y1, x0:x1]

return inp_img, tar_img

def __len__(self):
return self.sizex

def __getitem__(self, index):
index_ = index % self.sizex

inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]

inp_img = Image.open(inp_path).convert('RGB')
tar_img = Image.open(tar_path).convert('RGB')

inp_img = np.array(inp_img)
tar_img = np.array(tar_img)

transformed = self.transform(image=inp_img, target=tar_img)

inp_img = F.to_tensor(transformed['image'])
tar_img = F.to_tensor(transformed['target'])

if index_ > 0 and index_ % 3 == 0:
if random.random() > 0.5:
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='mixup')
else:
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='cutmix')

filename = os.path.basename(tar_path)

return inp_img, tar_img, filename


class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, inp='input', target='target', img_options=None):
super(DataLoaderVal, self).__init__()

inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp)))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target)))

self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)]

self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target

self.transform = A.Compose([
A.Resize(height=img_options['h'], width=img_options['w']), ],
additional_targets={
'target': 'image',
}
)

def __len__(self):
return self.sizex

def __getitem__(self, index):
index_ = index % self.sizex

inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]

inp_img = Image.open(inp_path).convert('RGB')
tar_img = Image.open(tar_path).convert('RGB')

inp_img = np.array(inp_img)
tar_img = np.array(tar_img)

if not self.img_options['ori']:
transformed = self.transform(image=inp_img, target=tar_img)

inp_img = transformed['image']
tar_img = transformed['target']

inp_img = F.to_tensor(inp_img)
tar_img = F.to_tensor(tar_img)

filename = os.path.basename(tar_path)

return inp_img, tar_img, filename
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import Model
Loading

0 comments on commit f596132

Please sign in to comment.