-
Notifications
You must be signed in to change notification settings - Fork 29
/
options.py
311 lines (276 loc) · 16.9 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import argparse
import os
import time
import numpy as np
from tensorboardX import SummaryWriter
from . import util,dsp,plot
import sys
sys.path.append("..")
from data import statistics,augmenter,dataloader
class Options():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
# ------------------------Base------------------------
self.parser.add_argument('--gpu_id', type=str, default='0',help='choose which gpu want to use, Single GPU: 0 | 1 | 2 ; Multi-GPU: 0,1,2,3 ; No GPU: -1')
self.parser.add_argument('--no_cudnn', action='store_true', help='if specified, do not use cudnn')
self.parser.add_argument('--label', type=str, default='auto',help='number of labels')
self.parser.add_argument('--input_nc', type=str, default='auto', help='number of input channels')
self.parser.add_argument('--loadsize', type=str, default='auto', help='load data in this size')
self.parser.add_argument('--finesize', type=str, default='auto', help='crop your data into this size')
self.parser.add_argument('--label_name', type=str, default='auto',help='name of labels,example:"a,b,c,d,e,f"')
self.parser.add_argument('--mode', type=str, default='auto',help='classify_1d | classify_2d | autoencoder | domain')
self.parser.add_argument('--domain_num', type=str, default='2',
help='number of domain, only available when mode==domain. 2 | auto ,if input 2, train-data is domain 0,test-data is domain 1.')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/simple_test',help='your dataset path')
self.parser.add_argument('--save_dir', type=str, default='./checkpoints/',help='save checkpoints')
self.parser.add_argument('--tensorboard', type=str, default='./checkpoints/tensorboardX',help='tensorboardX log dir')
self.parser.add_argument('--TBGlobalWriter', type=str, default='',help='')
# ------------------------Training Matters------------------------
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--load_thread', type=int, default=8,help='how many threads when load data')
self.parser.add_argument('--best_index', type=str, default='f1',help='select which evaluation index to get the best results in all epochs, f1 | err')
self.parser.add_argument('--pretrained', type=str, default='',help='pretrained model path. If not specified, fo not use pretrained model')
self.parser.add_argument('--continue_train', action='store_true', help='if specified, continue train')
self.parser.add_argument('--weight_level', type=int, default=0,help='change the weight of the loss function to an imbalanced dataset, loss_weight = (1/(class.weight))^opt.weight_level')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
# ------------------------Preprocessing------------------------
self.parser.add_argument('--normliaze', type=str, default='None', help='mode of normliaze, z-score | 5_95 | maxmin | None')
# filter
self.parser.add_argument('--filter', type=str, default='None', help='type of filter, fft | fir | iir |None')
self.parser.add_argument('--filter_mod', type=str, default='bandpass', help='mode of fft_filter, bandpass | bandstop')
self.parser.add_argument('--filter_fs', type=int, default=1000, help='fs of filter')
self.parser.add_argument('--filter_fc', type=str, default='[]', help='fc of filter, eg. [0.1,10]')
# filter by wavelet
self.parser.add_argument('--wave', type=str, default='None', help='wavelet name string, wavelet(eg. dbN symN haar gaus mexh) | None')
self.parser.add_argument('--wave_level', type=int, default=5, help='decomposition level')
self.parser.add_argument('--wave_usedcoeffs', type=str, default='[]', help='Coeff used for reconstruction, \
eg. when level = 6 usedcoeffs=[1,1,0,0,0,0,0] : reconstruct signal with cA6, cD6')
# ------------------------Data Augmentation------------------------
# base
self.parser.add_argument('--augment', type=str, default='None',
help='all | scale,warp,app,aaft,iaaft,filp,spike,step,slope,white,pink,blue,brown,violet ,enter some of them')
self.parser.add_argument('--augment_noise_lambda', type=float, default = 1.0, help='noise level(spike,step,slope,white,pink,blue,brown,violet)')
# for gan,it only support when fold_index = 1 or 0
self.parser.add_argument('--gan', action='store_true', help='if specified, using gan to augmente dataset')
self.parser.add_argument('--gan_lr', type=float, default=0.0002,help='learning rate')
self.parser.add_argument('--gan_augment_times', type=float, default=2,help='how many times that will be augmented by dcgan')
self.parser.add_argument('--gan_latent_dim', type=int, default=100,help='dimensionality of the latent space')
self.parser.add_argument('--gan_labels', type=str, default='[]',help='which label that will be augmented by dcgan, eg: [0,1,2,3]')
self.parser.add_argument('--gan_epochs', type=int, default=100,help='number of epochs of gan training')
# ------------------------Dataset------------------------
"""--fold_index
When --k_fold != 0 or 1:
Cut dataset into sub-set using index , and then run k-fold with sub-set
If input 'auto', it will shuffle dataset and then cut dataset to sub-dataset equally
If input 'load', load indexs.npy as fold_index
If input: [2,4,6,7]
when len(dataset) == 10
sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]
-------
When --k_fold == 0 or 1:
If input 'auto', it will shuffle dataset and then cut 80% dataset to train and other to eval
If input 'load', load indexs.npy as fold_index
If input: [5]
when len(dataset) == 10
train-set : dataset[0:5] eval-set : dataset[5:]
"""
self.parser.add_argument('--k_fold', type=int, default=0,help='fold_num of k-fold.If 0 or 1, no k-fold and cut 80% to train and other to eval')
self.parser.add_argument('--fold_index', type=str, default='auto',
help='auto | load | "input_your_index"-where to fold, eg. when 5-fold and input: [2,4,6,7] -> sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]')
self.parser.add_argument('--mergelabel', type=str, default='None',
help='merge some labels to one label and give the result, example:"[[0,1,4],[2,3,5]]" -> label(0,1,4) regard as 0,label(2,3,5) regard as 1')
self.parser.add_argument('--mergelabel_name', type=str, default='None',help='name of labels,example:"a,b,c,d,e,f"')
# ------------------------Network------------------------
"""Available Network
1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d,
micro_multi_scale_resnet_1d,autoencoder,mlp
2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101,
densenet121, densenet201, squeezenet
"""
self.parser.add_argument('--model_name', type=str, default='micro_multi_scale_resnet_1d',help='Choose model lstm...')
self.parser.add_argument('--lstm_inputsize', type=str, default='auto',help='lstm_inputsize of LSTM')
self.parser.add_argument('--lstm_timestep', type=int, default=100,help='time_step of LSTM')
# For autoecoder
self.parser.add_argument('--feature', type=int, default=3, help='number of encoder features')
# For 2d network(stft spectrum)
# Please cheek ./save_dir/spectrum_eg.jpg to change the following parameters
self.parser.add_argument('--spectrum', type=str, default='stft', help='stft | cwt')
self.parser.add_argument('--spectrum_n_downsample', type=int, default=1, help='downsample befor convert to spectrum')
self.parser.add_argument('--cwt_wavename', type=str, default='cgau8', help='')
self.parser.add_argument('--cwt_scale_num', type=int, default=64, help='')
self.parser.add_argument('--stft_size', type=int, default=512, help='length of each fft segment')
self.parser.add_argument('--stft_stride', type=int, default=128, help='stride of each fft segment')
self.parser.add_argument('--stft_no_log', action='store_true', help='if specified, do not log1p spectrum')
self.parser.add_argument('--img_shape', type=str, default='auto', help='output shape of stft. It depend on \
stft_size,stft_stride,stft_n_downsample. Do not input this parameter.')
self.initialized = True
def getparse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
input_arg = ''
opt_message = ''
opt_message += '----------------- Options ---------------\n'
for k, v in sorted(vars(self.opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
input_arg += ('--'+str(k)+' '+str(v)+' ')
opt_message += '{:>20}: {:<30}{}\n'.format(str(k), str(v), comment)
opt_message += '----------------- End -------------------'
localtime = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
util.makedirs(self.opt.save_dir)
util.writelog(str(localtime)+'\n'+opt_message+'\n', self.opt,True)
# start tensorboard
self.opt.tensorboard = os.path.join(self.opt.tensorboard,localtime+'_'+os.path.split(self.opt.save_dir)[1])
self.opt.TBGlobalWriter = SummaryWriter(self.opt.tensorboard)
util.writelog('Please run "tensorboard --logdir checkpoints/tensorboardX --host=your_server_ip" and input "'+localtime+'" to filter outputs',self.opt,True)
self.opt.TBGlobalWriter.add_text('Opt', opt_message.replace('\n', ' \n'))
self.opt.TBGlobalWriter.add_text('Opt', ('----------------- Input args ---------------\n'+input_arg).replace('\n', ' \n'))
# auto options (base)
if self.opt.gpu_id != '-1':
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id)
if self.opt.label != 'auto':
self.opt.label = int(self.opt.label)
if self.opt.input_nc !='auto':
self.opt.input_nc = int(self.opt.input_nc)
if self.opt.loadsize !='auto':
self.opt.loadsize = int(self.opt.loadsize)
if self.opt.finesize !='auto':
self.opt.finesize = int(self.opt.finesize)
if self.opt.lstm_inputsize != 'auto':
self.opt.lstm_inputsize = int(self.opt.lstm_inputsize)
if self.opt.mode == 'auto':
if self.opt.model_name in ['lstm', 'cnn_1d', 'resnet18_1d', 'resnet34_1d',
'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','mlp']:
self.opt.mode = 'classify_1d'
elif self.opt.model_name in ['light','dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50',
'resnet101','densenet121', 'densenet201', 'squeezenet', 'mobilenet','EarID','MV_Emotion']:
self.opt.mode = 'classify_2d'
elif self.opt.model_name == 'autoencoder':
self.opt.mode = 'autoencoder'
elif self.opt.model_name in ['dann','dann_base']:
self.opt.mode = 'domain'
elif self.opt.model_name in ['dann_lstm']:
self.opt.mode = 'domain_1d'
else:
print('\033[1;31m'+'Error: do not support this network '+self.opt.model_name+'\033[0m')
sys.exit(0)
if self.opt.k_fold == 0 :
self.opt.k_fold = 1
if self.opt.fold_index == 'auto':
if os.path.isfile(os.path.join(self.opt.dataset_dir,'index.npy')):
print('Warning: index.npy exists but does not load it')
elif self.opt.fold_index == 'load':
if os.path.isfile(os.path.join(self.opt.dataset_dir,'index.npy')):
self.opt.fold_index = (np.load(os.path.join(self.opt.dataset_dir,'index.npy'))).tolist()
else:
print('Warning: index.npy does not exist')
sys.exit(0)
else:
self.opt.fold_index = str2list(self.opt.fold_index,int)
if self.opt.augment == 'all':
self.opt.augment = ['scale','warp','spike','step','slope','white','pink','blue','brown','violet','app','aaft','iaaft','filp']
else:
self.opt.augment = str2list(self.opt.augment)
self.opt.filter_fc = str2list(self.opt.filter_fc,float)
self.opt.wave_usedcoeffs = str2list(self.opt.wave_usedcoeffs,int)
self.opt.gan_labels = str2list(self.opt.gan_labels,int)
if self.opt.mergelabel != 'None':
self.opt.mergelabel = str2list(self.opt.mergelabel,int,2)
if self.opt.mergelabel_name != 'None':
self.opt.mergelabel_name = str2list(self.opt.mergelabel_name)
return self.opt
def str2list(string,out_type = str,depth = 1):
out_list = []
string = string.replace(' ','')
if depth == 1:
string = string.replace('[','').replace(']','')
strings = string.split(',')
for string in strings:
if string != '':
out_list.append(out_type(string))
elif depth ==2:
string = list(string)[1:-1]
for c in string:
if c =='[':
_out = []
elif c == ']':
out_list.append(_out)
elif c != ',':
_out.append(int(c))
return out_list
def get_auto_options(opt,signals,labels):
shape = signals.shape
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
if opt.label =='auto':
opt.label = label_num
if opt.input_nc =='auto':
opt.input_nc = shape[1]
if opt.loadsize =='auto':
opt.loadsize = shape[2]
if opt.finesize =='auto':
opt.finesize = int(shape[2]*0.9)
if opt.lstm_inputsize =='auto':
opt.lstm_inputsize = opt.finesize//opt.lstm_timestep
# weight
opt.weight = np.ones(opt.label)
opt.weight = 1/label_cnt_per
opt.weight = np.power((opt.weight/np.min(opt.weight)),opt.weight_level)
util.writelog('Loss_weight:'+str(opt.weight),opt,True,True)
import torch
opt.weight = torch.from_numpy(opt.weight).float()
if opt.gpu_id != '-1':
opt.weight = opt.weight.cuda()
# label name
if opt.label_name == 'auto':
names = []
for i in range(opt.label):
names.append(str(i))
opt.label_name = names
elif not isinstance(opt.label_name,list):
opt.label_name = opt.label_name.replace(" ", "").split(",")
# domain_num
if opt.mode in ['domain','domain_1d']:
if opt.domain_num == '2':
opt.domain_num = 2
else:
if os.path.isfile(os.path.join(opt.dataset_dir,'domains.npy')):
domains = np.load(os.path.join(opt.dataset_dir,'domains.npy'))
domains = dataloader.rebuild_domain(domains)
opt.domain_num = statistics.label_statistics(domains)[2]
else:
print('Please generate domains.npy(np.int64, shape like labels.npy)')
sys.exit(0)
# check stft spectrum
if opt.mode in ['classify_2d','domain'] and signals.ndim == 3:
spectrums = []
data = signals[np.random.randint(0,shape[0]-1)].reshape(shape[1],shape[2])
data = augmenter.ch1d(opt, data, test_flag=False)
plot.draw_eg_signals(data,opt)
for i in range(shape[1]):
spectrums.append(dsp.signal2spectrum(data[i],opt.stft_size,opt.stft_stride,
opt.cwt_wavename,opt.cwt_scale_num,opt.spectrum_n_downsample,not opt.stft_no_log, mod = opt.spectrum))
plot.draw_eg_spectrums(spectrums,opt)
opt.img_shape = spectrums[0].shape
h,w = opt.img_shape
print('Shape of stft spectrum h,w:',opt.img_shape)
print('\033[1;37m'+'Please cheek tensorboard->IMAGES->spectrum_eg to change parameters'+'\033[0m')
if h<64 or w<64:
print('\033[1;33m'+'Warning: spectrum is too small'+'\033[0m')
if h>512 or w>512:
print('\033[1;33m'+'Warning: spectrum is too large'+'\033[0m')
if signals.ndim == 4:
opt.img_shape = signals.shape[2],signals.shape[3]
img = signals[np.random.randint(0,shape[0]-1)]
opt.TBGlobalWriter.add_image('img_eg',img)
return opt