-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_cr_mel.py
63 lines (56 loc) · 2.4 KB
/
main_cr_mel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
from cr_model_v2 import cr_cfg_process
from cr_model_v2 import cr_model
from cr_model_v2 import cr_model_impl_mel
from cr_model_v2 import cr_model_run
from cr_model_v2 import data_set
from cr_model_v2 import load_data
from utils import cfg_process
from utils import parser_util
def add_arguments(parser):
"""Build ArgumentParser"""
parser.add_argument('--config_file', type=str, default='./cr_model_v2/cfgs/mel_ma_nodropout.yml',
help='config file about hparams')
parser.add_argument('--config_name', type=str, default='mel_ma_nodropout',
help='config name for hparams')
parser.add_argument('--gpu', type=str, default='',
help='config for CUDA_VISIBLE_DEVICES')
def main(unused_argv):
parser = parser_util.MyArgumentParser()
add_arguments(parser)
argc, flags_dict = parser.parse_to_dict()
yparams = cfg_process.YParams(argc.config_file, argc.config_name)
yparams = cr_cfg_process.CRHpsPreprocessor(yparams, flags_dict).preprocess()
print('id str:', yparams.id_str)
yparams.save()
CRM_dict = {
'CRModel1': cr_model.CRModel1,
'CRModel2': cr_model.CRModel2,
'CRModel3': cr_model.CRModel3,
'MelModel1': cr_model_impl_mel.MelModel1,
'MelModel2': cr_model_impl_mel.MelModel2,
'MelModel3': cr_model_impl_mel.MelModel3,
'MelModel4': cr_model_impl_mel.MelModel4,
'MelModel5': cr_model_impl_mel.MelModel5,
'MelModel6': cr_model_impl_mel.MelModel6,
'MelModel7': cr_model_impl_mel.MelModel7,
'MelModel8': cr_model_impl_mel.MelModel8,
'MelModel9': cr_model_impl_mel.MelModel9,
'MelModel10': cr_model_impl_mel.MelModel10,
'MelModel11': cr_model_impl_mel.MelModel11,
'MelModel12': cr_model_impl_mel.MelModel12,
'Hid2DMelModel': cr_model_impl_mel.Hid2DMelModel,
'Hid3DMelModel': cr_model_impl_mel.Hid3DMelModel
}
# print('model_key', yparams.model_key)
CRM = CRM_dict[yparams.model_key]
model = CRM(yparams)
if 'is_rediv_data' in yparams and yparams.is_rediv_data is True:
l_data = load_data.load_data_mix(yparams)
else:
l_data = load_data.load_data(yparams)
d_set = data_set.DataSet(l_data, yparams)
cr_model_run_v2 = cr_model_run.CRModelRun(model)
cr_model_run_v2.run(d_set)
if __name__ == '__main__':
tf.app.run(main=main)