-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fac3a10
commit fd3e108
Showing
15 changed files
with
37,131 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
VERBOSE: True | ||
|
||
MODEL: | ||
SESSION: 'LOL' | ||
INPUT: 'input' | ||
TARGET: 'target' | ||
|
||
# Optimization arguments. | ||
OPTIM: | ||
BATCH_SIZE: 4 | ||
NUM_EPOCHS: 100 | ||
LR_INITIAL: 2e-4 | ||
LR_MIN: 1e-6 | ||
SEED: 3407 | ||
WANDB: False | ||
|
||
TRAINING: | ||
VAL_AFTER_EVERY: 1 | ||
RESUME: False | ||
PS_W: 128 | ||
PS_H: 128 | ||
TRAIN_DIR: '../dataset/VigSet/train/' # path to training data | ||
VAL_DIR: '../dataset/VigSet/test/' # path to validation data | ||
SAVE_DIR: './checkpoints/' # path to save models and images | ||
ORI: True | ||
|
||
TESTING: | ||
WEIGHT: './checkpoints/best.pth' | ||
SAVE_IMAGES: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .config import Config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue Jul 23 14:35:48 2019 | ||
@author: aditya | ||
""" | ||
|
||
r"""This module provides package-wide configuration management.""" | ||
from typing import Any, List | ||
|
||
from yacs.config import CfgNode as CN | ||
|
||
|
||
class Config(object): | ||
r""" | ||
A collection of all the required configuration parameters. This class is a nested dict-like | ||
structure, with nested keys accessible as attributes. It contains sensible default values for | ||
all the parameters, which may be overriden by (first) through a YAML file and (second) through | ||
a list of attributes and values. | ||
Extended Summary | ||
---------------- | ||
This class definition contains default values corresponding to ``joint_training`` phase, as it | ||
is the final training phase and uses almost all the configuration parameters. Modification of | ||
any parameter after instantiating this class is not possible, so you must override required | ||
parameter values in either through ``config_yaml`` file or ``config_override`` list. | ||
Parameters | ||
---------- | ||
config_yaml: str | ||
Path to a YAML file containing configuration parameters to override. | ||
config_override: List[Any], optional (default= []) | ||
A list of sequential attributes and values of parameters to override. This happens after | ||
overriding from YAML file. | ||
Examples | ||
-------- | ||
Let a YAML file named "config.yaml" specify these parameters to override:: | ||
ALPHA: 1000.0 | ||
BETA: 0.5 | ||
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) | ||
>>> _C.ALPHA # default: 100.0 | ||
1000.0 | ||
>>> _C.BATCH_SIZE # default: 256 | ||
2048 | ||
>>> _C.BETA # default: 0.1 | ||
0.7 | ||
Attributes | ||
---------- | ||
""" | ||
|
||
def __init__(self, config_yaml: str, config_override: List[Any] = []): | ||
self._C = CN() | ||
self._C.GPU = [0] | ||
self._C.VERBOSE = False | ||
|
||
self._C.MODEL = CN() | ||
self._C.MODEL.SESSION = 'MRI-CT' | ||
self._C.MODEL.INPUT = 'MRI' | ||
self._C.MODEL.TARGET = 'CT' | ||
|
||
self._C.OPTIM = CN() | ||
self._C.OPTIM.BATCH_SIZE = 1 | ||
self._C.OPTIM.SEED = 3407 | ||
self._C.OPTIM.NUM_EPOCHS = 200 | ||
self._C.OPTIM.NEPOCH_DECAY = [100] | ||
self._C.OPTIM.LR_INITIAL = 0.0002 | ||
self._C.OPTIM.LR_MIN = 0.0002 | ||
self._C.OPTIM.BETA1 = 0.5 | ||
self._C.OPTIM.WANDB = False | ||
|
||
self._C.TRAINING = CN() | ||
self._C.TRAINING.VAL_AFTER_EVERY = 1 | ||
self._C.TRAINING.RESUME = False | ||
self._C.TRAINING.TRAIN_DIR = '../dataset/MRI-CT/train' | ||
self._C.TRAINING.VAL_DIR = '../dataset/MRI-CT/test' | ||
self._C.TRAINING.SAVE_DIR = 'checkpoints' | ||
self._C.TRAINING.PS_W = 256 | ||
self._C.TRAINING.PS_H = 256 | ||
self._C.TRAINING.ORI = False | ||
|
||
self._C.TESTING = CN() | ||
self._C.TESTING.WEIGHT = './checkpoints/MRI-PET_epoch_68.pth' | ||
self._C.TESTING.SAVE_IMAGES = False | ||
|
||
# Override parameter values from YAML file first, then from override list. | ||
self._C.merge_from_file(config_yaml) | ||
self._C.merge_from_list(config_override) | ||
|
||
# Make an instantiated object of this class immutable. | ||
self._C.freeze() | ||
|
||
def dump(self, file_path: str): | ||
r"""Save config at the specified file path. | ||
Parameters | ||
---------- | ||
file_path: str | ||
(YAML) path to save config at. | ||
""" | ||
self._C.dump(stream=open(file_path, "w")) | ||
|
||
def __getattr__(self, attr: str): | ||
return self._C.__getattr__(attr) | ||
|
||
def __repr__(self): | ||
return self._C.__repr__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .data_RGB import get_training_data, get_validation_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
from .dataset_RGB import DataLoaderTrain, DataLoaderVal | ||
|
||
|
||
def get_training_data(rgb_dir, inp, target, img_options): | ||
assert os.path.exists(rgb_dir) | ||
return DataLoaderTrain(rgb_dir, inp, target, img_options) | ||
|
||
|
||
def get_validation_data(rgb_dir, inp, target, img_options): | ||
assert os.path.exists(rgb_dir) | ||
return DataLoaderVal(rgb_dir, inp, target, img_options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import os | ||
import random | ||
import albumentations as A | ||
import numpy as np | ||
import torchvision.transforms.functional as F | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) | ||
|
||
|
||
class DataLoaderTrain(Dataset): | ||
def __init__(self, rgb_dir, inp='input', target='target', img_options=None): | ||
super(DataLoaderTrain, self).__init__() | ||
|
||
inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) | ||
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) | ||
|
||
self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] | ||
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] | ||
|
||
self.img_options = img_options | ||
self.sizex = len(self.tar_filenames) # get the size of target | ||
|
||
self.transform = A.Compose([ | ||
A.Flip(p=0.3), | ||
A.Affine(p=0.3), | ||
A.RandomResizedCrop(height=img_options['h'], width=img_options['w']), | ||
], | ||
additional_targets={ | ||
'target': 'image', | ||
} | ||
) | ||
|
||
def mixup(self, inp_img, tar_img, mode='mixup'): | ||
mixup_index_ = random.randint(0, self.sizex - 1) | ||
|
||
mixup_inp_path = self.inp_filenames[mixup_index_] | ||
mixup_tar_path = self.tar_filenames[mixup_index_] | ||
|
||
mixup_inp_img = Image.open(mixup_inp_path).convert('RGB') | ||
mixup_tar_img = Image.open(mixup_tar_path).convert('RGB') | ||
|
||
mixup_inp_img = np.array(mixup_inp_img) | ||
mixup_tar_img = np.array(mixup_tar_img) | ||
|
||
transformed = self.transform(image=mixup_inp_img, target=mixup_tar_img) | ||
|
||
alpha = 0.2 | ||
lam = np.random.beta(alpha, alpha) | ||
|
||
mixup_inp_img = F.to_tensor(transformed['image']) | ||
mixup_tar_img = F.to_tensor(transformed['target']) | ||
|
||
if mode == 'mixup': | ||
inp_img = lam * inp_img + (1 - lam) * mixup_inp_img | ||
tar_img = lam * tar_img + (1 - lam) * mixup_tar_img | ||
elif mode == 'cutmix': | ||
img_h, img_w = self.img_options['h'], self.img_options['w'] | ||
|
||
cx = np.random.uniform(0, img_w) | ||
cy = np.random.uniform(0, img_h) | ||
|
||
w = img_w * np.sqrt(1 - lam) | ||
h = img_h * np.sqrt(1 - lam) | ||
|
||
x0 = int(np.round(max(cx - w / 2, 0))) | ||
x1 = int(np.round(min(cx + w / 2, img_w))) | ||
y0 = int(np.round(max(cy - h / 2, 0))) | ||
y1 = int(np.round(min(cy + h / 2, img_h))) | ||
|
||
inp_img[:, y0:y1, x0:x1] = mixup_inp_img[:, y0:y1, x0:x1] | ||
tar_img[:, y0:y1, x0:x1] = mixup_tar_img[:, y0:y1, x0:x1] | ||
|
||
return inp_img, tar_img | ||
|
||
def __len__(self): | ||
return self.sizex | ||
|
||
def __getitem__(self, index): | ||
index_ = index % self.sizex | ||
|
||
inp_path = self.inp_filenames[index_] | ||
tar_path = self.tar_filenames[index_] | ||
|
||
inp_img = Image.open(inp_path).convert('RGB') | ||
tar_img = Image.open(tar_path).convert('RGB') | ||
|
||
inp_img = np.array(inp_img) | ||
tar_img = np.array(tar_img) | ||
|
||
transformed = self.transform(image=inp_img, target=tar_img) | ||
|
||
inp_img = F.to_tensor(transformed['image']) | ||
tar_img = F.to_tensor(transformed['target']) | ||
|
||
if index_ > 0 and index_ % 3 == 0: | ||
if random.random() > 0.5: | ||
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='mixup') | ||
else: | ||
inp_img, tar_img = self.mixup(inp_img, tar_img, mode='cutmix') | ||
|
||
filename = os.path.basename(tar_path) | ||
|
||
return inp_img, tar_img, filename | ||
|
||
|
||
class DataLoaderVal(Dataset): | ||
def __init__(self, rgb_dir, inp='input', target='target', img_options=None): | ||
super(DataLoaderVal, self).__init__() | ||
|
||
inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) | ||
tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) | ||
|
||
self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] | ||
self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] | ||
|
||
self.img_options = img_options | ||
self.sizex = len(self.tar_filenames) # get the size of target | ||
|
||
self.transform = A.Compose([ | ||
A.Resize(height=img_options['h'], width=img_options['w']), ], | ||
additional_targets={ | ||
'target': 'image', | ||
} | ||
) | ||
|
||
def __len__(self): | ||
return self.sizex | ||
|
||
def __getitem__(self, index): | ||
index_ = index % self.sizex | ||
|
||
inp_path = self.inp_filenames[index_] | ||
tar_path = self.tar_filenames[index_] | ||
|
||
inp_img = Image.open(inp_path).convert('RGB') | ||
tar_img = Image.open(tar_path).convert('RGB') | ||
|
||
inp_img = np.array(inp_img) | ||
tar_img = np.array(tar_img) | ||
|
||
if not self.img_options['ori']: | ||
transformed = self.transform(image=inp_img, target=tar_img) | ||
|
||
inp_img = transformed['image'] | ||
tar_img = transformed['target'] | ||
|
||
inp_img = F.to_tensor(inp_img) | ||
tar_img = F.to_tensor(tar_img) | ||
|
||
filename = os.path.basename(tar_path) | ||
|
||
return inp_img, tar_img, filename |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import Model |
Oops, something went wrong.