-
Notifications
You must be signed in to change notification settings - Fork 3
/
search_params.py
113 lines (87 loc) · 4.05 KB
/
search_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import logging
import argparse
import multiprocessing
from copy import copy
from itertools import product
from subprocess import check_call
from functools import partial
import numpy as np
import utils
logger = logging.getLogger('DeepAR.Searcher')
utils.set_logger('param_search.log')
PYTHON = sys.executable
gpu_ids: list
param_template: utils.Params
args: argparse.ArgumentParser
search_params: dict
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='infant', help='Dataset name')
parser.add_argument('--data-dir', default='data', help='Directory containing the dataset')
parser.add_argument('--model-name', default='param_search', help='Parent directory for all jobs')
parser.add_argument('--relative-metrics', action='store_true', help='Whether to normalize the metrics by label scales')
parser.add_argument('--gpu-ids', nargs='+', default=[0,1,2], type=int, help='GPU ids')
parser.add_argument('--sampling', action='store_true', help='Whether to do ancestral sampling during evaluation')
def launch_training_job(search_range,search_params,param_template,gpu_ids,model_dir,args):
'''Launch training of the model with a set of hyperparameters in parent_dir/job_name
Args:
search_range: one combination of the params to search
'''
search_range = search_range[0]
params = {k: search_params[k][search_range[idx]] for idx, k in enumerate(sorted(search_params.keys()))}
model_param_list = '-'.join('_'.join((k, f'{v:.2f}')) for k, v in params.items())
if os.path.exists(f'experiments/param_search/{model_param_list}'):
logger.info('Already exist this train')
else:
model_param = copy(param_template)
for k, v in params.items():
setattr(model_param, k, v)
pool_id, job_idx = multiprocessing.Process()._identity
gpu_id = gpu_ids[pool_id - 1]
logger.info(f'Worker {pool_id} running {job_idx} using GPU {gpu_id}')
# Create a new folder in parent_dir with unique_name 'job_name'
model_name = os.path.join(model_dir, model_param_list)
model_input = os.path.join(args.model_name, model_param_list)
if not os.path.exists(model_name):
os.makedirs(model_name)
# Write parameters in json file
json_path = os.path.join(model_name, 'params.json')
model_param.save(json_path)
logger.info(f'Params saved to: {json_path}')
# Launch training with this config
cmd = f'{PYTHON} train.py ' \
f'--model-name={model_input} ' \
f'--dataset={args.dataset} ' \
f'--data-folder={args.data_dir} ' \
f'--save-best '
if args.sampling:
cmd += ' --sampling'
if args.relative_metrics:
cmd += ' --relative-metrics'
logger.info(cmd)
check_call(cmd, shell=True, env={'CUDA_VISIBLE_DEVICES': str(gpu_id),
'OMP_NUM_THREADS': '4'})
def main():
# Load the 'reference' parameters from parent_dir json file
global param_template, gpu_ids, args, search_params, model_dir
args = parser.parse_args()
model_dir = os.path.join('experiments', args.model_name)
json_file = os.path.join(model_dir, 'params.json')
assert os.path.isfile(json_file), f'No json configuration file found at {args.json}'
param_template = utils.Params(json_file)
gpu_ids = args.gpu_ids
logger.info(f'Running on GPU: {gpu_ids}')
# Perform hypersearch over parameters listed below
search_params = {
'lstm_dropout': np.arange(0, 0.201, 0.2).tolist(),
'lstm_hidden_dim': np.arange(5, 60, 30).tolist(),
'lam':np.arange(0, 0.3, 0.02).tolist()
}
keys = sorted(search_params.keys())
search_range = list(product(*[[*range(len(search_params[i]))] for i in keys]))
func_p = partial(launch_training_job, search_params=search_params,param_template=param_template,gpu_ids=gpu_ids,model_dir=model_dir,args=args)
pool = multiprocessing.Pool(len(gpu_ids))
pool.map(func_p, [(i, ) for i in search_range])
if __name__ == '__main__':
main()