-
Notifications
You must be signed in to change notification settings - Fork 1
/
config_args.py
279 lines (254 loc) · 11.4 KB
/
config_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import argparse, os, time, func_timeout
from shutil import copyfile, copytree, ignore_patterns
from UTIL.colorful import *
def secure_chained_vars(default_cfg, new_cfg, vb):
default_cfg_dict = default_cfg.__dict__
altered_cv = []
for key in default_cfg_dict:
if not key.endswith('_cv'): continue
o_key = key.replace('_cv','')
if o_key in new_cfg: continue
assert hasattr(default_cfg, o_key), ('twin var does not have original')
# get twin
chain_var = getattr(default_cfg, key)
need_reflesh = False
for chain_by_var in chain_var.chained_with:
if chain_by_var in new_cfg: need_reflesh = True
if not need_reflesh: continue
replace_item = chain_var.chain_func(*[getattr(default_cfg, v) for v in chain_var.chained_with])
original_item = getattr(default_cfg, o_key)
if vb: print靛('[config] warning, %s is chained by %s, automatic modifying:'%(o_key,
str(chain_var.chained_with)), original_item, '-->', replace_item)
setattr(default_cfg, o_key, replace_item)
altered_cv.append(o_key)
return altered_cv
# cfg_group: 'config.py->GlobalConfig'
# new_cfg: {'note': 'pymarl-compat-T3', 'train_time_testing': 'False', 'heartbeat_on': 'False', 'env_name': 'sc2', 'env_path': 'MISSIONS.starcraft.s...nv_wrapper', 'draw_mode': 'Img', 'num_threads': '2', 'report_reward_interval': '2', 'device': 'cuda:3', 'max_n_episode': 10000, 'fold': '1', 'seed': 9995, 'backup_files': []}
# vb: True
def override_config_file(cfg_group, new_cfg, vb):
import importlib
assert '->' in cfg_group
str_pro = '------------- %s -------------'%cfg_group
if vb: print绿(str_pro)
file_, class_ = cfg_group.split('->')
if '.py' in file_:
# replace it with removesuffix('.py') if you have python>=3.9
if file_.endswith('.py'): file_ = file_[:-3]
default_configs = getattr(importlib.import_module(file_), class_)
for key in new_cfg:
if new_cfg[key] is None: continue
my_setattr(conf_class=default_configs, key=key, new_value=new_cfg[key], vb=vb)
altered_cv = secure_chained_vars(default_configs, new_cfg, vb)
if vb:
print绿(''.join(['-']*len(str_pro)),)
arg_summary(default_configs, new_cfg, altered_cv)
print绿(''.join(['-']*len(str_pro)),'\n\n\n')
if 'TEAM_NAMES' in new_cfg:
return [item.split('->')[0] for item in new_cfg['TEAM_NAMES']]
return None
def check_config_relevence(json_data):
env_name = json_data['config.py->GlobalConfig']['env_name']
env_path = json_data['config.py->GlobalConfig']['env_path']
for key in json_data.keys():
if 'MISSIONS' in key: assert env_path in key, ('configering wrong env!')
def load_config_via_json(json_data, vb):
for cfg_group in json_data:
dependency = override_config_file(cfg_group, json_data[cfg_group], vb)
# if dependency is not None:
# for dep in dependency:
# assert any([dep in k for k in json_data.keys()]), 'Arg check failure, There is something missing!'
# check_config_relevence(json_data)
return None
def get_core_args(vb=True):
parser = argparse.ArgumentParser(description='HMP')
parser.add_argument('-c', '--cfg', help='Path of the configuration file')
parser.add_argument('-s', '--skip', action='store_true', help='skip logdir check')
args, unknown = parser.parse_known_args()
load_via_json = (hasattr(args, 'cfg') and args.cfg is not None)
if load_via_json:
if len(unknown) > 0 and vb:
print亮红('Warning! In json setting mode, %s is ignored'%str(unknown))
import commentjson as json
with open(args.cfg) as f: json_data = json.load(f)
core_group = 'config.py->GlobalConfig'
override_config_file(core_group, json_data[core_group], vb)
else:
assert False
new_args = load_config_via_cmdline(vb)
def prepare_args(vb=True):
parser = argparse.ArgumentParser(description='HMP')
parser.add_argument('-c', '--cfg', help='Path of the configuration file')
parser.add_argument('-s', '--skip', action='store_true', help='skip logdir check')
args, unknown = parser.parse_known_args()
load_via_json = (hasattr(args, 'cfg') and args.cfg is not None)
skip_logdir_check = (hasattr(args, 'skip') and (args.skip is not None) and args.skip)
if load_via_json:
if len(unknown) > 0 and vb:
print亮红('Warning! In json setting mode, %s is ignored'%str(unknown))
import commentjson as json
with open(args.cfg) as f:
json_data = json.load(f)
new_args = load_config_via_json(json_data, vb)
else:
new_args = load_config_via_cmdline(vb)
from config import GlobalConfig as cfg
note_name_overide = None
if not skip_logdir_check:
note_name_overide = check_experiment_log_path(cfg.logdir)
if note_name_overide is not None:
override_config_file('config.py->GlobalConfig', {'note':note_name_overide}, vb)
if not os.path.exists(cfg.logdir): os.makedirs(cfg.logdir)
if load_via_json and (not cfg.recall_previous_session):
copyfile(args.cfg, '%s/experiment.json'%cfg.logdir)
backup_files(cfg.backup_files, cfg.logdir)
cfg.machine_info = register_machine_info(cfg.logdir)
cfg.cfg_ready = True
return cfg
def register_machine_info(logdir):
import socket, json, subprocess, uuid
from .network import get_host_ip
info = {
'HostIP': get_host_ip(),
'ExpUUID':uuid.uuid1().hex,
'RunPath': os.getcwd(),
'StartDateTime': time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
}
try:
info['DockerContainerHash'] = subprocess.getoutput(r'cat /proc/self/cgroup | grep -o -e "docker/.*"| head -n 1 |sed "s/docker\\/\\(.*\\)/\\1/" |cut -c1-12')
except:
info['DockerContainerHash'] = 'None'
with open('%s/info.json'%logdir, 'w+') as f:
json.dump(info, f, indent=4)
return info
def backup_files(files, logdir):
for file in files:
if os.path.isfile(file):
print绿('[config] Backup File:',file)
bkdir = '%s/backup_files/'%logdir
if not os.path.exists(bkdir): os.makedirs(bkdir)
copyfile(file, '%s/%s'%(bkdir, os.path.basename(file)))
else:
print亮绿('[config] Backup Folder:',file)
assert os.path.isdir(file), ('cannot find', file)
copytree(file, '%s/backup_files/%s'%(logdir, os.path.basename(file)),
dirs_exist_ok=True, ignore=ignore_patterns("__pycache__"))
return
def check_experiment_log_path(logdir):
res = None
if os.path.exists(logdir):
if os.path.exists(logdir+'test_stage'): return None
print亮红('Current log path:', logdir)
print亮红('Warning! you will overwrite old logs if continue!')
print亮红("Pause for 60 seconds ( Press Enter to confirm! )")
try:
res = askChoice()
if res == '': res = None
except func_timeout.exceptions.FunctionTimedOut as e:
res = None
return res
@func_timeout.func_set_timeout(60)
def askChoice():
return input('>>')
# def input_or_timeout(timeout):
# for i in range(30):
# time.sleep(1)
# return
def arg_summary(config_class, modify_dict = {}, altered_cv = []):
for key in config_class.__dict__:
if '__' in key: continue
if key.endswith('_cv'): continue
if (not key in modify_dict) or (modify_dict[key] is None):
if key not in altered_cv:
print绿(key.center(25), '-->', str(getattr(config_class,key)))
else:
print靛(key.center(25), '-->', str(getattr(config_class,key)))
else:
print红(key.center(25), '-->', str(getattr(config_class,key)))
def load_config_via_cmdline(vb):
parser = argparse.ArgumentParser(description='HMP')
# environment
from config import GlobalConfig as cfg
for setting_name in cfg.__dict__:
if '__' in setting_name: continue
if setting_name.endswith('_cv'): continue
try:
parser.add_argument('--' + setting_name)
except:
if vb: print红('[config] 参数重复! ', setting_name)
args = vars(parser.parse_args())
args = {key: args[key] for key in args if args[key] is not None}
override_config_file('config.py->GlobalConfig', args, vb)
if vb: arg_summary(cfg, args)
return args
def my_setattr(conf_class, key, new_value, vb):
assert hasattr(conf_class, key), (conf_class, 'has no such config item: **%s**'%key)
setting_name = key
replace_item = new_value
original_item = getattr(conf_class, setting_name)
if vb: print绿('[config] override %s:'%setting_name, original_item, '-->', replace_item)
if isinstance(original_item, float):
replace_item = float(replace_item)
elif isinstance(original_item, bool):
if replace_item == 'True':
replace_item = True
elif replace_item == 'False':
replace_item = False
elif isinstance(replace_item, bool):
replace_item = replace_item
else:
assert False, ('enter True or False, but have:', replace_item)
elif isinstance(original_item, int):
replace_item = int(replace_item)
elif isinstance(original_item, str):
replace_item = replace_item
elif isinstance(original_item, list):
assert isinstance(replace_item, list)
elif isinstance(original_item, dict):
assert isinstance(replace_item, dict)
else:
assert False, ('not support this type')
setattr(conf_class, setting_name, replace_item)
return
def find_all_conf():
import glob
py_script_list = glob.glob('./**/*.py', recursive=True)
conf_class_gather = []
for python_file in py_script_list:
with open(python_file,encoding='UTF-8') as f:
lines = f.readlines()
for line in lines:
if 'ADD_TO_CONF_SYSTEM' not in line: continue
if 'class ' not in line: continue
conf_class_gather.append({'line':line, 'file':python_file})
def getBetween(str, str1, str2):
strOutput = str[str.find(str1)+len(str1):str.find(str2)]
return strOutput
for target in conf_class_gather:
class_name = getBetween(target['line'], 'class ', '(')
target['class_name'] = class_name
target['file'] = target['file'].replace('/', '.').replace('..', '')
import importlib
target['class'] = getattr(importlib.import_module(target['file'].replace('.py', '')), class_name)
return conf_class_gather
def make_json(conf_list):
import json
out = {}
for conf in conf_list:
local_conf = {}
config_class = conf['class']
for key in config_class.__dict__:
if '__' in key or '_cv' in key: continue
item_to_be_serialize = getattr(config_class, key)
try:
json.dumps(item_to_be_serialize)
except:
item_to_be_serialize = '[cannot be json]' + str(item_to_be_serialize)
local_conf[key] = item_to_be_serialize
out[conf['file']] = local_conf
# json_str = json.dumps(out)
with open('all_conf.json', 'w') as f:
json.dump(out, f, indent=4)
print亮紫('the conf summary is successfully saved to all_conf.json')
if __name__ == '__main__':
conf_list = find_all_conf()
res_json = make_json(conf_list)