-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
VERBOSE: True | ||
|
||
MODEL: | ||
SESSION: 'de_highlight' | ||
INPUT: 'specular' | ||
TARGET: 'diffuse' | ||
|
||
# Optimization arguments. | ||
OPTIM: | ||
BATCH_SIZE: 8 | ||
NUM_EPOCHS: 100 | ||
LR_INITIAL: 2e-4 | ||
LR_MIN: 1e-6 | ||
SEED: 3407 | ||
WANDB: False | ||
|
||
TRAINING: | ||
VAL_AFTER_EVERY: 1 | ||
RESUME: False | ||
WEIGHT: '' | ||
PS_W: 256 | ||
PS_H: 256 | ||
TRAIN_DIR: '' # path to training data | ||
VAL_DIR: '' # path to validation data | ||
SAVE_DIR: '' # path to save models and images | ||
ORI: False | ||
LOG_FILE: '' | ||
|
||
TESTING: | ||
WEIGHT: '' | ||
TEST_DIR: '' # path to testing data | ||
SAVE_IMAGES: True | ||
RESULT_DIR: '' | ||
LOG_FILE: '' | ||
|
||
LOG: | ||
LOG_DIR: '' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .config import Config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue Jul 23 14:35:48 2019 | ||
@author: aditya | ||
""" | ||
|
||
r"""This module provides package-wide configuration management.""" | ||
from typing import Any, List | ||
|
||
from yacs.config import CfgNode as CN | ||
|
||
|
||
class Config(object): | ||
r""" | ||
A collection of all the required configuration parameters. This class is a nested dict-like | ||
structure, with nested keys accessible as attributes. It contains sensible default values for | ||
all the parameters, which may be overriden by (first) through a YAML file and (second) through | ||
a list of attributes and values. | ||
Extended Summary | ||
---------------- | ||
This class definition contains default values corresponding to ``joint_training`` phase, as it | ||
is the final training phase and uses almost all the configuration parameters. Modification of | ||
any parameter after instantiating this class is not possible, so you must override required | ||
parameter values in either through ``config_yaml`` file or ``config_override`` list. | ||
Parameters | ||
---------- | ||
config_yaml: str | ||
Path to a YAML file containing configuration parameters to override. | ||
config_override: List[Any], optional (default= []) | ||
A list of sequential attributes and values of parameters to override. This happens after | ||
overriding from YAML file. | ||
Examples | ||
-------- | ||
Let a YAML file named "config.yaml" specify these parameters to override:: | ||
ALPHA: 1000.0 | ||
BETA: 0.5 | ||
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) | ||
>>> _C.ALPHA # default: 100.0 | ||
1000.0 | ||
>>> _C.BATCH_SIZE # default: 256 | ||
2048 | ||
>>> _C.BETA # default: 0.1 | ||
0.7 | ||
Attributes | ||
---------- | ||
""" | ||
|
||
def __init__(self, config_yaml: str, config_override: List[Any] = []): | ||
self._C = CN() | ||
self._C.GPU = [0] | ||
self._C.VERBOSE = False | ||
|
||
self._C.MODEL = CN() | ||
self._C.MODEL.SESSION = 'de_highlight' | ||
self._C.MODEL.INPUT = 'input' | ||
self._C.MODEL.TARGET = 'target' | ||
|
||
self._C.OPTIM = CN() | ||
self._C.OPTIM.BATCH_SIZE = 1 | ||
self._C.OPTIM.SEED = 3407 | ||
self._C.OPTIM.NUM_EPOCHS = 100 | ||
self._C.OPTIM.NEPOCH_DECAY = [50] | ||
self._C.OPTIM.LR_INITIAL = 0.0002 | ||
self._C.OPTIM.LR_MIN = 0.0002 | ||
self._C.OPTIM.BETA1 = 0.5 | ||
self._C.OPTIM.WANDB = False | ||
|
||
self._C.TRAINING = CN() | ||
self._C.TRAINING.VAL_AFTER_EVERY = 1 | ||
self._C.TRAINING.RESUME = False | ||
self._C.TRAINING.TRAIN_DIR = '../dataset/train' | ||
self._C.TRAINING.VAL_DIR = '../dataset/val' | ||
self._C.TRAINING.SAVE_DIR = 'checkpoints' | ||
self._C.TRAINING.PS_W = 256 | ||
self._C.TRAINING.PS_H = 256 | ||
self._C.TRAINING.ORI = False | ||
self._C.TRAINING.LOG_FILE = 'log.txt' | ||
self._C.TRAINING.WEIGHT = './checkpoints/model_epoch_68.pth' | ||
|
||
self._C.TESTING = CN() | ||
self._C.TESTING.WEIGHT = './checkpoints/model_epoch_68.pth' | ||
self._C.TESTING.SAVE_IMAGES = False | ||
self._C.TESTING.LOG_FILE = 'log.txt' | ||
self._C.TESTING.TEST_DIR = '../dataset/test' | ||
self._C.TESTING.RESULT_DIR = '../result' | ||
|
||
self._C.LOG = CN() | ||
self._C.LOG.LOG_DIR = 'output_dir' | ||
|
||
# Override parameter values from YAML file first, then from override list. | ||
self._C.merge_from_file(config_yaml) | ||
self._C.merge_from_list(config_override) | ||
|
||
# Make an instantiated object of this class immutable. | ||
self._C.freeze() | ||
|
||
def dump(self, file_path: str): | ||
r"""Save config at the specified file path. | ||
Parameters | ||
---------- | ||
file_path: str | ||
(YAML) path to save config at. | ||
""" | ||
self._C.dump(stream=open(file_path, "w")) | ||
|
||
def __getattr__(self, attr: str): | ||
return self._C.__getattr__(attr) | ||
|
||
def __repr__(self): | ||
return self._C.__repr__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .data_RGB import get_training_data, get_validation_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
from .dataset_RGB import DataLoaderTrain, DataLoaderVal | ||
|
||
|
||
def get_training_data(rgb_dir, inp, target, img_options): | ||
assert os.path.exists(rgb_dir) | ||
return DataLoaderTrain(rgb_dir, inp, target, img_options) | ||
|
||
|
||
def get_validation_data(rgb_dir, inp, target, img_options): | ||
assert os.path.exists(rgb_dir) | ||
return DataLoaderVal(rgb_dir, inp, target, img_options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import os | ||
import random | ||
import albumentations as A | ||
import numpy as np | ||
import torchvision.transforms.functional as F | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) | ||
|
||
|
||
class DataLoaderTrain(Dataset): | ||
def __init__(self, rgb_dir, inp='input', target='target', img_options=None): | ||
super(DataLoaderTrain, self).__init__() | ||
|
||
inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) | ||
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) | ||
|
||
self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] | ||
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] | ||
|
||
self.img_options = img_options | ||
self.sizex = len(self.tar_filenames) # get the size of target | ||
|
||
self.transform = A.Compose([ | ||
A.Flip(p=0.3), | ||
A.RandomRotate90(p=0.3), | ||
A.Rotate(p=0.3), | ||
A.Transpose(p=0.3), | ||
A.RandomResizedCrop(height=img_options['h'], width=img_options['w']), | ||
], | ||
additional_targets={ | ||
'target': 'image', | ||
} | ||
) | ||
|
||
def mixup(self, inp_img, tar_img, mode='mixup'): | ||
mixup_index_ = random.randint(0, self.sizex - 1) | ||
|
||
mixup_inp_path = self.inp_filenames[mixup_index_] | ||
mixup_tar_path = self.tar_filenames[mixup_index_] | ||
|
||
mixup_inp_img = Image.open(mixup_inp_path).convert('RGB') | ||
mixup_tar_img = Image.open(mixup_tar_path).convert('RGB') | ||
|
||
mixup_inp_img = np.array(mixup_inp_img) | ||
mixup_tar_img = np.array(mixup_tar_img) | ||
|
||
transformed = self.transform(image=mixup_inp_img, target=mixup_tar_img) | ||
|
||
alpha = 0.2 | ||
lam = np.random.beta(alpha, alpha) | ||
|
||
mixup_inp_img = F.to_tensor(transformed['image']) | ||
mixup_tar_img = F.to_tensor(transformed['target']) | ||
|
||
if mode == 'mixup': | ||
inp_img = lam * inp_img + (1 - lam) * mixup_inp_img | ||
tar_img = lam * tar_img + (1 - lam) * mixup_tar_img | ||
elif mode == 'cutmix': | ||
img_h, img_w = self.img_options['h'], self.img_options['w'] | ||
|
||
cx = np.random.uniform(0, img_w) | ||
cy = np.random.uniform(0, img_h) | ||
|
||
w = img_w * np.sqrt(1 - lam) | ||
h = img_h * np.sqrt(1 - lam) | ||
|
||
x0 = int(np.round(max(cx - w / 2, 0))) | ||
x1 = int(np.round(min(cx + w / 2, img_w))) | ||
y0 = int(np.round(max(cy - h / 2, 0))) | ||
y1 = int(np.round(min(cy + h / 2, img_h))) | ||
|
||
inp_img[:, y0:y1, x0:x1] = mixup_inp_img[:, y0:y1, x0:x1] | ||
tar_img[:, y0:y1, x0:x1] = mixup_tar_img[:, y0:y1, x0:x1] | ||
|
||
return inp_img, tar_img | ||
|
||
def __len__(self): | ||
return self.sizex | ||
|
||
def __getitem__(self, index): | ||
index_ = index % self.sizex | ||
|
||
inp_path = self.inp_filenames[index_] | ||
tar_path = self.tar_filenames[index_] | ||
|
||
inp_img = Image.open(inp_path).convert('RGB') | ||
tar_img = Image.open(tar_path).convert('RGB') | ||
|
||
inp_img = np.array(inp_img) | ||
tar_img = np.array(tar_img) | ||
|
||
transformed = self.transform(image=inp_img, target=tar_img) | ||
|
||
inp_img = F.to_tensor(transformed['image']) | ||
tar_img = F.to_tensor(transformed['target']) | ||
|
||
if index_ > 0 and index_ % 3 == 0: | ||
if random.random() > 0.5: | ||
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='mixup') | ||
else: | ||
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='cutmix') | ||
|
||
filename = os.path.basename(tar_path) | ||
|
||
return inp_img, tar_img, filename | ||
|
||
|
||
class DataLoaderVal(Dataset): | ||
def __init__(self, rgb_dir, inp='input', target='target', img_options=None): | ||
super(DataLoaderVal, self).__init__() | ||
|
||
inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) | ||
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) | ||
|
||
self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] | ||
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] | ||
|
||
self.img_options = img_options | ||
self.sizex = len(self.tar_filenames) # get the size of target | ||
|
||
self.transform = A.Compose([ | ||
A.Resize(height=img_options['h'], width=img_options['w']), ], | ||
additional_targets={ | ||
'target': 'image', | ||
} | ||
) | ||
|
||
def __len__(self): | ||
return self.sizex | ||
|
||
def __getitem__(self, index): | ||
index_ = index % self.sizex | ||
|
||
inp_path = self.inp_filenames[index_] | ||
tar_path = self.tar_filenames[index_] | ||
|
||
inp_img = Image.open(inp_path).convert('RGB') | ||
tar_img = Image.open(tar_path).convert('RGB') | ||
|
||
inp_img = np.array(inp_img) | ||
tar_img = np.array(tar_img) | ||
|
||
if not self.img_options['ori']: | ||
transformed = self.transform(image=inp_img, target=tar_img) | ||
|
||
inp_img = transformed['image'] | ||
tar_img = transformed['target'] | ||
|
||
inp_img = F.to_tensor(inp_img) | ||
tar_img = F.to_tensor(tar_img) | ||
|
||
filename = os.path.basename(tar_path) | ||
|
||
return inp_img, tar_img, filename |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import Model |
Oops, something went wrong.