forked from lingtengqiu/Deeperlab-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
122 lines (97 loc) · 3.04 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# encoding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path as osp
import sys
import time
import math
import numpy as np
from easydict import EasyDict as edict
import argparse
import torch.utils.model_zoo as model_zoo
C = edict()
config = C
cfg = C
C.seed = 12345
"""please config ROOT_dir and user when u first using"""
C.repo_name = osp.abspath("./").split("/")[-1]
C.abs_dir = osp.realpath(".")
C.this_dir = C.abs_dir.split(osp.sep)[-1]
C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)]
print(C.abs_dir)
C.log_dir = osp.abspath(osp.join(C.root_dir, 'log', C.this_dir))
C.log_dir_link = osp.join(C.abs_dir, 'log')
C.snapshot_dir = osp.abspath(osp.join(C.log_dir, "snapshot"))
exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
C.log_file = C.log_dir + '/log_' + exp_time + '.log'
C.link_log_file = C.log_file + '/log_last.log'
C.val_log_file = C.log_dir + '/val_' + exp_time + '.log'
C.link_val_log_file = C.log_dir + '/val_last.log'
"""Data Dir and Weight Dir"""
img_root =osp.abspath("./data/JPEGImages/")
gt_root = osp.abspath("./data/SegmentationLabel/")
train_root = osp.abspath("./data/ImageSets/Segmentation/train.txt")
eval_root = osp.abspath("./data/ImageSets/Segmentation/val.txt")
"""Data Dir and Weight Dir"""
C.img_root_folder = img_root
C.gt_root_folder = gt_root
C.train_source = train_root
C.eval_source = eval_root
# C.test_source = "/unsullied/sharefs/yuchangqian/Storage/Datasets/VOC2012_AUG/config/voc12_test.txt"
C.is_test = False
"""Path Config"""
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
add_path(osp.join(C.root_dir, 'furnace'))
from utils.pyt_utils import model_urls
"""Image Config"""
C.num_classes = 21
C.background = 0
C.image_mean = np.array([0.485, 0.456, 0.406]) # 0.485, 0.456, 0.406
C.image_std = np.array([0.229, 0.224, 0.225])
C.target_size = 512
C.image_height = 512
C.image_width = 512
C.num_train_imgs = 1464
C.num_eval_imgs = 1449
""" Settings for network, this would be different for each kind of model"""
C.fix_bias = True
C.fix_bn = False
C.sync_bn = True
C.bn_eps = 1e-5
C.bn_momentum = 0.1
C.pretrained_model = "./pretrain/xception-71.pth"
C.aux_loss_alpha = 0.1
"""Train Config"""
C.lr = 1e-3
C.lr_power = 0.9
C.momentum = 0.9
C.weight_decay = 1e-5
C.batch_size = 4 #4 * C.num_gpu
C.nepochs = 120
C.niters_per_epoch = int(math.ceil(C.num_train_imgs * 1.0 // C.batch_size))
C.num_workers = 1
C.train_scale_array = [0.5, 0.75, 1, 1.5, 1.75, 2.0]
"""Eval Config"""
C.eval_iter = 30
C.eval_stride_rate = 2 / 3
C.eval_scale_array = [1, ]
C.eval_flip = False
C.eval_base_size = 512
C.eval_crop_size = 512
"""Display Config"""
C.snapshot_iter = 10
C.record_info_iter = 20
C.display_iter = 50
C.tensorboardX = True
def open_tensorboard():
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-tb', '--tensorboard', default=False, action='store_true')
args = parser.parse_args()
if args.tensorboard:
open_tensorboard()