-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
94 lines (85 loc) · 4.25 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
from easydict import EasyDict as edict
__C = edict()
cfg = __C
#
# Dataset Config
#
__C.DATASETS = edict()
__C.DATASETS.SHAPENET = edict()
__C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH = './datasets/ShapeNet.json'
__C.DATASETS.SHAPENET.RENDERING_PATH = '/home/aistudio/dataset/ShapeNet/ShapeNetRendering/%s/%s/rendering/%02d.png'
__C.DATASETS.SHAPENET.VOXEL_PATH = '/home/aistudio/dataset/ShapeNet/ShapeNetVox32/%s/%s/model.binvox'
# __C.DATASETS.PASCAL3D = edict()
# __C.DATASETS.PASCAL3D.TAXONOMY_FILE_PATH = './datasets/Pascal3D.json'
# __C.DATASETS.PASCAL3D.ANNOTATION_PATH = '/home/aistudio/dataset/PASCAL3D/Annotations/%s_imagenet/%s.mat'
# __C.DATASETS.PASCAL3D.RENDERING_PATH = '/home/aistudio/dataset/PASCAL3D/Images/%s_imagenet/%s.JPEG'
# __C.DATASETS.PASCAL3D.VOXEL_PATH = '/home/aistudio/dataset/PASCAL3D/CAD/%s/%02d.binvox'
# __C.DATASETS.PIX3D = edict()
# __C.DATASETS.PIX3D.TAXONOMY_FILE_PATH = './datasets/Pix3D.json'
# __C.DATASETS.PIX3D.ANNOTATION_PATH = '/home/aistudio/dataset/Pix3D/pix3d.json'
# __C.DATASETS.PIX3D.RENDERING_PATH = '/home/aistudio/dataset/Pix3D/img/%s/%s.%s'
# __C.DATASETS.PIX3D.VOXEL_PATH = '/home/aistudio/dataset/Pix3D/model/%s/%s/%s.binvox'
#
# Dataset
#
__C.DATASET = edict()
__C.DATASET.MEAN = [0.5, 0.5, 0.5]
__C.DATASET.STD = [0.5, 0.5, 0.5]
__C.DATASET.TRAIN_DATASET = 'ShapeNet'
__C.DATASET.TEST_DATASET = 'ShapeNet'
# __C.DATASET.TEST_DATASET = 'Pascal3D'
# __C.DATASET.TEST_DATASET = 'Pix3D'
#
# Common
#
__C.CONST = edict()
__C.CONST.DEVICE = '0'
__C.CONST.RNG_SEED = 0
__C.CONST.IMG_W = 127 # Image width for input
__C.CONST.IMG_H = 127 # Image height for input
__C.CONST.N_VOX = 32
__C.CONST.BATCH_SIZE = 30
__C.CONST.N_VIEWS_RENDERING = 5 # Number of input views
__C.CONST.CROP_IMG_W = 96 # Image crop width for input
__C.CONST.CROP_IMG_H = 96 # Image crop height for input
__C.CONST.INFO_BATCH = 100 # Print once for 100 batches
#
# Directories
#
__C.DIR = edict()
__C.DIR.OUT_PATH = './output'
# __C.DIR.RANDOM_BG_PATH = '/Datasets/SUN2012/JPEGImages'
#
# Network
#
__C.NETWORK = edict()
__C.NETWORK.LEAKY_VALUE = .2
__C.NETWORK.TCONV_USE_BIAS = False
__C.NETWORK.USE_MERGER = True
#
# Training
#
__C.TRAIN = edict()
__C.TRAIN.RESUME_TRAIN = False
__C.TRAIN.NUM_WORKER = 4 # number of data workers
__C.TRAIN.NUM_EPOCHES = 60
__C.TRAIN.BRIGHTNESS = .4
__C.TRAIN.CONTRAST = .4
__C.TRAIN.SATURATION = .4
__C.TRAIN.NOISE_STD = .1
__C.TRAIN.RANDOM_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
__C.TRAIN.POLICY = 'adam' # available options: sgd, adam
__C.TRAIN.RES_GRU_NET_LEARNING_RATE = 1e-4
__C.TRAIN.RES_GRU_NET_LR_MILESTONES = [45]
__C.TRAIN.BETAS = (.9, .999)
__C.TRAIN.MOMENTUM = .9
__C.TRAIN.GAMMA = .5
__C.TRAIN.SAVE_FREQ = 10 # weights will be overwritten every save_freq epoch
__C.TRAIN.UPDATE_N_VIEWS_RENDERING = False
#
# Testing options
#
__C.TEST = edict()
__C.TEST.RANDOM_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]
__C.TEST.VOXEL_THRESH = [.2, .3, .4, .5]